Autoencoders – Choosing the Best Loss Function

autoencoderscross entropymsetensorflow

I am experimenting a bit autoencoders, and with tensorflow I created a model that tries to reconstruct the MNIST dataset.

My network is very simple: X, e1, e2, d1, Y, where e1 and e2 are encoding layers, d2 and Y are decoding layers (and Y is the reconstructed output).

X has 784 units, e1 has 100, e2 has 50, d1 has 100 again and Y 784 again.

I am using sigmoids as activation functions for layers e1, e2, d1 and Y. Inputs are in [0,1] and so should be the outputs.

Well, I tried using cross entropy as loss function, but the output was always a blob, and I noticed that the weights from X to e1 would always converge to an zero-valued matrix.

On the other hand, using mean squared errors as loss function, would produce a decent result, and I am now able to reconstruct the inputs.

Why is that so? I thought I could interpret the values as probabilities, and therefore use cross entropy, but obviously I am doing something wrong.

Best Answer

I think the best answer to this is that the cross-entropy loss function is just not well-suited to this particular task.

In taking this approach, you are essentially saying the true MNIST data is binary, and your pixel intensities represent the probability that each pixel is 'on.' But we know this is not actually the case. The incorrectness of this implicit assumption is then causing us issues.

We can also look at the cost function and see why it might be inappropriate. Let's say our target pixel value is 0.8. If we plot the MSE loss, and the cross-entropy loss $- [ (\text{target}) \log (\text{prediction}) + (1 - \text{target}) \log (1 - \text{prediction}) ]$ (normalising this so that it's minimum is at zero), we get:

cross-entropy vs. mse loss

We can see that the cross-entropy loss is asymmetric. Why would we want this? Is it really worse to predict 0.9 for this 0.8 pixel than it is to predict 0.7? I would say it's maybe better, if anything.

We could probably go into more detail and figure out why this leads to the specific blobs that you are seeing. I'd hazard a guess that it is because pixel intensities are above 0.5 on average in the region where you are seeing the blob. But in general this is a case of the implicit modelling assumptions you have made being inappropriate for the data.

Hope that helps!