Multimodal VAE models

class multimodal_compare.models.mmvae_models.DMVAE(vaes: list, n_latents: int, obj_config: dict, model_config=None)

Bases: TorchMMVAE

_is_full_backward_hook: bool | None
forward(x, K=1)

Forward pass that takes input data and outputs a list of private and shared posteriors, reconstructions and latent samples :param inputs: input data, a list of modalities where missing modalities are replaced with None :type inputs: list :param K: sample K samples from the posterior :type K: int :return: a list of posterior distributions, a list of reconstructions and latent samples :rtype: tuple(list, list, list)

get_remaining_mods_data(qz_xs: dict, exclude_mod: str)
logsumexp(x, dim=None, keepdim=False)

A smooth maximum function :param x: input data :type x: torch.tensor :param dim: dimension :type dim: int :param keepdim: whether to keep shape or squeeze :type keepdim: bool :return: data :rtype: torch.tensor

objective(mods)

Objective for the DMVAE model. Source: https://github.com/seqam-lab/

Parameters:

data (dict) – input data with modalities as keys

Return obj:

dictionary with the obligatory “loss” key on which the model is optimized, plus any other keys that you wish to log

Rtype obj:

dict

property pz_params
training: bool
class multimodal_compare.models.mmvae_models.MOE(vaes: list, n_latents: int, obj_config: dict, model_config=None)

Bases: TorchMMVAE

_is_full_backward_hook: bool | None
forward(x, K=1)

Forward pass that takes input data and outputs a list of posteriors, reconstructions and latent samples

Parameters:
  • x (list) – input data, a list of modalities where missing modalities are replaced with None

  • K (int) – sample K samples from the posterior

Returns:

a list of posterior distributions, a list of reconstructions and latent samples

Return type:

tuple(list, list, list)

objective(data)

Objective function for MoE

Parameters:

data (dict) – input data with modalities as keys

Return obj:

dictionary with the obligatory “loss” key on which the model is optimized, plus any other keys that you wish to log

Rtype obj:

dict

property pz_params
reconstruct(data, runPath, epoch)

Reconstruct data for individual experts

Parameters:
  • data (list) – list of input modalities

  • runPath (str) – path to save data to

  • epoch (str) – current epoch to name the data

training: bool
class multimodal_compare.models.mmvae_models.MoPOE(vaes: list, n_latents: int, obj_config: dict, model_config=None)

Bases: TorchMMVAE

_is_full_backward_hook: bool | None
forward(inputs, K=1)

Forward pass that takes input data and outputs a list of posteriors, reconstructions and latent samples :param inputs: input data, a list of modalities where missing modalities are replaced with None :type inputs: list :param K: sample K samples from the posterior :type K: int :return: a list of posterior distributions, a list of reconstructions and latent samples :rtype: tuple(list, list, list)

mixture_component_selection(mus, logvars, w_modalities=None)
modality_mixing(input_batch)

Mix the encoded distributions according to the chosen approach

Parameters:

mods (dict) – qz_xs dictionary with modalities as keys and distribution parameters as values

Returns:

latent samples dictionary with modalities as keys and latent sample tensors as values

Return type:

dict

moe_fusion(mus, logvars, weights)
objective(mods)

Objective function for MoPoE. Computes GENERALIZED MULTIMODAL ELBO https://arxiv.org/pdf/2105.02470.pdf

Parameters:

data (dict) – input data with modalities as keys

Return obj:

dictionary with the obligatory “loss” key on which the model is optimized, plus any other keys that you wish to log

Rtype obj:

dict

poe_fusion(mus, logvars)
property pz_params
reparameterize(mu, logvar)
reweight_weights(w)
set_subsets()

powerset([1,2,3]) –> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)

training: bool
class multimodal_compare.models.mmvae_models.POE(vaes: list, n_latents: int, obj_config: dict, model_config=None)

Bases: TorchMMVAE

_is_full_backward_hook: bool | None
forward(inputs, K=1)

Forward pass that takes input data and outputs a dict with posteriors, reconstructions and latent samples :param inputs: input data, a dict of modalities where missing modalities are replaced with None :type inputs: dict :param K: sample K samples from the posterior :type K: int :return: dict where keys are modalities and values are a named tuple :rtype: dict

modality_mixing(x)

Inference module, calculates the joint posterior :param inputs: input data, a dict of modalities where missing modalities are replaced with None :type inputs: dict :param K: sample K samples from the posterior :type K: int :return: joint posterior and individual posteriors :rtype: tuple(torch.tensor, torch.tensor, list, list)

objective(mods)

Objective function for PoE

Parameters:

data (dict) – input data with modalities as keys

Return obj:

dictionary with the obligatory “loss” key on which the model is optimized, plus any other keys that you wish to log

Rtype obj:

dict

prior_expert(size, use_cuda=False)

Universal prior expert. Here we use a spherical Gaussian: N(0, 1).

Parameters:
  • size (int) – dimensionality of the Gaussian

  • use_cuda (boolean) – cast CUDA on variables

Returns:

mean and logvar of the expert

Return type:

tuple

property pz_params
training: bool