Multimodal VAE models
- class multimodal_compare.models.mmvae_models.DMVAE(vaes: list, n_latents: int, obj_config: dict, model_config=None)
Bases:
TorchMMVAE
- _is_full_backward_hook: bool | None
- forward(x, K=1)
Forward pass that takes input data and outputs a list of private and shared posteriors, reconstructions and latent samples :param inputs: input data, a list of modalities where missing modalities are replaced with None :type inputs: list :param K: sample K samples from the posterior :type K: int :return: a list of posterior distributions, a list of reconstructions and latent samples :rtype: tuple(list, list, list)
- get_remaining_mods_data(qz_xs: dict, exclude_mod: str)
- logsumexp(x, dim=None, keepdim=False)
A smooth maximum function :param x: input data :type x: torch.tensor :param dim: dimension :type dim: int :param keepdim: whether to keep shape or squeeze :type keepdim: bool :return: data :rtype: torch.tensor
- objective(mods)
Objective for the DMVAE model. Source: https://github.com/seqam-lab/
- Parameters:
data (dict) – input data with modalities as keys
- Return obj:
dictionary with the obligatory “loss” key on which the model is optimized, plus any other keys that you wish to log
- Rtype obj:
dict
- property pz_params
- training: bool
- class multimodal_compare.models.mmvae_models.MOE(vaes: list, n_latents: int, obj_config: dict, model_config=None)
Bases:
TorchMMVAE
- _is_full_backward_hook: bool | None
- forward(x, K=1)
Forward pass that takes input data and outputs a list of posteriors, reconstructions and latent samples
- Parameters:
x (list) – input data, a list of modalities where missing modalities are replaced with None
K (int) – sample K samples from the posterior
- Returns:
a list of posterior distributions, a list of reconstructions and latent samples
- Return type:
tuple(list, list, list)
- objective(data)
Objective function for MoE
- Parameters:
data (dict) – input data with modalities as keys
- Return obj:
dictionary with the obligatory “loss” key on which the model is optimized, plus any other keys that you wish to log
- Rtype obj:
dict
- property pz_params
- reconstruct(data, runPath, epoch)
Reconstruct data for individual experts
- Parameters:
data (list) – list of input modalities
runPath (str) – path to save data to
epoch (str) – current epoch to name the data
- training: bool
- class multimodal_compare.models.mmvae_models.MoPOE(vaes: list, n_latents: int, obj_config: dict, model_config=None)
Bases:
TorchMMVAE
- _is_full_backward_hook: bool | None
- forward(inputs, K=1)
Forward pass that takes input data and outputs a list of posteriors, reconstructions and latent samples :param inputs: input data, a list of modalities where missing modalities are replaced with None :type inputs: list :param K: sample K samples from the posterior :type K: int :return: a list of posterior distributions, a list of reconstructions and latent samples :rtype: tuple(list, list, list)
- mixture_component_selection(mus, logvars, w_modalities=None)
- modality_mixing(input_batch)
Mix the encoded distributions according to the chosen approach
- Parameters:
mods (dict) – qz_xs dictionary with modalities as keys and distribution parameters as values
- Returns:
latent samples dictionary with modalities as keys and latent sample tensors as values
- Return type:
dict
- moe_fusion(mus, logvars, weights)
- objective(mods)
Objective function for MoPoE. Computes GENERALIZED MULTIMODAL ELBO https://arxiv.org/pdf/2105.02470.pdf
- Parameters:
data (dict) – input data with modalities as keys
- Return obj:
dictionary with the obligatory “loss” key on which the model is optimized, plus any other keys that you wish to log
- Rtype obj:
dict
- poe_fusion(mus, logvars)
- property pz_params
- reparameterize(mu, logvar)
- reweight_weights(w)
- set_subsets()
powerset([1,2,3]) –> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
- training: bool
- class multimodal_compare.models.mmvae_models.POE(vaes: list, n_latents: int, obj_config: dict, model_config=None)
Bases:
TorchMMVAE
- _is_full_backward_hook: bool | None
- forward(inputs, K=1)
Forward pass that takes input data and outputs a dict with posteriors, reconstructions and latent samples :param inputs: input data, a dict of modalities where missing modalities are replaced with None :type inputs: dict :param K: sample K samples from the posterior :type K: int :return: dict where keys are modalities and values are a named tuple :rtype: dict
- modality_mixing(x)
Inference module, calculates the joint posterior :param inputs: input data, a dict of modalities where missing modalities are replaced with None :type inputs: dict :param K: sample K samples from the posterior :type K: int :return: joint posterior and individual posteriors :rtype: tuple(torch.tensor, torch.tensor, list, list)
- objective(mods)
Objective function for PoE
- Parameters:
data (dict) – input data with modalities as keys
- Return obj:
dictionary with the obligatory “loss” key on which the model is optimized, plus any other keys that you wish to log
- Rtype obj:
dict
- prior_expert(size, use_cuda=False)
Universal prior expert. Here we use a spherical Gaussian: N(0, 1).
- Parameters:
size (int) – dimensionality of the Gaussian
use_cuda (boolean) – cast CUDA on variables
- Returns:
mean and logvar of the expert
- Return type:
tuple
- property pz_params
- training: bool