Solved – Should reconstruction loss be computed as sum or average over input for variational autoencoders

autoencoderskeraskullback-leiblerloss-functionsvariational-bayes

I am following this variational autoencoder tutorial: https://keras.io/examples/generative/vae/. I have included the loss computation part of the code below.

I know VAE's loss function consists of the reconstruction loss that compares the original image and reconstruction, as well as the KL loss. However, I'm a bit confused about the reconstruction loss and whether it is over the entire image (sum of squared differences) or per pixel (average sum of squared differences). My understanding is that the reconstruction loss should be per pixel (MSE), but the example code I am following multiplies MSE by 28 x 28, the MNIST image dimensions. Is that correct? Furthermore, my assumption is this would make the reconstruction loss term significantly larger than the KL loss and I'm not sure we want that.

I tried removing the multiplication by (28×28), but this resulted in extremely poor reconstructions. Essentially all the reconstructions looked the same regardless of the input. Can I use a lambda parameter to capture the tradeoff between kl divergence and reconstruction, or it that incorrect because the loss has a precise derivation (as opposed to just adding a regularization penalty).

reconstruction_loss = tf.reduce_mean(
    keras.losses.binary_crossentropy(data, reconstruction)
)
reconstruction_loss *= 28 * 28
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_mean(kl_loss)
kl_loss *= -0.5
total_loss = reconstruction_loss + kl_loss

Best Answer

To go directly to the answer, the loss does have a precise derivation (but that doesn't mean you can't necessarily change it).

It's important to remember that Variational Auto-encoders are at their core a method for doing variational inference over some latent variables we assume to be generating the data. In this framework we aim to minimise the KL-divergence between some approximate posterior over the latent variables and the true posterior, which we can alternatively do my maximising the evidence lower bound (ELBO), details in the VAE paper. This gives us the objective in VAEs:

$$ \mathcal{L}(\theta,\phi) = \underbrace{\mathbb{E}_{q_\phi}[\log p_\theta(x|z)]}_{\text{Reconstruction Loss}} - \underbrace{D_{KL}(q_\phi(z)||p(z))}_{\text{KL Regulariser}} $$

Now the reconstruction loss is the expected log-likelihood of the data given the latent variables. For an image which is made up of a number of pixels the total log-likelihood will be the sum of the log-likelihood of all of the pixels (assuming independence), not the average log-likelihood of each individual pixel which is why it's the case in the example.

The question of whether you can add an extra parameter is an interesting one. DeepMind for example have introduced the $\beta$-VAE, which does exactly this, albeit for a slightly different purpose - they show that this extra parameter can lead to a more disentangled latent-space that allows for more interpretable variables. How principled this change in objective is is up for debate, but it does work. That being said it is very easy to change the KL regulariser term in a principled way by simply changing your prior ($p(z)$) on the latent variables, the original prior is a very boring standard normal distribution so just swapping in something else will change the loss function. You might even be able, though I haven't checked this myself, to specify a new prior ($p'(z)$) such that:

$$ D_{KL}(q_\phi(z)||p'(z)) = \lambda * D_{KL}(q_\phi(z)||p(z)), $$

which will do exactly what you want.

So basically the answer is yes - feel free to change the loss function if it helps you do the task you want just be aware of how what you're doing is different to the original case so you don't make any claims you shouldn't.