Machine Learning – Variational Autoencoder and Latent Space Dimensions

autoencodersgenerative-modelsmachine learningneural networksnormal distribution

I've done some experiments to understand the influence of the dimension of the latent space in a VAE, and it seems that the higher the space, the harder it is to generate realistic images. I might have an intuition of the reason, and I wanted to have your opinion or any other theoretical insight about it.

First, what I've noticed:

  • After the training of a deep convolutional VAE with a large latent space (8x8x1024) on MNIST, the reconstruction works very well. Moreover, when I give any sample $x$ to my encoder, the output mean $\mu(x)$ is close to 0 and the output std $\sigma(x)$ is close to 1. Both the reconstruction loss and the latent loss seem to be low.
  • However, if I give random samples from $\mathcal{N}(0,I)$ to my decoder, the output is some random white strokes on a black background (like MNIST samples, but not looking like digits).
  • If I give an image $x$ to my encoder, it will output a mean $\mu(x)$ (close to 0), and if I give to my decoder random samples from $\mathcal{N}(\mu(x),I)$, the output will be images representing the same digit than the input (both realistic and different from the input)

What I conclude is that:

  • The VAE has generated many gaussian distributions of realistic images, whose centers are close to 0 but not exactly 0. Thus, the distribution of realistic images is a mixture of gaussians $\mathcal{D} = \sum_{x \in \mathcal{X}} \alpha_x \mathcal{N}(\mu(x),I)$
  • The practical support of $\mathcal{N}(0,I)$ does not overlap with the practical support of $\mathcal{D}$ (except on a set of measure zero). By practical support, I mean the space where most points are actually generated. For a high-dimensional gaussian, it corresponds to a soap bubble.

So here is a visualization of what would happen with a high-dimensional latent space:

high-dimensional gaussians

The red bubble would be the practical support of $\mathcal{N}(0,I)$ while the union of the black bubbles would be the practical support of $\mathcal{D}$. Only the black bubbles contain realistic images, while the red bubble contains almost no realistic image. The higher the dimension, the thinner the bubbles are and the smaller the overlapping space is.

Is this intuition correct? Is there any other reason for high dimensional latent spaces not to work correctly?

Best Answer

You seem to have misunderstood your architecture and are, quite simply, overfitting your data.

It looks like your interpretation of the latent space is that it represents a manifold of realistic-looking images. That is unlikely in the best case, and if your decoder performs any transformation (except perhaps an affine transformation) on the sampling outputs - impossible.

Autoencoders (or rather the encoder component of them) in general are compression algorithms. This means that they approximate 'real' data with a smaller set of more abstract features.

For example, a string '33333333000000000669111222222' could be losslessly compressed by a very simplistic algorithm to '8:3/9:0/2:6/1:9/3:1/6:2' - occurences:number, maintaining position. If your criterion was length of text, the encoding is six characters shorter - not a huge improvement, but an improvement nonetheless.

What happened was we've introduced an abstract, higher-dimensional feature - 'number of repetitions' - that helps us express the original data more tersely. You could compress the output further; for example, noticing that even positions are just separators, you could encode them as a single-bit padding rather than an ASCII code.

Autoencoders do exactly that, except they get to pick the features themselves, and variational autoencoders enforce that the final level of coding (at least) is fuzzy in a way that can be manipulated.

So what you do in your model is that you're describing your input image using over sixty-five thousand features. And in a variational autoencoder, each feature is actually a sliding scale between two distinct versions of a feature, e.g. male/female for faces, or wide/thin brushstroke for MNIST digits.

Can you think of just a hundred ways to describe the differences between two realistic pictures in a meaningful way? Possible, I suppose, but they'll get increasingly forced as you try to go on.

With so much room to spare, the optimizer can comfortably encode each distinct training image's features in a non-overlapping slice of the latent space rather than learning the features of the training data taken globally.

So, when you feed it a validation picture, its encoding lands somewhere between islands of locally applicable feature encodings and so the result is entirely incoherent.

Related Question