Multimodal VAE Base Class

class multimodal_compare.models.mmvae_base.TorchMMVAE(vaes: list, n_latents: int, obj: str, beta=1, K=1)

Bases: Module

Base class for all PyTorch based MMVAE implementations.

_is_full_backward_hook: bool | None
add_vaes(vae_dict: ModuleDict)

This functions updates the VAEs of the MMVAE with a given dictionary.

Parameters:

vae_dict (nn.ModuleDict) – A dictionary with the modality names as keys and BaseVAEs as values

decode(samples)

Make reconstructions for the input samples

Parameters:

samples (dict) – Dictionary with modalities as keys and latent sample tensors as values

Returns:

dictionary with modalities as keys and torch distributions as values

Return type:

dict

encode(inputs)

Encode inputs with appropriate VAE encoders

Parameters:

inputs (dict) – input dictionary with modalities as keys and data tensors as values

Returns:

qz_xs dictionary with modalities as keys and distribution parameters as values

Return type:

dict

forward(inputs, K=1)

The general forward pass of multimodal VAE

Parameters:
  • inputs (dict) – input dictionary with modalities as keys and data tensors as values

  • K (int) – number of samples

Returns:

dictionary with modalities as keys and namedtuples as values

Return type:

dict[str,VaeOutput]

get_missing_modalities(mods)

Get indices of modalities that are missing on the input

Parameters:

mods (list) – list of modalities

Returns:

list of indices of missing modalities

Return type:

list

property latent_factorization

Returns True if latent space is factorized into shared and modality-specific subspaces, else False

make_output_dict(encoder_dist=None, decoder_dist=None, latent_samples=None, joint_dist=None, enc_dist_private=None, dec_dist_private=None, joint_decoder_dist=None, cross_decoder_dist=None)

Prepares output of the forward pass

Parameters:
  • encoder_dist (dict) – dict with modalities as keys and encoder distributions as values

  • decoder_dist (dict) – dict with modalities as keys and decoder distributions as values

  • latent_samples (dict) – dict with modalities as keys and dicts with latent samples as values

  • joint_dist (dict) – dict with modalities as keys and joint distribution as values

  • enc_dist_private (dict) – dict with modalities as keys and dicts with single latent distributions as values

  • dec_dist_private (dict) – dict with modalities as keys and dicts with single decoder distributions as values

  • joint_decoder_dist (dict) – dict with modalities as keys and dicts with decoder distributions coming from joint distribution

  • cross_decoder_dist (dict) – dict with modalities as keys and dicts with cross-modal decoder distributions

Returns:

VAEOutput object

Return type:

object

abstract modality_mixing(mods)

Mix the encoded distributions according to the chosen approach

Parameters:

mods (dict) – qz_xs dictionary with modalities as keys and distribution parameters as values

Returns:

latent samples dictionary with modalities as keys and latent sample tensors as values

Return type:

dict

abstract objective(mods)

Includes the forward pass and calculates the loss

Parameters:

mods (dict) – dictionary with input data with modalities as keys

Returns:

loss

Return type:

dict

static product_of_experts(mu, logvar)

Calculate the product of experts for input data

Parameters:
  • mu (list) – list of means

  • logvar (list) – list of logvars

Returns:

joint posterior

Return type:

tuple(torch.tensor, torch.tensor)

property pz_params
set_likelihood_scales()
training: bool