Add a new dataset

By default, we support the proposed CdSprites+ dataset as well as MNIST-SVHN, CelebA, SPRITES, PolyMNIST, FashionMNIST or the Caltech-UCSD Birds (CUB) dataset. Here we describe how you can train the models on your own data.

Supported data formats, config

In general, the preferred data formats (supported by default) are:

  • pickle (.pkl)

  • the pytorch format (.pth, .pt)

  • numpy format (.npy)

  • hdf5 format (.h5)

  • a directory containing .png or .jpg images

To train with any of these, specify the path to your data in the config.yml:

batch_size: 16
epochs: 600
exp_name: cub
labels:
beta: 1
lr: 1e-3
mixing: moe
n_latents: 16
obj: elbo
optimizer: adam
pre_trained: null
seed: 2
viz_freq: 1
test_split: 0.1
dataset_name: cub
modality_1:
  decoder: CNN
  encoder: CNN
  mod_type: image
  recon_loss:  bce
  path: ./data/cub/images
modality_2:
  decoder: TxtTransformer
  encoder: TxtTransformer
  mod_type: text
  recon_loss: category_ce
  path: ./data/cub/cub_captions.pkl

This is an example of the config file for the CUB dataset (for download, see our README).

As you can see, we specified the path to an image folder (./data/cub/images) and to the pickled captions (./data/cub/cub_captions.pkl). Both modalities are expected to be ordered so that they can be semantically matched into pairs (e.g. the first image should match with the first caption).

Adding a new dataset class

If you wish to train on your own data, you will need to make a custom dataset class in datasets.py. Any new dataset must inherit from BaseDataset to have some common methods used by the DataModule.

In case of CUB we add it in datasets.py like this:

 1 class CUB(BaseDataset):
 2     """Dataset class for our processed version of Caltech-UCSD birds dataset. We use the original images and text
 3     represented as sequences of one-hot-encodings for each character (incl. spaces)"""
 4     feature_dims = {"image": [64, 64, 3],
 5                     "text": [246, 27, 1]
 6                     }  # these feature_dims are also used by the encoder and decoder networks
 7
 8     def __init__(self, pth, testpth, mod_type):
 9         super().__init__(pth, testpth, mod_type)
10         self.mod_type = mod_type
11         self.text2img_size = (64,380,3)
12
13     def _preprocess_text_onehot(self):
14         """
15         General function for loading text strings and preparing them as torch one-hot encodings
16
17         :return: torch with text encodings and masks
18         :rtype: torch.tensor
19         """
20         self.has_masks = True
21         self.categorical = True
22         data = [one_hot_encode(len(f), f) for f in self.get_data_raw()]
23         data = [torch.from_numpy(np.asarray(x)) for x in data]
24         masks = lengths_to_mask(torch.tensor(np.asarray([x.shape[0] for x in data]))).unsqueeze(-1)
25         data = torch.nn.utils.rnn.pad_sequence(data, batch_first=True, padding_value=0.0)
26         data_and_masks = torch.cat((data, masks), dim=-1)
27         return data_and_masks
28
29     def _postprocess_text(self, data):
30         if isinstance(data, dict):
31             masks = data["masks"]
32             data = data["data"]
33             text = output_onehot2text(data)
34             if masks is not None:
35                 masks = torch.count_nonzero(masks, dim=-1)
36                 text = [x[:masks[i]] for i, x in enumerate(text)]
37         else:
38             text = output_onehot2text(data)
39         for i, phrase in enumerate(text):
40             phr = phrase.split(" ")
41             newphr = copy.deepcopy(phr)
42             stringcount = 0
43             for x, w in enumerate(phr):
44                 stringcount += (len(w))+1
45                 if stringcount > 40:
46                     newphr.insert(x, "\n")
47                     stringcount = 0
48             text[i] = (" ".join(newphr)).replace("\n  ", "\n ")
49         return text
50
51     def labels(self):
52         """
53         No labels for T-SNAE available
54         """
55         return None
56
57     def _preprocess_text(self):
58         d = self.get_data_raw()
59         self.has_masks = True
60         self.categorical = True
61         data = [one_hot_encode(len(f), f) for f in d]
62         data = [torch.from_numpy(np.asarray(x)) for x in data]
63         masks = lengths_to_mask(torch.tensor(np.asarray([x.shape[0] for x in data]))).unsqueeze(-1)
64         data = torch.nn.utils.rnn.pad_sequence(data, batch_first=True, padding_value=0.0)
65         data_and_masks = torch.cat((data, masks), dim=-1)
66         return data_and_masks
67
68     def _preprocess_images(self):
69         d = self.get_data_raw().reshape(-1, *[self.feature_dims["image"][i] for i in [2,0,1]])
70         data = torch.tensor(d)
71         return data
72
73     def _mod_specific_loaders(self):
74         return {"image": self._preprocess_images, "text": self._preprocess_text}
75
76     def _mod_specific_savers(self):
77         return {"image": self._postprocess_images, "text": self._postprocess_text}
78
79     def save_recons(self, data, recons, path, mod_names):
80         output_processed = self._postprocess_all2img(recons)
81         outs = add_recon_title(output_processed, "output\n{}".format(self.mod_type), (0, 170, 0))
82         input_processed = []
83         for key, d in data.items():
84             output = self._mod_specific_savers()[mod_names[key]](d)
85             images = turn_text2image(output, img_size=self.text2img_size) if mod_names[key] == "text" \
86                 else np.reshape(output,(-1,*self.feature_dims["image"]))
87             images = add_recon_title(images, "input\n{}".format(mod_names[key]), (0, 0, 255))
88             input_processed.append(np.vstack(images))
89             input_processed.append(np.ones((np.vstack(images).shape[0], 2, 3))*125)
90         inputs = np.hstack(input_processed).astype("uint8")
91         final = np.hstack((inputs, np.vstack(outs).astype("uint8")))
92         cv2.imwrite(path, cv2.cvtColor(final, cv2.COLOR_BGR2RGB))

Eventhough the dataset is multimodal, a new instance of it will be created for each modality. Therefore, the constructor gets two arguments: path to the modality (str) and eventually path to the test data (this is used for evaluation after training), and modality_type (str). Modality type is any string that you assign to the given modality to distinguish it from the others. For CUB we chose “image” for images and “text” for text, for MNIST_SVHN we have “mnist” and “svhn”. You specify mod_type in the config. You also need to specify the expected shape of the data in the class attribute “feature_dims”. This will be used by the dataset class to postprocess the data (i.e. reconstructions produced by the model), but also by the encoder and decoder networks to adjust sizes of the network layers.

Next thing you need are methods that prepare each modality for training (_preprocess_text and _preprocess_images). Here we use _preprocess_images from CdSprites+, since it is the same format, and only rewrite _preprocess_text. Data loading is handled automatically by BaseDataset, so you only perform reshaping, converting to tensors etc., so that these functions return tensors of the same length on the output. Note: In case of sequential data (like text here), we make boolean masks and concatenate them with the last dimension of the text data. This is then automatically handled by the collate function.

Another thing we need to do is map the data processing functions to the modality types, i.e. define _mod_specific_loaders() and _mod_specific_savers():

def _mod_specific_loaders(self):
    return {"image": self._preprocess_images, "text": self._preprocess_text}

def _mod_specific_savers(self):
    return {"image": self._postprocess_images, "text": self._postprocess_text}

Here we just assign the above-mentioned methods to the selected mod_types. Once this is done, the dataset class should be ready and you can launch training.

Finally, we can configure how are the outputs saved for visualization. This can be data-dependent, the save_recons() method shown in the example is suited for putting images and text next to each other in one image. The _postprocess_all2img() method prints the string into image of the size self.text2image_size (defined in __init__, see Line 11).

Different data formats

If you want to train on an unsupported data format, you can file an issue on our GitHub repository. Alternatively, you can try to incorporate it on your own as it is only a matter of adjusting one function in utils.py:

def load_data(path):
    """
    Returns loaded data based on path suffix
    :param path: Path to data
    :type path: str
    :return: loaded data
    :rtype: object
    """
    if path.startswith('.'):
        path = os.path.join(get_root_folder(), path)
    assert os.path.exists(path), "Path does not exist: {}".format(path)
    if os.path.isdir(path):
        return load_images(path)
    if pathlib.Path(path).suffix in [".pt",".pth"]:
        return torch.load(path)
    if pathlib.Path(path).suffix == ".pkl":
        return load_pickle(path)
    if pathlib.Path(path).suffix == ".h5":
        return h5py.File(path, 'r')
    if pathlib.Path(path).suffix == ".npy":
        return np.load(path)
    raise Exception("Unrecognized dataset format. Supported types are: .pkl, .pth or directory with images")

Please note that by default, we have incorporated encoders and decoders for images (preferably in 32x32x3 or 64x64x3 resolution, resp. 28x28x1 pixels for MNIST), text data (arbitrary strings which we encode on the character-level) and sequential data (e.g. actions suitable for a Transformer network). If you add a new data structure or image resolution, you will also need to add or adjust the encoder and decoder networks - you can then specify these in the config file.