Objectives
- class multimodal_compare.models.objectives.BaseObjective
Base objective class shared for all loss functions
- calc_kld(dist1, dist2)
Calculate KL divergence between two distributions :param dist1: distribution 1 :type dist1: torch.dist :param dist2: distribution 2 :type dist2: torch.dist :return: KL divergence :rtype: torch.tensor
- calc_klds(latent_dists, model)
Calculated th KL-divergence between the distribution and posterior dist.
- Parameters:
latent_dists (list) – list of the two distributions
model (object) – model object
- Returns:
list of klds
- Return type:
list
- static compute_microbatch_split(x, K)
Checks if batch needs to be broken down further to fit in memory. :param x: input data :type x: torch.tensor :param K: K samples will be made from each distribution :type K: int :return: microbatch split :rtype: torch.tensor
- elbo(lpx_z, kld, beta=1)
The most general elbo function
- Parameters:
lpx_z (torch.tensor) – reconstruction loss(es)
kld (torch.tensor) – KL divergence(s)
beta (torch.float) – disentangling factor
- Returns:
ELBO loss
- Return type:
torch.tensor
- iwae(lp_z, lpx_z, lqz_x)
The most general iwae function.
- Parameters:
lp_z (torch.tensor) – log probability of latent samples coming from the prior
lpx_z (torch.tensor) – reconstruction loss(es)
lqz_x (torch.tensor) – log probability of latent samples coming from the learned posterior
- Returns:
IWAE loss
- Return type:
torch.tensor
- static normalize(target, data=None)
Normalize data between 0 and 1
- Parameters:
target (torch.tensor) – target data
data (torch.tensor) – output data (optional)
- Returns:
normalized data
- Return type:
list
- recon_loss_fn(output, target, K=1)
Calculate reconstruction loss
- Parameters:
output (torch.tensor) – Output data, torch.dist or list
target (torch.tensor) – Target data
K (int) – K samples from posterior distribution
- Returns:
computed loss
- Return type:
torch.tensor
- static reshape_for_loss(output, target, K=1)
Reshapes output and target to calculate reconstruction loss
- Parameters:
output (torch.dist) – output likelihood
target (torch.dist) – target likelihood
ltype (str) – reconstruction loss
K (int) – K samples from posterior distribution
- Returns:
reshaped data
- Return type:
tuple(torch.tensor, torch.tensor, str)
- set_ltype(ltype)
Checks the objective setup through a set of asserts
- weighted_group_kld(latent_dists, model, weights)
Calculated the weighted group KL-divergence.
- Parameters:
latent_dists (list) – list of the two distributions
model (object) – model object
weights (torch.Tensor) – tensor with weights for each distribution
- Returns:
group divergence, list of klds
- Return type:
tuple
- class multimodal_compare.models.objectives.MultimodalObjective(obj: str, beta=1)
Common class for multimodal objectives
- _m_dreg_looser(lpx_zs, pz, pz_params, zss, qz_xs, K=1)
DREG estimate for log p_ heta(x) for multi-modal vae – fully vectorised; Source: https://github.com/iffsid/mmvae This version is the looser bound—with the average over modalities outside the log
- calculate_loss(data)
Calculates the loss using self.objective
- Parameters:
px_z (dict) – dictionary with the required data for loss calculation
- Returns:
calculated losses
- Return type:
dict
- dreg(data)
Computes dreg estimate for log p_ heta(x) for multi-modal vae; Source: https://github.com/iffsid/mmvae This version is the looser bound—with the average over modalities outside the log
- elbo(data)
Computes multimodal ELBO E_{p(x)}[ELBO]
- Parameters:
data – dict with the keys: lpx_z (recon losses) and kld (kl divergences)
- Returns:
dict with loss, kl divergence, reconstruction loss and kld
- Return type:
dict
- iwae(data)
Computes multimodal IWAE
- Parameters:
data – dict with the keys: lpx_z (recon losses) and kld (kl divergences)
- Returns:
dict with loss, kl divergence, reconstruction loss and kld
- Return type:
dict
- class multimodal_compare.models.objectives.ReconLoss
Class that stores reconstruction loss functions
- static bce(output, target, bs)
Binary Cross-Entropy loss
- Parameters:
output (torch.distributions) – model output distribution
target (torch.tensor) – ground truth tensor
bs (int) – batch size
- Returns:
calculated loss
- Return type:
torch.Tensor.float
- static category_ce(output, target, bs)
Categorical Cross-Entropy loss (for classification problems such as text)
- Parameters:
output (torch.distributions) – model output distribution
target (torch.tensor) – ground truth tensor
bs (int) – batch size
- Returns:
calculated loss
- Return type:
torch.Tensor.float
- static gaussian_nll(output, target, bs)
Calculate Gaussian NLL with optimal sigma as in Sigma VAE https://github.com/orybkin/sigma-vae-pytorch
- static l1(output, target, bs)
L1 loss
- Parameters:
output (torch.distributions) – model output distribution
target (torch.tensor) – ground truth tensor
bs (int) – batch size
- Returns:
calculated loss
- Return type:
torch.Tensor.float
- static lprob(output, target, bs)
Log-likelihood loss
- Parameters:
output (torch.distributions) – model output distribution
target (torch.tensor) – ground truth tensor
bs (int) – batch size
- Returns:
calculated loss
- Return type:
torch.Tensor.float
- static mse(output, target, bs)
Mean squared error (squared L2 norm) loss
- Parameters:
output (torch.distributions) – model output distribution
target (torch.tensor) – ground truth tensor
bs (int) – batch size
- Returns:
calculated loss
- Return type:
torch.Tensor.float
- class multimodal_compare.models.objectives.UnimodalObjective(obj: str, beta=1)
Common class for unimodal objectives (used in unimodal VAEs only)
- calculate_loss(px_z, target, qz_x, prior_dist, pz_params, zs, K=1)
Calculates the loss using self.objective
- Parameters:
px_z (torch.distributions) – decoder distribution
target (torch.tensor) – ground truth
qz_x (torch.distribution) – posterior distribution
prior_dist (torch.distribution) – model’s prior
zs (torch.tensor) – latent samples
K (int) – how many samples were drawn from the posterior
- Returns:
calculated losses
- Return type:
dict
- dreg(data)
DREG estimate for log p_ heta(x) – fully vectorised. Source: https://github.com/iffsid/mmvae
- Parameters:
data – dict with the keys: px_z, target, qz_x, zs, K, prior_dist
- Returns:
dict with loss, reconstruction loss and kld
- Return type:
dict
- elbo(data)
Computes unimodal ELBO E_{p(x)}[ELBO]
- Parameters:
data – dict with the keys: px_z, target, qz_x, prior_dist, K
- Returns:
dict with loss, kl divergence, reconstruction loss and kld
- Return type:
dict
- iwae(data)
Computes an importance-weighted ELBO estimate for log p_ heta(x) Source: https://github.com/iffsid/mmvae
- Parameters:
data – dict with the keys: px_z, target, qz_x, zs, K
- Returns:
dict with loss, reconstruction loss and kld
- Return type:
dict