Neural Networks – Addressing Distorted Validation Loss with Batch Normalization in Convolutional Autoencoders

autoencodersbatch normalizationkerasneural networks

I have implemented an variational autoencoder with convolutional layers in Keras. I have around 40'000 training images and 4000 validation images. The images are heat maps. The encoder and decoder are symmetric. In total I have 3 layers (32, 64, 128 feature maps with stride 2). After each layer I have a batch normalization after relu activation layer.

The problem is that without batch normalization the training and validation loss decreases as expected and are smooth but when inserting batch normalization I either face one huge peak in the validation loss (see left image) or the validation loss is very bumpy (see right image). I have played around with a momentum of 0.99 and 0.9 for batch normalization layer. If I use a momentum of 0.9 only cases as in the left image appears.

enter image description here

What can I do against it? Not using batch normalization at all? As said without batch normalization the validation loss behaves like the training loss but I think today everybody is using batch normalization…

Best Answer

It's common practice to avoid using batch normalization when training VAEs, since the additional stochasticity due to using mini-batches may aggravate instability on top of the stochasticity from sampling.

Source: https://www.tensorflow.org/beta/tutorials/generative/cvae