VAE in PyTorch – Handling Extremely Negative Training Loss

autoencodersbayesianloss-functionsmachine learning

So we can derive the loss function for the VAE following something like this: https://arxiv.org/pdf/1907.08956v1.pdf

But when I go to implement the loss function in pytorch using the negative log-likelihood from that PDF, with MSE as the reconstruction error, I get an extremely large negative training loss. What am I doing wrong?

The training loss does actually start out positive but then starts immediately going extremely negative in an exponential fashion. In code my loss functions is:

MSELoss_criterion = nn.MSELoss()
MSE_loss = MSELoss_criterion(y_hat, tgts) 

KLDiv_loss = -0.5*torch.sum(1+log_var_q - mu_q **2 - log_var_q.exp(), dim=(2)) 
KLDiv_loss = torch.mean(KLDiv_loss) 
return -MSE_loss + KLDiv_Loss

Best Answer

The problem is return -MSE_loss + KLDiv_Loss. You don't want to minimize -MSE_loss because you can always make $-(x-c)^2$ smaller by choosing $x$ farther from $c$. If $c$ is your target, this means your model is getting further from your goal.

Use return MSE_loss + KLDiv_Loss instead. You can show that this is correct by starting from a Gaussian likelihood for your target tgts and manipulating the algebra to obtain the negative log-likelihood, whence MSE is a rescaling.