VAE class

class multimodal_compare.models.vae.BaseVae(enc, dec, prior_dist=<class 'torch.distributions.normal.Normal'>, likelihood_dist=<class 'torch.distributions.normal.Normal'>, post_dist=<class 'torch.distributions.normal.Normal'>)

Bases: Module

Base VAE class for all implementations.

_is_full_backward_hook: bool | None
decode(inp)

Decodes the latent samples

Parameters:

inp (dict) – Samples dictionary

Returns:

decoded distribution parameters (means and logvars)

Return type:

tuple

encode(inp)

Encodes the inputs

Parameters:

inp (dict) – Inputs dictionary

Returns:

encoded distribution parameters (means and logvars)

Return type:

tuple

forward(x, K=1)

Forward pass

Parameters:
  • x (torch.tensor) – input modality

  • K (int) – sample K samples from the posterior

Returns:

the posterior distribution, the reconstruction and latent samples

:rtype:tuple(torch.dist, torch.dist, torch.tensor)

training: bool
class multimodal_compare.models.vae.DencoderFactory

Bases: object

classmethod get_nework_classes(enc_name, dec_name, n_latents, private_latents, data_dim: tuple)

Instantiates the encoder and decoder networks

Parameters:
  • enc (str) – encoder name

  • dec (str) – decoder name

Returns:

returns encoder and decoder class

Return type:

tuple(object, object)

class multimodal_compare.models.vae.VAE(enc, dec, feature_dim, n_latents, ltype, private_latents=None, prior_dist='normal', likelihood_dist='normal', post_dist='normal', obj_fn=None, beta=1, id_name='mod_1', llik_scaling='auto')

Bases: BaseVae

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Tensor | None]
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: bool | None
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, 'Module' | None]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Parameter | None]
_state_dict_hooks: Dict[int, Callable]
generate_samples(N, traversals=False, traversal_range=(-1, 1))

Generates samples from the latent space :param N: How many samples to make :type N: int :param traversals: whether to make latent traversals (True) or random samples (False) :type traversals: bool :param traversal_range: range of the traversals (if plausible) :type traversal_range: tuple :return: output reconstructions :rtype: torch.tensor

objective(data)

Objective function for unimodal VAE scenario (not used with multimodal VAEs)

Parameters:

data (dict) – input data with modalities as keys

Returns:

loss calculated using self.loss_fn

Return type:

torch.tensor

property pz_params

returns likelihood parameters :rtype: list(torch.tensor, torch.tensor)

Type:

return

property pz_params_private

returns likelihood parameters for the private latent space :rtype: list(torch.tensor, torch.tensor)

Type:

return

property qz_x_params

returns posterior distribution parameters :rtype: list(torch.tensor, torch.tensor)

Type:

return

set_objective_fn(obj_fn, beta)

Set up loss function in case of unimodal VAE

training: bool