Dataset Classes
- class multimodal_compare.models.datasets.BaseDataset(pth, testpth, mod_type)
Bases:
object
Abstract dataset class shared for all datasets
- _mod_specific_loaders()
Assigns the preprocessing function based on the mod_type
- _mod_specific_savers()
Assigns the postprocessing function based on the mod_type
- _postprocess(output_data)
Postprocesses the output data according to modality type
- Returns:
postprocessed data
- Return type:
list
- _postprocess_all2img(data)
Converts any kind of data to images to save traversal visualizations
- Parameters:
data (torch.tensor) – input data
- Returns:
processed images
- Return type:
torch.tensor
- _preprocess()
Preprocesses the loaded data according to modality type
- Returns:
preprocessed data
- Return type:
list
- _preprocess_images(dimensions)
General function for loading images and preparing them as torch tensors
- Parameters:
dimensions (list) – feature_dim for the image modality
- Returns:
preprocessed data
- Return type:
torch.tensor
- _preprocess_text_onehot()
General function for loading text strings and preparing them as torch one-hot encodings
- Returns:
torch with text encodings and masks
- Return type:
torch.tensor
- current_datatype()
Returns whther the current path to data points to test data or train data
- eval_statistics_fn()
(optional) Returns a dataset-specific function that runs systematic evaluation
- get_data()
Returns processed data
- Returns:
processed data
- Return type:
list
- get_data_raw()
Loads raw data from path
- Returns:
loaded raw data
- Return type:
list
- get_labels(split='train')
Returns labels for the given split: train or test
- get_processed_recons(recons_raw)
Returns the postprocessed data that came from the decoders
- Parameters:
recons_raw (torch.tensor) – tensor with output reconstructions
- Returns:
postprocessed data as returned by the specific _postprocess function
- Return type:
list
- get_test_data()
Returns processed test data if available
- Returns:
processed data
- Return type:
list
- labels()
Returns labels for the whole dataset
- save_traversals(recons, path, num_dims)
Makes a grid of traversals and saves as image
- Parameters:
recons (torch.tensor) – data to save
path (str) – path to save the traversal to
num_dims (int) – number of latent dimensions
- class multimodal_compare.models.datasets.CDSPRITESPLUS(pth, testpth, mod_type)
Bases:
BaseDataset
- _mod_specific_loaders()
Assigns the preprocessing function based on the mod_type
- _mod_specific_savers()
Assigns the postprocessing function based on the mod_type
- _postprocess_images(data)
- _postprocess_text(data)
- _preprocess_images()
General function for loading images and preparing them as torch tensors
- Parameters:
dimensions (list) – feature_dim for the image modality
- Returns:
preprocessed data
- Return type:
torch.tensor
- _preprocess_text()
- eval_statistics_fn()
(optional) Returns a dataset-specific function that runs systematic evaluation
- feature_dims = {'image': [64, 64, 3], 'text': [45, 27, 1]}
- labels()
Extract text labels based on the dataset level :return: list of labels as strings :rtype: list
- save_recons(data, recons, path, mod_names)
- set_vis_image_shape()
- class multimodal_compare.models.datasets.CELEBA(pth, testpth, mod_type)
Bases:
BaseDataset
- _mod_specific_loaders()
Assigns the preprocessing function based on the mod_type
- _mod_specific_savers()
Assigns the postprocessing function based on the mod_type
- _postprocess_all2img(data)
Converts any kind of data to images to save traversal visualizations
- Parameters:
data (torch.tensor) – input data
- Returns:
processed images
- Return type:
torch.tensor
- _postprocess_atts(data)
- _postprocess_images(data)
- _preprocess_atts()
- _preprocess_images()
General function for loading images and preparing them as torch tensors
- Parameters:
dimensions (list) – feature_dim for the image modality
- Returns:
preprocessed data
- Return type:
torch.tensor
- feature_dims = {'atts': [4, 2], 'image': [64, 64, 3]}
- save_recons(data, recons, path, mod_names)
- save_traversals(recons, path, num_dims)
Makes a grid of traversals and saves as image
- Parameters:
recons (torch.tensor) – data to save
path (str) – path to save the traversal to
num_dims (int) – number of latent dimensions
- class multimodal_compare.models.datasets.CUB(pth, testpth, mod_type)
Bases:
BaseDataset
Dataset class for our processed version of Caltech-UCSD birds dataset. We use the original images and text represented as sequences of one-hot-encodings for each character (incl. spaces)
- _mod_specific_loaders()
Assigns the preprocessing function based on the mod_type
- _mod_specific_savers()
Assigns the postprocessing function based on the mod_type
- _postprocess_text(data)
- _preprocess_images()
General function for loading images and preparing them as torch tensors
- Parameters:
dimensions (list) – feature_dim for the image modality
- Returns:
preprocessed data
- Return type:
torch.tensor
- _preprocess_text()
- _preprocess_text_onehot()
General function for loading text strings and preparing them as torch one-hot encodings
- Returns:
torch with text encodings and masks
- Return type:
torch.tensor
- feature_dims = {'image': [64, 64, 3], 'text': [246, 27, 1]}
- labels()
No labels for T-SNAE available
- save_recons(data, recons, path, mod_names)
- class multimodal_compare.models.datasets.FASHIONMNIST(pth, testpth, mod_type)
Bases:
BaseDataset
Dataset class for the FashionMNIST dataset
- _mod_specific_loaders()
Assigns the preprocessing function based on the mod_type
- _mod_specific_savers()
Assigns the postprocessing function based on the mod_type
- _postprocess_image(data)
- _postprocess_label(data)
- _process_image()
- _process_label()
- feature_dims = {'image': [28, 28, 1], 'label': [10]}
- get_data_raw()
Loads raw data from path
- Returns:
loaded raw data
- Return type:
list
- labels()
Returns labels for the whole dataset
- save_recons(data, recons, path, mod_names)
- class multimodal_compare.models.datasets.MNIST_SVHN(pth, testpth, mod_type)
Bases:
BaseDataset
Dataset class for the MNIST-SVHN bimodal dataset (can be also used for unimodal training)
- _mod_specific_loaders()
Assigns the preprocessing function based on the mod_type
- _mod_specific_savers()
Assigns the postprocessing function based on the mod_type
- _postprocess_all2img(data)
Converts any kind of data to images to save traversal visualizations
- Parameters:
data (torch.tensor) – input data
- Returns:
processed images
- Return type:
torch.tensor
- _postprocess_mnist(data)
- _postprocess_svhn(data)
- _process_mnist()
- _process_svhn()
- check_indices_present()
- feature_dims = {'mnist': [28, 28, 1], 'svhn': [32, 32, 3]}
- labels()
Returns labels for the whole dataset
- save_recons(data, recons, path, mod_names)
- class multimodal_compare.models.datasets.POLYMNIST(pth, testpth, mod_type)
Bases:
BaseDataset
Dataset class for the POLYMNIST dataset
- _mod_specific_loaders()
Assigns the preprocessing function based on the mod_type
- _mod_specific_savers()
Assigns the postprocessing function based on the mod_type
- _postprocess_mnist(data)
- _process_mnist()
- feature_dims = {'m0': [28, 28, 3], 'm1': [28, 28, 3], 'm2': [28, 28, 3], 'm3': [28, 28, 3], 'm4': [28, 28, 3]}
- save_recons(data, recons, path, mod_names)
- save_traversals(recons, path, num_dims)
Makes a grid of traversals and saves as image
- Parameters:
recons (torch.tensor) – data to save
path (str) – path to save the traversal to
num_dims (int) – number of latent dimensions
- class multimodal_compare.models.datasets.SPRITES(pth, testpth, mod_type)
Bases:
BaseDataset
- _mod_specific_loaders()
Assigns the preprocessing function based on the mod_type
- _mod_specific_savers()
Assigns the postprocessing function based on the mod_type
- _postprocess_actions(data)
- _postprocess_all2img(data)
Converts any kind of data to images to save traversal visualizations
- Parameters:
data (torch.tensor) – input data
- Returns:
processed images
- Return type:
torch.tensor
- _postprocess_attributes(data)
- _postprocess_frames(data)
- eval_statistics_fn()
(optional) Returns a dataset-specific function that runs systematic evaluation
- feature_dims = {'actions': [9], 'attributes': [4, 6], 'frames': [8, 64, 64, 3]}
- get_actions()
- get_attributes()
- get_frames()
- iter_over_inputs(outs, data, mod_names, f=0)
- labels()
Returns labels for the whole dataset
- make_masks(shape)
- save_recons(data, recons, path, mod_names)
- save_traversals(recons, path)
Makes a grid of traversals and saves as animated gif image
- Parameters:
recons (torch.tensor) – data to save
path (str) – path to save the traversal to