MultimodalVAE class
- class multimodal_compare.models.trainer.MultimodalVAE(cfg, feature_dims: dict)
Bases:
LightningModule
Multimodal VAE trainer common for all architectures. Configures, trains and tests the model.
- Parameters:
feature_dims (dict) – dictionary with feature dimensions of training data
cfg (object) – instance of Config class
- analyse_data(data=None, labels=None, num_samples=250, path_name='', savedir=None, split='val')
Encodes data and plots T-SNE. If no data is passed, a dataloader (based on split=”val”/”test”) will be used.
- Parameters:
data (torch.tensor) – test data
labels (list) – labels for the data for labelled T-SNE (optional) - list of strings
num_samples (int) – number of samples to use for visualization
path_name (str) – label under which to save the visualizations
savedir (str) – where to save the reconstructions
split (str) – val/test, whether to take samples from test or validation dataloader
- check_config(cfg)
Creates a Config class out of the provided argument parser
- Parameters:
cfg ((argparse.ArgumentParser, str)) – argument parser or str path to config
- Returns:
Config instance
- Return type:
object
- configure_optimizers()
Sets up the optimizer specified in the config
- property datamod
When the class is used for inference, there is no pl trainer module
- Returns:
an instance of DataModule class
- Return type:
- eval_forward(data)
Forward pass used outside training, e.g. during evaluation
- get_mod_names()
Creates a dictionary with modality numbers and their names based on dataset
- Returns:
Dict with modality numbers as keys and names as values
- Return type:
dict
- get_model()
Sets up the model according to the config file
- save_joint_samples(num_samples=16, savedir=None, traversals=False)
Generate joint samples from random vectors and save them
- Parameters:
num_samples (int) – number of samples to generate
savedir (str) – where to save the reconstructions
traversals (bool) – whether to make traversals for each dimension (True) or randomly sample latents (False)
- save_reconstructions(num_samples=10, savedir=None, split='val')
Reconstructs data and saves output, also iterates over missing modalities on the input to cross-generate
- Parameters:
num_samples (int) – number of samples to take from the dataloader for reconstruction
savedir (str) – where to save the reconstructions
split (str) – val/test, whether to take samples from test or validation dataloader
- test_epoch_end(outputs)
Visualizations to make at the end of the testing epoch
- test_step(test_batch, batch_idx)
Iterates over the test loader (if test data is provided, otherwise val loader)
- training_step(train_batch, batch_idx)
Iterates over the train loader
- validation_epoch_end(outputs)
Save visualizations at the end of validation epoch
- Parameters:
outputs (torch.tensor) – Loss that comes from validation_step
- validation_step(val_batch, batch_idx)
Iterates over the val loader