Objectives

class multimodal_compare.models.objectives.BaseObjective

Base objective class shared for all loss functions

calc_kld(dist1, dist2)

Calculate KL divergence between two distributions :param dist1: distribution 1 :type dist1: torch.dist :param dist2: distribution 2 :type dist2: torch.dist :return: KL divergence :rtype: torch.tensor

calc_klds(latent_dists, model)

Calculated th KL-divergence between the distribution and posterior dist.

Parameters:
  • latent_dists (list) – list of the two distributions

  • model (object) – model object

Returns:

list of klds

Return type:

list

static compute_microbatch_split(x, K)

Checks if batch needs to be broken down further to fit in memory. :param x: input data :type x: torch.tensor :param K: K samples will be made from each distribution :type K: int :return: microbatch split :rtype: torch.tensor

elbo(lpx_z, kld, beta=1)

The most general elbo function

Parameters:
  • lpx_z (torch.tensor) – reconstruction loss(es)

  • kld (torch.tensor) – KL divergence(s)

  • beta (torch.float) – disentangling factor

Returns:

ELBO loss

Return type:

torch.tensor

iwae(lp_z, lpx_z, lqz_x)

The most general iwae function.

Parameters:
  • lp_z (torch.tensor) – log probability of latent samples coming from the prior

  • lpx_z (torch.tensor) – reconstruction loss(es)

  • lqz_x (torch.tensor) – log probability of latent samples coming from the learned posterior

Returns:

IWAE loss

Return type:

torch.tensor

static normalize(target, data=None)

Normalize data between 0 and 1

Parameters:
  • target (torch.tensor) – target data

  • data (torch.tensor) – output data (optional)

Returns:

normalized data

Return type:

list

recon_loss_fn(output, target, K=1)

Calculate reconstruction loss

Parameters:
  • output (torch.tensor) – Output data, torch.dist or list

  • target (torch.tensor) – Target data

  • K (int) – K samples from posterior distribution

Returns:

computed loss

Return type:

torch.tensor

static reshape_for_loss(output, target, K=1)

Reshapes output and target to calculate reconstruction loss

Parameters:
  • output (torch.dist) – output likelihood

  • target (torch.dist) – target likelihood

  • ltype (str) – reconstruction loss

  • K (int) – K samples from posterior distribution

Returns:

reshaped data

Return type:

tuple(torch.tensor, torch.tensor, str)

set_ltype(ltype)

Checks the objective setup through a set of asserts

weighted_group_kld(latent_dists, model, weights)

Calculated the weighted group KL-divergence.

Parameters:
  • latent_dists (list) – list of the two distributions

  • model (object) – model object

  • weights (torch.Tensor) – tensor with weights for each distribution

Returns:

group divergence, list of klds

Return type:

tuple

class multimodal_compare.models.objectives.MultimodalObjective(obj: str, beta=1)

Common class for multimodal objectives

_m_dreg_looser(lpx_zs, pz, pz_params, zss, qz_xs, K=1)

DREG estimate for log p_ heta(x) for multi-modal vae – fully vectorised; Source: https://github.com/iffsid/mmvae This version is the looser bound—with the average over modalities outside the log

calculate_loss(data)

Calculates the loss using self.objective

Parameters:

px_z (dict) – dictionary with the required data for loss calculation

Returns:

calculated losses

Return type:

dict

dreg(data)

Computes dreg estimate for log p_ heta(x) for multi-modal vae; Source: https://github.com/iffsid/mmvae This version is the looser bound—with the average over modalities outside the log

elbo(data)

Computes multimodal ELBO E_{p(x)}[ELBO]

Parameters:

data – dict with the keys: lpx_z (recon losses) and kld (kl divergences)

Returns:

dict with loss, kl divergence, reconstruction loss and kld

Return type:

dict

iwae(data)

Computes multimodal IWAE

Parameters:

data – dict with the keys: lpx_z (recon losses) and kld (kl divergences)

Returns:

dict with loss, kl divergence, reconstruction loss and kld

Return type:

dict

class multimodal_compare.models.objectives.ReconLoss

Class that stores reconstruction loss functions

static bce(output, target, bs)

Binary Cross-Entropy loss

Parameters:
  • output (torch.distributions) – model output distribution

  • target (torch.tensor) – ground truth tensor

  • bs (int) – batch size

Returns:

calculated loss

Return type:

torch.Tensor.float

static category_ce(output, target, bs)

Categorical Cross-Entropy loss (for classification problems such as text)

Parameters:
  • output (torch.distributions) – model output distribution

  • target (torch.tensor) – ground truth tensor

  • bs (int) – batch size

Returns:

calculated loss

Return type:

torch.Tensor.float

static gaussian_nll(output, target, bs)

Calculate Gaussian NLL with optimal sigma as in Sigma VAE https://github.com/orybkin/sigma-vae-pytorch

static l1(output, target, bs)

L1 loss

Parameters:
  • output (torch.distributions) – model output distribution

  • target (torch.tensor) – ground truth tensor

  • bs (int) – batch size

Returns:

calculated loss

Return type:

torch.Tensor.float

static lprob(output, target, bs)

Log-likelihood loss

Parameters:
  • output (torch.distributions) – model output distribution

  • target (torch.tensor) – ground truth tensor

  • bs (int) – batch size

Returns:

calculated loss

Return type:

torch.Tensor.float

static mse(output, target, bs)

Mean squared error (squared L2 norm) loss

Parameters:
  • output (torch.distributions) – model output distribution

  • target (torch.tensor) – ground truth tensor

  • bs (int) – batch size

Returns:

calculated loss

Return type:

torch.Tensor.float

class multimodal_compare.models.objectives.UnimodalObjective(obj: str, beta=1)

Common class for unimodal objectives (used in unimodal VAEs only)

calculate_loss(px_z, target, qz_x, prior_dist, pz_params, zs, K=1)

Calculates the loss using self.objective

Parameters:
  • px_z (torch.distributions) – decoder distribution

  • target (torch.tensor) – ground truth

  • qz_x (torch.distribution) – posterior distribution

  • prior_dist (torch.distribution) – model’s prior

  • zs (torch.tensor) – latent samples

  • K (int) – how many samples were drawn from the posterior

Returns:

calculated losses

Return type:

dict

dreg(data)

DREG estimate for log p_ heta(x) – fully vectorised. Source: https://github.com/iffsid/mmvae

Parameters:

data – dict with the keys: px_z, target, qz_x, zs, K, prior_dist

Returns:

dict with loss, reconstruction loss and kld

Return type:

dict

elbo(data)

Computes unimodal ELBO E_{p(x)}[ELBO]

Parameters:

data – dict with the keys: px_z, target, qz_x, prior_dist, K

Returns:

dict with loss, kl divergence, reconstruction loss and kld

Return type:

dict

iwae(data)

Computes an importance-weighted ELBO estimate for log p_ heta(x) Source: https://github.com/iffsid/mmvae

Parameters:

data – dict with the keys: px_z, target, qz_x, zs, K

Returns:

dict with loss, reconstruction loss and kld

Return type:

dict