DataLoader
- class multimodal_compare.models.dataloader.DataModule(config)
Bases:
LightningDataModule
- check_load_testdata()
- check_testdata_avail()
- collate_fn(batch)
Custom collate function that puts data in a dictionary and prepares masks if needed
- Parameters:
batch (list) – input batch
- Returns:
dictionary with data batch
- Return type:
dict
- get_dataset_class()
Get the dataset class object according to the dataset name
- Returns:
dataset class object
- Return type:
object
- get_label_for_indices(indices, split)
Get labels for the given data split according to indices
- get_labels(split='all')
Return data labels for given indices if available
- Parameters:
split (str) – “all”/”train”/”val/test” depending on the data split
- Returns:
list of labels for given indices
- Return type:
list
- get_num_samples(num_samples, split='val')
Returns batch of the predict_dataloader together with the indices
- make_masks(batch)
Makes masks for sequential data
- Parameters:
batch (torch.tensor) – data batch
- Returns:
dictionary with data and masks
- Return type:
dict
- predict_dataloader(batch_size, split='val') DataLoader
Return Val DataLoader with custom batch size
- prepare_data_classes()
- prepare_singlemodal(batch, mod_index)
Prepares singlemodal data for given modality
- Parameters:
batch (list) – input batch
mod_index (int) – index of the modality (as the order in config)
- Returns:
prepared data for training
- Return type:
dict
- setup(stage: str | None = None) None
Loads appropriate dataset classes and makes data splits
- test_dataloader()
Return Test DataLoader
- train_dataloader() DataLoader
Return Train DataLoader
- val_dataloader() DataLoader
Return Val DataLoader