DataLoader

class multimodal_compare.models.dataloader.DataModule(config)

Bases: LightningDataModule

check_load_testdata()
check_testdata_avail()
collate_fn(batch)

Custom collate function that puts data in a dictionary and prepares masks if needed

Parameters:

batch (list) – input batch

Returns:

dictionary with data batch

Return type:

dict

get_dataset_class()

Get the dataset class object according to the dataset name

Returns:

dataset class object

Return type:

object

get_label_for_indices(indices, split)

Get labels for the given data split according to indices

get_labels(split='all')

Return data labels for given indices if available

Parameters:

split (str) – “all”/”train”/”val/test” depending on the data split

Returns:

list of labels for given indices

Return type:

list

get_num_samples(num_samples, split='val')

Returns batch of the predict_dataloader together with the indices

make_masks(batch)

Makes masks for sequential data

Parameters:

batch (torch.tensor) – data batch

Returns:

dictionary with data and masks

Return type:

dict

predict_dataloader(batch_size, split='val') DataLoader

Return Val DataLoader with custom batch size

prepare_data_classes()
prepare_singlemodal(batch, mod_index)

Prepares singlemodal data for given modality

Parameters:
  • batch (list) – input batch

  • mod_index (int) – index of the modality (as the order in config)

Returns:

prepared data for training

Return type:

dict

setup(stage: str | None = None) None

Loads appropriate dataset classes and makes data splits

test_dataloader()

Return Test DataLoader

train_dataloader() DataLoader

Return Train DataLoader

val_dataloader() DataLoader

Return Val DataLoader