VAE class
- class multimodal_compare.models.vae.BaseVae(enc, dec, prior_dist=<class 'torch.distributions.normal.Normal'>, likelihood_dist=<class 'torch.distributions.normal.Normal'>, post_dist=<class 'torch.distributions.normal.Normal'>)
Bases:
Module
Base VAE class for all implementations.
- _is_full_backward_hook: bool | None
- decode(inp)
Decodes the latent samples
- Parameters:
inp (dict) – Samples dictionary
- Returns:
decoded distribution parameters (means and logvars)
- Return type:
tuple
- encode(inp)
Encodes the inputs
- Parameters:
inp (dict) – Inputs dictionary
- Returns:
encoded distribution parameters (means and logvars)
- Return type:
tuple
- forward(x, K=1)
Forward pass
- Parameters:
x (torch.tensor) – input modality
K (int) – sample K samples from the posterior
- Returns:
the posterior distribution, the reconstruction and latent samples
:rtype:tuple(torch.dist, torch.dist, torch.tensor)
- training: bool
- class multimodal_compare.models.vae.DencoderFactory
Bases:
object
- classmethod get_nework_classes(enc_name, dec_name, n_latents, private_latents, data_dim: tuple)
Instantiates the encoder and decoder networks
- Parameters:
enc (str) – encoder name
dec (str) – decoder name
- Returns:
returns encoder and decoder class
- Return type:
tuple(object, object)
- class multimodal_compare.models.vae.VAE(enc, dec, feature_dim, n_latents, ltype, private_latents=None, prior_dist='normal', likelihood_dist='normal', post_dist='normal', obj_fn=None, beta=1, id_name='mod_1', llik_scaling='auto')
Bases:
BaseVae
- _backward_hooks: Dict[int, Callable]
- _buffers: Dict[str, Tensor | None]
- _forward_hooks: Dict[int, Callable]
- _forward_pre_hooks: Dict[int, Callable]
- _is_full_backward_hook: bool | None
- _load_state_dict_post_hooks: Dict[int, Callable]
- _load_state_dict_pre_hooks: Dict[int, Callable]
- _modules: Dict[str, 'Module' | None]
- _non_persistent_buffers_set: Set[str]
- _parameters: Dict[str, Parameter | None]
- _state_dict_hooks: Dict[int, Callable]
- generate_samples(N, traversals=False, traversal_range=(-1, 1))
Generates samples from the latent space :param N: How many samples to make :type N: int :param traversals: whether to make latent traversals (True) or random samples (False) :type traversals: bool :param traversal_range: range of the traversals (if plausible) :type traversal_range: tuple :return: output reconstructions :rtype: torch.tensor
- objective(data)
Objective function for unimodal VAE scenario (not used with multimodal VAEs)
- Parameters:
data (dict) – input data with modalities as keys
- Returns:
loss calculated using self.loss_fn
- Return type:
torch.tensor
- property pz_params
returns likelihood parameters :rtype: list(torch.tensor, torch.tensor)
- Type:
return
- property pz_params_private
returns likelihood parameters for the private latent space :rtype: list(torch.tensor, torch.tensor)
- Type:
return
- property qz_x_params
returns posterior distribution parameters :rtype: list(torch.tensor, torch.tensor)
- Type:
return
- set_objective_fn(obj_fn, beta)
Set up loss function in case of unimodal VAE
- training: bool