Solved – Learning normal distribution with VAE

autoencodersgenerative-modelsnormal distributionunsupervised learning

I am trying to use a Variational Autoencoder to learn a multivariate normal distribution. I know that from practical point of view this is pointless, as we can sample from a normal distribution itself, however I wanted to try this before going to more interesting applications.

I am using the algorithm presented here. Instead of training in the MNIST data, I sample from numpy's normal multivariate and map the result to $[0,1]$ using a sigmoid. I use this data to train the VAE.

After training I sample from a $\mathcal{N}(0,I)$ in the latent space and use the decoder to generate data. I expected that if I apply an inverse sigmoid to the generated data, I should get normally distributed data with the same mean and covariance as the dataset I used for training. I compare the training and the generated data using a scatter plot and I cannot get them to match.

My main question is whether this scheme can work (which means that I am just doing something wrong in the program) or if it is not correct to use VAEs this way at all (I just learnt about VAEs yesterday, so I am not that sure!).

Moreover, in the link I gave a Bernoulli distribution is used to calculate reconstruction loss. I tried both with this but also with a Gaussian:

reconstr_loss = 0.5 * tf.reduce_sum((self.x – self.x_reconstr_mean)*(self.x – self.x_reconstr_mean) + np.log(2*np.pi*self.sigma_hyper * self.sigma_hyper), 1)

where sigma_hyper is just a hyperparameter (I set it 1). Both ways didn't work.

Thanks a lot!

Edit: Some of the scatter plots. Left is a 2D multivariate normal (used for training – axes correspond to the two random variables) and right the data generated from the decoder. Both mean and covariance is apparently wrong in the generated data.

Best Answer

Its totally doable due to the connections between Probabilistic PCA and Linear VAEs

You can show that the decoder weights $W$ from the Linear VAE can be used to simulate samples from the following normal distribution $$ x \sim N(0, W'W + \sigma I) $$

See reference below for more details on this https://arxiv.org/abs/1911.02469

Related Question