So we can derive the loss function for the VAE following something like this: https://arxiv.org/pdf/1907.08956v1.pdf
But when I go to implement the loss function in pytorch using the negative log-likelihood from that PDF, with MSE as the reconstruction error, I get an extremely large negative training loss. What am I doing wrong?
The training loss does actually start out positive but then starts immediately going extremely negative in an exponential fashion. In code my loss functions is:
MSELoss_criterion = nn.MSELoss()
MSE_loss = MSELoss_criterion(y_hat, tgts)
KLDiv_loss = -0.5*torch.sum(1+log_var_q - mu_q **2 - log_var_q.exp(), dim=(2))
KLDiv_loss = torch.mean(KLDiv_loss)
return -MSE_loss + KLDiv_Loss
Best Answer
The problem is
return -MSE_loss + KLDiv_Loss
. You don't want to minimize-MSE_loss
because you can always make $-(x-c)^2$ smaller by choosing $x$ farther from $c$. If $c$ is your target, this means your model is getting further from your goal.Use
return MSE_loss + KLDiv_Loss
instead. You can show that this is correct by starting from a Gaussian likelihood for your targettgts
and manipulating the algebra to obtain the negative log-likelihood, whence MSE is a rescaling.