Dataset Classes

class multimodal_compare.models.datasets.BaseDataset(pth, testpth, mod_type)

Bases: object

Abstract dataset class shared for all datasets

_mod_specific_loaders()

Assigns the preprocessing function based on the mod_type

_mod_specific_savers()

Assigns the postprocessing function based on the mod_type

_postprocess(output_data)

Postprocesses the output data according to modality type

Returns:

postprocessed data

Return type:

list

_postprocess_all2img(data)

Converts any kind of data to images to save traversal visualizations

Parameters:

data (torch.tensor) – input data

Returns:

processed images

Return type:

torch.tensor

_preprocess()

Preprocesses the loaded data according to modality type

Returns:

preprocessed data

Return type:

list

_preprocess_images(dimensions)

General function for loading images and preparing them as torch tensors

Parameters:

dimensions (list) – feature_dim for the image modality

Returns:

preprocessed data

Return type:

torch.tensor

_preprocess_text_onehot()

General function for loading text strings and preparing them as torch one-hot encodings

Returns:

torch with text encodings and masks

Return type:

torch.tensor

current_datatype()

Returns whther the current path to data points to test data or train data

eval_statistics_fn()

(optional) Returns a dataset-specific function that runs systematic evaluation

get_data()

Returns processed data

Returns:

processed data

Return type:

list

get_data_raw()

Loads raw data from path

Returns:

loaded raw data

Return type:

list

get_labels(split='train')

Returns labels for the given split: train or test

get_processed_recons(recons_raw)

Returns the postprocessed data that came from the decoders

Parameters:

recons_raw (torch.tensor) – tensor with output reconstructions

Returns:

postprocessed data as returned by the specific _postprocess function

Return type:

list

get_test_data()

Returns processed test data if available

Returns:

processed data

Return type:

list

labels()

Returns labels for the whole dataset

save_traversals(recons, path, num_dims)

Makes a grid of traversals and saves as image

Parameters:
  • recons (torch.tensor) – data to save

  • path (str) – path to save the traversal to

  • num_dims (int) – number of latent dimensions

class multimodal_compare.models.datasets.CDSPRITESPLUS(pth, testpth, mod_type)

Bases: BaseDataset

_mod_specific_loaders()

Assigns the preprocessing function based on the mod_type

_mod_specific_savers()

Assigns the postprocessing function based on the mod_type

_postprocess_images(data)
_postprocess_text(data)
_preprocess_images()

General function for loading images and preparing them as torch tensors

Parameters:

dimensions (list) – feature_dim for the image modality

Returns:

preprocessed data

Return type:

torch.tensor

_preprocess_text()
eval_statistics_fn()

(optional) Returns a dataset-specific function that runs systematic evaluation

feature_dims = {'image': [64, 64, 3], 'text': [45, 27, 1]}
labels()

Extract text labels based on the dataset level :return: list of labels as strings :rtype: list

save_recons(data, recons, path, mod_names)
set_vis_image_shape()
class multimodal_compare.models.datasets.CELEBA(pth, testpth, mod_type)

Bases: BaseDataset

_mod_specific_loaders()

Assigns the preprocessing function based on the mod_type

_mod_specific_savers()

Assigns the postprocessing function based on the mod_type

_postprocess_all2img(data)

Converts any kind of data to images to save traversal visualizations

Parameters:

data (torch.tensor) – input data

Returns:

processed images

Return type:

torch.tensor

_postprocess_atts(data)
_postprocess_images(data)
_preprocess_atts()
_preprocess_images()

General function for loading images and preparing them as torch tensors

Parameters:

dimensions (list) – feature_dim for the image modality

Returns:

preprocessed data

Return type:

torch.tensor

feature_dims = {'atts': [4, 2], 'image': [64, 64, 3]}
save_recons(data, recons, path, mod_names)
save_traversals(recons, path, num_dims)

Makes a grid of traversals and saves as image

Parameters:
  • recons (torch.tensor) – data to save

  • path (str) – path to save the traversal to

  • num_dims (int) – number of latent dimensions

class multimodal_compare.models.datasets.CUB(pth, testpth, mod_type)

Bases: BaseDataset

Dataset class for our processed version of Caltech-UCSD birds dataset. We use the original images and text represented as sequences of one-hot-encodings for each character (incl. spaces)

_mod_specific_loaders()

Assigns the preprocessing function based on the mod_type

_mod_specific_savers()

Assigns the postprocessing function based on the mod_type

_postprocess_text(data)
_preprocess_images()

General function for loading images and preparing them as torch tensors

Parameters:

dimensions (list) – feature_dim for the image modality

Returns:

preprocessed data

Return type:

torch.tensor

_preprocess_text()
_preprocess_text_onehot()

General function for loading text strings and preparing them as torch one-hot encodings

Returns:

torch with text encodings and masks

Return type:

torch.tensor

feature_dims = {'image': [64, 64, 3], 'text': [246, 27, 1]}
labels()

No labels for T-SNAE available

save_recons(data, recons, path, mod_names)
class multimodal_compare.models.datasets.FASHIONMNIST(pth, testpth, mod_type)

Bases: BaseDataset

Dataset class for the FashionMNIST dataset

_mod_specific_loaders()

Assigns the preprocessing function based on the mod_type

_mod_specific_savers()

Assigns the postprocessing function based on the mod_type

_postprocess_image(data)
_postprocess_label(data)
_process_image()
_process_label()
feature_dims = {'image': [28, 28, 1], 'label': [10]}
get_data_raw()

Loads raw data from path

Returns:

loaded raw data

Return type:

list

labels()

Returns labels for the whole dataset

save_recons(data, recons, path, mod_names)
class multimodal_compare.models.datasets.MNIST_SVHN(pth, testpth, mod_type)

Bases: BaseDataset

Dataset class for the MNIST-SVHN bimodal dataset (can be also used for unimodal training)

_mod_specific_loaders()

Assigns the preprocessing function based on the mod_type

_mod_specific_savers()

Assigns the postprocessing function based on the mod_type

_postprocess_all2img(data)

Converts any kind of data to images to save traversal visualizations

Parameters:

data (torch.tensor) – input data

Returns:

processed images

Return type:

torch.tensor

_postprocess_mnist(data)
_postprocess_svhn(data)
_process_mnist()
_process_svhn()
check_indices_present()
feature_dims = {'mnist': [28, 28, 1], 'svhn': [32, 32, 3]}
labels()

Returns labels for the whole dataset

save_recons(data, recons, path, mod_names)
class multimodal_compare.models.datasets.POLYMNIST(pth, testpth, mod_type)

Bases: BaseDataset

Dataset class for the POLYMNIST dataset

_mod_specific_loaders()

Assigns the preprocessing function based on the mod_type

_mod_specific_savers()

Assigns the postprocessing function based on the mod_type

_postprocess_mnist(data)
_process_mnist()
feature_dims = {'m0': [28, 28, 3], 'm1': [28, 28, 3], 'm2': [28, 28, 3], 'm3': [28, 28, 3], 'm4': [28, 28, 3]}
save_recons(data, recons, path, mod_names)
save_traversals(recons, path, num_dims)

Makes a grid of traversals and saves as image

Parameters:
  • recons (torch.tensor) – data to save

  • path (str) – path to save the traversal to

  • num_dims (int) – number of latent dimensions

class multimodal_compare.models.datasets.SPRITES(pth, testpth, mod_type)

Bases: BaseDataset

_mod_specific_loaders()

Assigns the preprocessing function based on the mod_type

_mod_specific_savers()

Assigns the postprocessing function based on the mod_type

_postprocess_actions(data)
_postprocess_all2img(data)

Converts any kind of data to images to save traversal visualizations

Parameters:

data (torch.tensor) – input data

Returns:

processed images

Return type:

torch.tensor

_postprocess_attributes(data)
_postprocess_frames(data)
eval_statistics_fn()

(optional) Returns a dataset-specific function that runs systematic evaluation

feature_dims = {'actions': [9], 'attributes': [4, 6], 'frames': [8, 64, 64, 3]}
get_actions()
get_attributes()
get_frames()
iter_over_inputs(outs, data, mod_names, f=0)
labels()

Returns labels for the whole dataset

make_masks(shape)
save_recons(data, recons, path, mod_names)
save_traversals(recons, path)

Makes a grid of traversals and saves as animated gif image

Parameters:
  • recons (torch.tensor) – data to save

  • path (str) – path to save the traversal to