Multimodal VAE Base Class
- class multimodal_compare.models.mmvae_base.TorchMMVAE(vaes: list, n_latents: int, obj: str, beta=1, K=1)
Bases:
Module
Base class for all PyTorch based MMVAE implementations.
- _is_full_backward_hook: bool | None
- add_vaes(vae_dict: ModuleDict)
This functions updates the VAEs of the MMVAE with a given dictionary.
- Parameters:
vae_dict (nn.ModuleDict) – A dictionary with the modality names as keys and BaseVAEs as values
- decode(samples)
Make reconstructions for the input samples
- Parameters:
samples (dict) – Dictionary with modalities as keys and latent sample tensors as values
- Returns:
dictionary with modalities as keys and torch distributions as values
- Return type:
dict
- encode(inputs)
Encode inputs with appropriate VAE encoders
- Parameters:
inputs (dict) – input dictionary with modalities as keys and data tensors as values
- Returns:
qz_xs dictionary with modalities as keys and distribution parameters as values
- Return type:
dict
- forward(inputs, K=1)
The general forward pass of multimodal VAE
- Parameters:
inputs (dict) – input dictionary with modalities as keys and data tensors as values
K (int) – number of samples
- Returns:
dictionary with modalities as keys and namedtuples as values
- Return type:
dict[str,VaeOutput]
- get_missing_modalities(mods)
Get indices of modalities that are missing on the input
- Parameters:
mods (list) – list of modalities
- Returns:
list of indices of missing modalities
- Return type:
list
- property latent_factorization
Returns True if latent space is factorized into shared and modality-specific subspaces, else False
- make_output_dict(encoder_dist=None, decoder_dist=None, latent_samples=None, joint_dist=None, enc_dist_private=None, dec_dist_private=None, joint_decoder_dist=None, cross_decoder_dist=None)
Prepares output of the forward pass
- Parameters:
encoder_dist (dict) – dict with modalities as keys and encoder distributions as values
decoder_dist (dict) – dict with modalities as keys and decoder distributions as values
latent_samples (dict) – dict with modalities as keys and dicts with latent samples as values
joint_dist (dict) – dict with modalities as keys and joint distribution as values
enc_dist_private (dict) – dict with modalities as keys and dicts with single latent distributions as values
dec_dist_private (dict) – dict with modalities as keys and dicts with single decoder distributions as values
joint_decoder_dist (dict) – dict with modalities as keys and dicts with decoder distributions coming from joint distribution
cross_decoder_dist (dict) – dict with modalities as keys and dicts with cross-modal decoder distributions
- Returns:
VAEOutput object
- Return type:
object
- abstract modality_mixing(mods)
Mix the encoded distributions according to the chosen approach
- Parameters:
mods (dict) – qz_xs dictionary with modalities as keys and distribution parameters as values
- Returns:
latent samples dictionary with modalities as keys and latent sample tensors as values
- Return type:
dict
- abstract objective(mods)
Includes the forward pass and calculates the loss
- Parameters:
mods (dict) – dictionary with input data with modalities as keys
- Returns:
loss
- Return type:
dict
- static product_of_experts(mu, logvar)
Calculate the product of experts for input data
- Parameters:
mu (list) – list of means
logvar (list) – list of logvars
- Returns:
joint posterior
- Return type:
tuple(torch.tensor, torch.tensor)
- property pz_params
- set_likelihood_scales()
- training: bool