multimodal-vae-comparison
Paper Results
Supplementary material
Tutorials
Add a new model
Add a new dataset
Code documentation
MultimodalVAE class
Multimodal VAE Base Class
Multimodal VAE models
Encoders
Decoders
VAE class
Objectives
DataLoader
Dataset Classes
Inference module
Evaluate on CdSprites+ dataset
Config class
multimodal-vae-comparison
Index
Index
_
|
A
|
B
|
C
|
D
|
E
|
F
|
G
|
I
|
L
|
M
|
N
|
O
|
P
|
Q
|
R
|
S
|
T
|
U
|
V
|
W
_
_backward_hooks (multimodal_compare.models.encoders.VaeEncoder attribute)
(multimodal_compare.models.vae.VAE attribute)
_buffers (multimodal_compare.models.encoders.VaeEncoder attribute)
(multimodal_compare.models.vae.VAE attribute)
_define_params() (multimodal_compare.models.config_cls.Config method)
_forward_hooks (multimodal_compare.models.encoders.VaeEncoder attribute)
(multimodal_compare.models.vae.VAE attribute)
_forward_pre_hooks (multimodal_compare.models.encoders.VaeEncoder attribute)
(multimodal_compare.models.vae.VAE attribute)
_get_mods_config() (multimodal_compare.models.config_cls.Config method)
_is_full_backward_hook (multimodal_compare.models.decoders.Dec_CNN attribute)
(multimodal_compare.models.decoders.Dec_FNN attribute)
(multimodal_compare.models.decoders.Dec_MNIST attribute)
(multimodal_compare.models.decoders.Dec_MNIST2 attribute)
(multimodal_compare.models.decoders.Dec_PolyMNIST attribute)
(multimodal_compare.models.decoders.Dec_SVHN attribute)
(multimodal_compare.models.decoders.Dec_SVHN2 attribute)
(multimodal_compare.models.decoders.Dec_Transformer attribute)
(multimodal_compare.models.decoders.Dec_TransformerIMG attribute)
(multimodal_compare.models.decoders.Dec_TxtTransformer attribute)
(multimodal_compare.models.decoders.Dec_VideoGPT attribute)
(multimodal_compare.models.decoders.VaeDecoder attribute)
(multimodal_compare.models.encoders.Enc_CNN attribute)
(multimodal_compare.models.encoders.Enc_FNN attribute)
(multimodal_compare.models.encoders.Enc_MNIST attribute)
(multimodal_compare.models.encoders.Enc_MNIST2 attribute)
(multimodal_compare.models.encoders.Enc_PolyMNIST attribute)
(multimodal_compare.models.encoders.Enc_SVHN attribute)
(multimodal_compare.models.encoders.Enc_SVHN2 attribute)
(multimodal_compare.models.encoders.Enc_Transformer attribute)
(multimodal_compare.models.encoders.Enc_TransformerIMG attribute)
(multimodal_compare.models.encoders.Enc_TxtTransformer attribute)
(multimodal_compare.models.encoders.Enc_VideoGPT attribute)
(multimodal_compare.models.encoders.VaeComponent attribute)
(multimodal_compare.models.encoders.VaeEncoder attribute)
(multimodal_compare.models.mmvae_base.TorchMMVAE attribute)
(multimodal_compare.models.mmvae_models.DMVAE attribute)
(multimodal_compare.models.mmvae_models.MOE attribute)
(multimodal_compare.models.mmvae_models.MoPOE attribute)
(multimodal_compare.models.mmvae_models.POE attribute)
(multimodal_compare.models.vae.BaseVae attribute)
(multimodal_compare.models.vae.VAE attribute)
_load_config() (multimodal_compare.models.config_cls.Config method)
_load_state_dict_post_hooks (multimodal_compare.models.encoders.VaeEncoder attribute)
(multimodal_compare.models.vae.VAE attribute)
_load_state_dict_pre_hooks (multimodal_compare.models.encoders.VaeEncoder attribute)
(multimodal_compare.models.vae.VAE attribute)
_m_dreg_looser() (multimodal_compare.models.objectives.MultimodalObjective method)
_mod_specific_loaders() (multimodal_compare.models.datasets.BaseDataset method)
(multimodal_compare.models.datasets.CDSPRITESPLUS method)
(multimodal_compare.models.datasets.CELEBA method)
(multimodal_compare.models.datasets.CUB method)
(multimodal_compare.models.datasets.FASHIONMNIST method)
(multimodal_compare.models.datasets.MNIST_SVHN method)
(multimodal_compare.models.datasets.POLYMNIST method)
(multimodal_compare.models.datasets.SPRITES method)
_mod_specific_savers() (multimodal_compare.models.datasets.BaseDataset method)
(multimodal_compare.models.datasets.CDSPRITESPLUS method)
(multimodal_compare.models.datasets.CELEBA method)
(multimodal_compare.models.datasets.CUB method)
(multimodal_compare.models.datasets.FASHIONMNIST method)
(multimodal_compare.models.datasets.MNIST_SVHN method)
(multimodal_compare.models.datasets.POLYMNIST method)
(multimodal_compare.models.datasets.SPRITES method)
_modules (multimodal_compare.models.encoders.VaeEncoder attribute)
(multimodal_compare.models.vae.VAE attribute)
_non_persistent_buffers_set (multimodal_compare.models.encoders.VaeEncoder attribute)
(multimodal_compare.models.vae.VAE attribute)
_parameters (multimodal_compare.models.encoders.VaeEncoder attribute)
(multimodal_compare.models.vae.VAE attribute)
_parse_args() (multimodal_compare.models.config_cls.Config method)
_postprocess() (multimodal_compare.models.datasets.BaseDataset method)
_postprocess_actions() (multimodal_compare.models.datasets.SPRITES method)
_postprocess_all2img() (multimodal_compare.models.datasets.BaseDataset method)
(multimodal_compare.models.datasets.CELEBA method)
(multimodal_compare.models.datasets.MNIST_SVHN method)
(multimodal_compare.models.datasets.SPRITES method)
_postprocess_attributes() (multimodal_compare.models.datasets.SPRITES method)
_postprocess_atts() (multimodal_compare.models.datasets.CELEBA method)
_postprocess_frames() (multimodal_compare.models.datasets.SPRITES method)
_postprocess_image() (multimodal_compare.models.datasets.FASHIONMNIST method)
_postprocess_images() (multimodal_compare.models.datasets.CDSPRITESPLUS method)
(multimodal_compare.models.datasets.CELEBA method)
_postprocess_label() (multimodal_compare.models.datasets.FASHIONMNIST method)
_postprocess_mnist() (multimodal_compare.models.datasets.MNIST_SVHN method)
(multimodal_compare.models.datasets.POLYMNIST method)
_postprocess_svhn() (multimodal_compare.models.datasets.MNIST_SVHN method)
_postprocess_text() (multimodal_compare.models.datasets.CDSPRITESPLUS method)
(multimodal_compare.models.datasets.CUB method)
_preprocess() (multimodal_compare.models.datasets.BaseDataset method)
_preprocess_atts() (multimodal_compare.models.datasets.CELEBA method)
_preprocess_images() (multimodal_compare.models.datasets.BaseDataset method)
(multimodal_compare.models.datasets.CDSPRITESPLUS method)
(multimodal_compare.models.datasets.CELEBA method)
(multimodal_compare.models.datasets.CUB method)
_preprocess_text() (multimodal_compare.models.datasets.CDSPRITESPLUS method)
(multimodal_compare.models.datasets.CUB method)
_preprocess_text_onehot() (multimodal_compare.models.datasets.BaseDataset method)
(multimodal_compare.models.datasets.CUB method)
_process_image() (multimodal_compare.models.datasets.FASHIONMNIST method)
_process_label() (multimodal_compare.models.datasets.FASHIONMNIST method)
_process_mnist() (multimodal_compare.models.datasets.MNIST_SVHN method)
(multimodal_compare.models.datasets.POLYMNIST method)
_process_svhn() (multimodal_compare.models.datasets.MNIST_SVHN method)
_setup_savedir() (multimodal_compare.models.config_cls.Config method)
_state_dict_hooks (multimodal_compare.models.encoders.VaeEncoder attribute)
(multimodal_compare.models.vae.VAE attribute)
A
add_vaes() (multimodal_compare.models.mmvae_base.TorchMMVAE method)
analyse_data() (multimodal_compare.models.trainer.MultimodalVAE method)
B
BaseDataset (class in multimodal_compare.models.datasets)
BaseObjective (class in multimodal_compare.models.objectives)
BaseVae (class in multimodal_compare.models.vae)
bce() (multimodal_compare.models.objectives.ReconLoss static method)
C
calc_kld() (multimodal_compare.models.objectives.BaseObjective method)
calc_klds() (multimodal_compare.models.objectives.BaseObjective method)
calculate_cross_coherency() (in module multimodal_compare.eval.eval_cdsprites)
calculate_joint_coherency() (in module multimodal_compare.eval.eval_cdsprites)
calculate_loss() (multimodal_compare.models.objectives.MultimodalObjective method)
(multimodal_compare.models.objectives.UnimodalObjective method)
category_ce() (multimodal_compare.models.objectives.ReconLoss static method)
CDSPRITESPLUS (class in multimodal_compare.models.datasets)
CELEBA (class in multimodal_compare.models.datasets)
change_seed() (multimodal_compare.models.config_cls.Config method)
check_config() (multimodal_compare.models.trainer.MultimodalVAE method)
check_cross_sample_correct() (in module multimodal_compare.eval.eval_cdsprites)
check_indices_present() (multimodal_compare.models.datasets.MNIST_SVHN method)
check_load_testdata() (multimodal_compare.models.dataloader.DataModule method)
check_testdata_avail() (multimodal_compare.models.dataloader.DataModule method)
collate_fn() (multimodal_compare.models.dataloader.DataModule method)
compute_microbatch_split() (multimodal_compare.models.objectives.BaseObjective static method)
Config (class in multimodal_compare.models.config_cls)
configure_optimizers() (multimodal_compare.models.trainer.MultimodalVAE method)
count_same_letters() (in module multimodal_compare.eval.eval_cdsprites)
CUB (class in multimodal_compare.models.datasets)
current_datatype() (multimodal_compare.models.datasets.BaseDataset method)
D
datamod (multimodal_compare.models.trainer.MultimodalVAE property)
DataModule (class in multimodal_compare.models.dataloader)
Dec_CNN (class in multimodal_compare.models.decoders)
Dec_FNN (class in multimodal_compare.models.decoders)
Dec_MNIST (class in multimodal_compare.models.decoders)
Dec_MNIST2 (class in multimodal_compare.models.decoders)
Dec_PolyMNIST (class in multimodal_compare.models.decoders)
Dec_SVHN (class in multimodal_compare.models.decoders)
Dec_SVHN2 (class in multimodal_compare.models.decoders)
Dec_Transformer (class in multimodal_compare.models.decoders)
Dec_TransformerIMG (class in multimodal_compare.models.decoders)
Dec_TxtTransformer (class in multimodal_compare.models.decoders)
Dec_VideoGPT (class in multimodal_compare.models.decoders)
decode() (multimodal_compare.models.mmvae_base.TorchMMVAE method)
(multimodal_compare.models.vae.BaseVae method)
DencoderFactory (class in multimodal_compare.models.vae)
DMVAE (class in multimodal_compare.models.mmvae_models)
dreg() (multimodal_compare.models.objectives.MultimodalObjective method)
(multimodal_compare.models.objectives.UnimodalObjective method)
dump_config() (multimodal_compare.models.config_cls.Config method)
E
elbo() (multimodal_compare.models.objectives.BaseObjective method)
(multimodal_compare.models.objectives.MultimodalObjective method)
(multimodal_compare.models.objectives.UnimodalObjective method)
Enc_CNN (class in multimodal_compare.models.encoders)
Enc_FNN (class in multimodal_compare.models.encoders)
Enc_MNIST (class in multimodal_compare.models.encoders)
Enc_MNIST2 (class in multimodal_compare.models.encoders)
Enc_PolyMNIST (class in multimodal_compare.models.encoders)
Enc_SVHN (class in multimodal_compare.models.encoders)
Enc_SVHN2 (class in multimodal_compare.models.encoders)
Enc_Transformer (class in multimodal_compare.models.encoders)
Enc_TransformerIMG (class in multimodal_compare.models.encoders)
Enc_TxtTransformer (class in multimodal_compare.models.encoders)
Enc_VideoGPT (class in multimodal_compare.models.encoders)
encode() (multimodal_compare.models.mmvae_base.TorchMMVAE method)
(multimodal_compare.models.vae.BaseVae method)
eval_all() (in module multimodal_compare.eval.eval_cdsprites)
eval_cdsprites_over_seeds() (in module multimodal_compare.eval.eval_cdsprites)
eval_forward() (multimodal_compare.models.trainer.MultimodalVAE method)
eval_single_model() (in module multimodal_compare.eval.eval_cdsprites)
eval_statistics() (multimodal_compare.eval.infer.MultimodalVAEInfer method)
eval_statistics_fn() (multimodal_compare.models.datasets.BaseDataset method)
(multimodal_compare.models.datasets.CDSPRITESPLUS method)
(multimodal_compare.models.datasets.SPRITES method)
eval_with_classifier() (in module multimodal_compare.eval.eval_cdsprites)
extra_hidden_layer() (in module multimodal_compare.models.decoders)
(in module multimodal_compare.models.encoders)
F
FASHIONMNIST (class in multimodal_compare.models.datasets)
feature_dims (multimodal_compare.models.datasets.CDSPRITESPLUS attribute)
(multimodal_compare.models.datasets.CELEBA attribute)
(multimodal_compare.models.datasets.CUB attribute)
(multimodal_compare.models.datasets.FASHIONMNIST attribute)
(multimodal_compare.models.datasets.MNIST_SVHN attribute)
(multimodal_compare.models.datasets.POLYMNIST attribute)
(multimodal_compare.models.datasets.SPRITES attribute)
fill_cats() (in module multimodal_compare.eval.eval_cdsprites)
find_in_list() (in module multimodal_compare.eval.eval_cdsprites)
find_version() (multimodal_compare.models.config_cls.Config method)
forward() (multimodal_compare.models.decoders.Dec_CNN method)
(multimodal_compare.models.decoders.Dec_FNN method)
(multimodal_compare.models.decoders.Dec_MNIST method)
(multimodal_compare.models.decoders.Dec_MNIST2 method)
(multimodal_compare.models.decoders.Dec_PolyMNIST method)
(multimodal_compare.models.decoders.Dec_SVHN method)
(multimodal_compare.models.decoders.Dec_SVHN2 method)
(multimodal_compare.models.decoders.Dec_Transformer method)
(multimodal_compare.models.decoders.Dec_TransformerIMG method)
(multimodal_compare.models.decoders.Dec_TxtTransformer method)
(multimodal_compare.models.decoders.Dec_VideoGPT method)
(multimodal_compare.models.encoders.Enc_CNN method)
(multimodal_compare.models.encoders.Enc_FNN method)
(multimodal_compare.models.encoders.Enc_MNIST method)
(multimodal_compare.models.encoders.Enc_MNIST2 method)
(multimodal_compare.models.encoders.Enc_PolyMNIST method)
(multimodal_compare.models.encoders.Enc_SVHN method)
(multimodal_compare.models.encoders.Enc_SVHN2 method)
(multimodal_compare.models.encoders.Enc_Transformer method)
(multimodal_compare.models.encoders.Enc_TransformerIMG method)
(multimodal_compare.models.encoders.Enc_TxtTransformer method)
(multimodal_compare.models.encoders.Enc_VideoGPT method)
(multimodal_compare.models.encoders.VaeComponent method)
(multimodal_compare.models.mmvae_base.TorchMMVAE method)
(multimodal_compare.models.mmvae_models.DMVAE method)
(multimodal_compare.models.mmvae_models.MOE method)
(multimodal_compare.models.mmvae_models.MoPOE method)
(multimodal_compare.models.mmvae_models.POE method)
(multimodal_compare.models.vae.BaseVae method)
G
gaussian_nll() (multimodal_compare.models.objectives.ReconLoss static method)
generate_samples() (multimodal_compare.models.vae.VAE method)
get_actions() (multimodal_compare.models.datasets.SPRITES method)
get_all_classifiers() (in module multimodal_compare.eval.eval_cdsprites)
get_attribute() (in module multimodal_compare.eval.eval_cdsprites)
get_attribute_from_recon() (in module multimodal_compare.eval.eval_cdsprites)
get_attributes() (multimodal_compare.models.datasets.SPRITES method)
get_base_path() (multimodal_compare.eval.infer.MultimodalVAEInfer method)
get_config() (multimodal_compare.eval.infer.MultimodalVAEInfer method)
get_data() (multimodal_compare.models.datasets.BaseDataset method)
get_data_raw() (multimodal_compare.models.datasets.BaseDataset method)
(multimodal_compare.models.datasets.FASHIONMNIST method)
get_datamodule() (multimodal_compare.eval.infer.MultimodalVAEInfer method)
get_dataset_class() (multimodal_compare.models.dataloader.DataModule method)
get_frames() (multimodal_compare.models.datasets.SPRITES method)
get_label_for_indices() (multimodal_compare.models.dataloader.DataModule method)
get_labels() (multimodal_compare.models.dataloader.DataModule method)
(multimodal_compare.models.datasets.BaseDataset method)
get_mean_stats() (in module multimodal_compare.eval.eval_cdsprites)
get_missing_modalities() (multimodal_compare.models.mmvae_base.TorchMMVAE method)
get_mod_mappings() (in module multimodal_compare.eval.eval_cdsprites)
get_mod_names() (multimodal_compare.models.trainer.MultimodalVAE method)
get_model() (multimodal_compare.models.trainer.MultimodalVAE method)
get_nework_classes() (multimodal_compare.models.vae.DencoderFactory class method)
get_num_samples() (multimodal_compare.models.dataloader.DataModule method)
get_processed_recons() (multimodal_compare.models.datasets.BaseDataset method)
get_remaining_mods_data() (multimodal_compare.models.mmvae_models.DMVAE method)
get_test_data() (multimodal_compare.models.datasets.BaseDataset method)
get_vis_dir() (multimodal_compare.models.config_cls.Config method)
get_wrapped_model() (multimodal_compare.eval.infer.MultimodalVAEInfer method)
I
image_to_text() (in module multimodal_compare.eval.eval_cdsprites)
iter_over_inputs() (multimodal_compare.models.datasets.SPRITES method)
iwae() (multimodal_compare.models.objectives.BaseObjective method)
(multimodal_compare.models.objectives.MultimodalObjective method)
(multimodal_compare.models.objectives.UnimodalObjective method)
L
l1() (multimodal_compare.models.objectives.ReconLoss static method)
labels() (multimodal_compare.models.datasets.BaseDataset method)
(multimodal_compare.models.datasets.CDSPRITESPLUS method)
(multimodal_compare.models.datasets.CUB method)
(multimodal_compare.models.datasets.FASHIONMNIST method)
(multimodal_compare.models.datasets.MNIST_SVHN method)
(multimodal_compare.models.datasets.SPRITES method)
latent_factorization (multimodal_compare.models.mmvae_base.TorchMMVAE property)
load_classifier() (in module multimodal_compare.eval.eval_cdsprites)
load_images() (in module multimodal_compare.eval.eval_cdsprites)
logsumexp() (multimodal_compare.models.mmvae_models.DMVAE method)
lprob() (multimodal_compare.models.objectives.ReconLoss static method)
M
make_dataloaders() (multimodal_compare.eval.infer.MultimodalVAEInfer method)
make_masks() (multimodal_compare.models.dataloader.DataModule method)
(multimodal_compare.models.datasets.SPRITES method)
make_output_dict() (multimodal_compare.models.mmvae_base.TorchMMVAE method)
manhattan_distance() (in module multimodal_compare.eval.eval_cdsprites)
mixture_component_selection() (multimodal_compare.models.mmvae_models.MoPOE method)
MNIST_SVHN (class in multimodal_compare.models.datasets)
modality_mixing() (multimodal_compare.models.mmvae_base.TorchMMVAE method)
(multimodal_compare.models.mmvae_models.MoPOE method)
(multimodal_compare.models.mmvae_models.POE method)
module
multimodal_compare.eval.eval_cdsprites
multimodal_compare.eval.infer
multimodal_compare.models.config_cls
multimodal_compare.models.dataloader
multimodal_compare.models.datasets
multimodal_compare.models.decoders
multimodal_compare.models.encoders
multimodal_compare.models.mmvae_base
multimodal_compare.models.mmvae_models
multimodal_compare.models.objectives
multimodal_compare.models.trainer
multimodal_compare.models.vae
MOE (class in multimodal_compare.models.mmvae_models)
moe_fusion() (multimodal_compare.models.mmvae_models.MoPOE method)
MoPOE (class in multimodal_compare.models.mmvae_models)
mse() (multimodal_compare.models.objectives.ReconLoss static method)
multimodal_compare.eval.eval_cdsprites
module
multimodal_compare.eval.infer
module
multimodal_compare.models.config_cls
module
multimodal_compare.models.dataloader
module
multimodal_compare.models.datasets
module
multimodal_compare.models.decoders
module
multimodal_compare.models.encoders
module
multimodal_compare.models.mmvae_base
module
multimodal_compare.models.mmvae_models
module
multimodal_compare.models.objectives
module
multimodal_compare.models.trainer
module
multimodal_compare.models.vae
module
MultimodalObjective (class in multimodal_compare.models.objectives)
MultimodalVAE (class in multimodal_compare.models.trainer)
MultimodalVAEInfer (class in multimodal_compare.eval.infer)
N
normalize() (multimodal_compare.models.objectives.BaseObjective static method)
O
objective() (multimodal_compare.models.mmvae_base.TorchMMVAE method)
(multimodal_compare.models.mmvae_models.DMVAE method)
(multimodal_compare.models.mmvae_models.MOE method)
(multimodal_compare.models.mmvae_models.MoPOE method)
(multimodal_compare.models.mmvae_models.POE method)
(multimodal_compare.models.vae.VAE method)
P
parse_params() (multimodal_compare.models.config_cls.Config method)
POE (class in multimodal_compare.models.mmvae_models)
poe_fusion() (multimodal_compare.models.mmvae_models.MoPOE method)
POLYMNIST (class in multimodal_compare.models.datasets)
predict_dataloader() (multimodal_compare.models.dataloader.DataModule method)
prepare_data_classes() (multimodal_compare.models.dataloader.DataModule method)
prepare_singlemodal() (multimodal_compare.models.dataloader.DataModule method)
prior_expert() (multimodal_compare.models.mmvae_models.POE method)
product_of_experts() (multimodal_compare.models.mmvae_base.TorchMMVAE static method)
pz_params (multimodal_compare.models.mmvae_base.TorchMMVAE property)
(multimodal_compare.models.mmvae_models.DMVAE property)
(multimodal_compare.models.mmvae_models.MOE property)
(multimodal_compare.models.mmvae_models.MoPOE property)
(multimodal_compare.models.mmvae_models.POE property)
(multimodal_compare.models.vae.VAE property)
pz_params_private (multimodal_compare.models.vae.VAE property)
Q
qz_x_params (multimodal_compare.models.vae.VAE property)
R
recon_loss_fn() (multimodal_compare.models.objectives.BaseObjective method)
ReconLoss (class in multimodal_compare.models.objectives)
reconstruct() (multimodal_compare.models.mmvae_models.MOE method)
reparameterize() (multimodal_compare.models.mmvae_models.MoPOE method)
reshape_for_loss() (multimodal_compare.models.objectives.BaseObjective static method)
reweight_weights() (multimodal_compare.models.mmvae_models.MoPOE method)
S
save_joint_samples() (multimodal_compare.models.trainer.MultimodalVAE method)
save_recons() (multimodal_compare.models.datasets.CDSPRITESPLUS method)
(multimodal_compare.models.datasets.CELEBA method)
(multimodal_compare.models.datasets.CUB method)
(multimodal_compare.models.datasets.FASHIONMNIST method)
(multimodal_compare.models.datasets.MNIST_SVHN method)
(multimodal_compare.models.datasets.POLYMNIST method)
(multimodal_compare.models.datasets.SPRITES method)
save_reconstructions() (multimodal_compare.models.trainer.MultimodalVAE method)
save_traversals() (multimodal_compare.models.datasets.BaseDataset method)
(multimodal_compare.models.datasets.CELEBA method)
(multimodal_compare.models.datasets.POLYMNIST method)
(multimodal_compare.models.datasets.SPRITES method)
search_att() (in module multimodal_compare.eval.eval_cdsprites)
set_likelihood_scales() (multimodal_compare.models.mmvae_base.TorchMMVAE method)
set_ltype() (multimodal_compare.models.objectives.BaseObjective method)
set_objective_fn() (multimodal_compare.models.vae.VAE method)
set_subsets() (multimodal_compare.models.mmvae_models.MoPOE method)
set_vis_image_shape() (multimodal_compare.models.datasets.CDSPRITESPLUS method)
setup() (multimodal_compare.models.dataloader.DataModule method)
SPRITES (class in multimodal_compare.models.datasets)
T
test_dataloader() (multimodal_compare.models.dataloader.DataModule method)
test_epoch_end() (multimodal_compare.models.trainer.MultimodalVAE method)
test_step() (multimodal_compare.models.trainer.MultimodalVAE method)
text_to_image() (in module multimodal_compare.eval.eval_cdsprites)
TorchMMVAE (class in multimodal_compare.models.mmvae_base)
train_dataloader() (multimodal_compare.models.dataloader.DataModule method)
training (multimodal_compare.models.decoders.Dec_CNN attribute)
(multimodal_compare.models.decoders.Dec_FNN attribute)
(multimodal_compare.models.decoders.Dec_MNIST attribute)
(multimodal_compare.models.decoders.Dec_MNIST2 attribute)
(multimodal_compare.models.decoders.Dec_PolyMNIST attribute)
(multimodal_compare.models.decoders.Dec_SVHN attribute)
(multimodal_compare.models.decoders.Dec_SVHN2 attribute)
(multimodal_compare.models.decoders.Dec_Transformer attribute)
(multimodal_compare.models.decoders.Dec_TransformerIMG attribute)
(multimodal_compare.models.decoders.Dec_TxtTransformer attribute)
(multimodal_compare.models.decoders.Dec_VideoGPT attribute)
(multimodal_compare.models.decoders.VaeDecoder attribute)
(multimodal_compare.models.encoders.Enc_CNN attribute)
(multimodal_compare.models.encoders.Enc_FNN attribute)
(multimodal_compare.models.encoders.Enc_MNIST attribute)
(multimodal_compare.models.encoders.Enc_MNIST2 attribute)
(multimodal_compare.models.encoders.Enc_PolyMNIST attribute)
(multimodal_compare.models.encoders.Enc_SVHN attribute)
(multimodal_compare.models.encoders.Enc_SVHN2 attribute)
(multimodal_compare.models.encoders.Enc_Transformer attribute)
(multimodal_compare.models.encoders.Enc_TransformerIMG attribute)
(multimodal_compare.models.encoders.Enc_TxtTransformer attribute)
(multimodal_compare.models.encoders.Enc_VideoGPT attribute)
(multimodal_compare.models.encoders.VaeComponent attribute)
(multimodal_compare.models.encoders.VaeEncoder attribute)
(multimodal_compare.models.mmvae_base.TorchMMVAE attribute)
(multimodal_compare.models.mmvae_models.DMVAE attribute)
(multimodal_compare.models.mmvae_models.MOE attribute)
(multimodal_compare.models.mmvae_models.MoPOE attribute)
(multimodal_compare.models.mmvae_models.POE attribute)
(multimodal_compare.models.vae.BaseVae attribute)
(multimodal_compare.models.vae.VAE attribute)
training_step() (multimodal_compare.models.trainer.MultimodalVAE method)
try_retrieve_atts() (in module multimodal_compare.eval.eval_cdsprites)
U
UnimodalObjective (class in multimodal_compare.models.objectives)
V
VAE (class in multimodal_compare.models.vae)
VaeComponent (class in multimodal_compare.models.encoders)
VaeDecoder (class in multimodal_compare.models.decoders)
VaeEncoder (class in multimodal_compare.models.encoders)
val_dataloader() (multimodal_compare.models.dataloader.DataModule method)
validation_epoch_end() (multimodal_compare.models.trainer.MultimodalVAE method)
validation_step() (multimodal_compare.models.trainer.MultimodalVAE method)
W
weighted_group_kld() (multimodal_compare.models.objectives.BaseObjective method)