MultimodalVAE class

class multimodal_compare.models.trainer.MultimodalVAE(cfg, feature_dims: dict)

Bases: LightningModule

Multimodal VAE trainer common for all architectures. Configures, trains and tests the model.

Parameters:
  • feature_dims (dict) – dictionary with feature dimensions of training data

  • cfg (object) – instance of Config class

analyse_data(data=None, labels=None, num_samples=250, path_name='', savedir=None, split='val')

Encodes data and plots T-SNE. If no data is passed, a dataloader (based on split=”val”/”test”) will be used.

Parameters:
  • data (torch.tensor) – test data

  • labels (list) – labels for the data for labelled T-SNE (optional) - list of strings

  • num_samples (int) – number of samples to use for visualization

  • path_name (str) – label under which to save the visualizations

  • savedir (str) – where to save the reconstructions

  • split (str) – val/test, whether to take samples from test or validation dataloader

check_config(cfg)

Creates a Config class out of the provided argument parser

Parameters:

cfg ((argparse.ArgumentParser, str)) – argument parser or str path to config

Returns:

Config instance

Return type:

object

configure_optimizers()

Sets up the optimizer specified in the config

property datamod

When the class is used for inference, there is no pl trainer module

Returns:

an instance of DataModule class

Return type:

eval_forward(data)

Forward pass used outside training, e.g. during evaluation

get_mod_names()

Creates a dictionary with modality numbers and their names based on dataset

Returns:

Dict with modality numbers as keys and names as values

Return type:

dict

get_model()

Sets up the model according to the config file

save_joint_samples(num_samples=16, savedir=None, traversals=False)

Generate joint samples from random vectors and save them

Parameters:
  • num_samples (int) – number of samples to generate

  • savedir (str) – where to save the reconstructions

  • traversals (bool) – whether to make traversals for each dimension (True) or randomly sample latents (False)

save_reconstructions(num_samples=10, savedir=None, split='val')

Reconstructs data and saves output, also iterates over missing modalities on the input to cross-generate

Parameters:
  • num_samples (int) – number of samples to take from the dataloader for reconstruction

  • savedir (str) – where to save the reconstructions

  • split (str) – val/test, whether to take samples from test or validation dataloader

test_epoch_end(outputs)

Visualizations to make at the end of the testing epoch

test_step(test_batch, batch_idx)

Iterates over the test loader (if test data is provided, otherwise val loader)

training_step(train_batch, batch_idx)

Iterates over the train loader

validation_epoch_end(outputs)

Save visualizations at the end of validation epoch

Parameters:

outputs (torch.tensor) – Loss that comes from validation_step

validation_step(val_batch, batch_idx)

Iterates over the val loader