Solved – How and why does Batch Normalization use moving averages to track the accuracy of the model as it trains

batch normalizationconv-neural-networkmachine learningneural networks

I was reading the batch normalization (BN) paper (1) and didn't understand the need to use moving averages to track the accuracy of the model and even if I accepted that it was the right thing to do, I don't understand what they are doing exactly.

To my understanding (which might be wrong), the paper mentions that it uses the population statistics rather than the mini-batch statistics once the model has finished training. After some discussion of unbiased estimates (that seems tangential to me and I don't understand why it talks about that) they go and say:

Using moving averages instead, we track the accuracy of the model as
it trains.

That is the part that is confusing to me. Why do they do moving averages to estimate the accuracy of the model and over what data set?

Usually what people do to estimate the generalization of their model, they just track the validation error of their model (and potentially early stop their gradient descent to regularize). However, it seems that batch normalization is doing something completely different. Can someone clarify what and why it's doing something different?


1: Ioffe S. and Szegedy C. (2015),
"Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift",
Proceedings of the 32nd International Conference on Machine Learning, Lille, France, 2015.
Journal of Machine Learning Research: W&CP volume 37

Best Answer

When using batch_normalization first thing we have to understand is that it works on two different ways when in Training and Testing.

  1. In Training we need to calculate mini batch mean in order to normalize the batch

  2. In the inference we just apply pre-calculated mini batch statistics

So in the 2nd thing how to calculate this mini batch statics

Here comes the moving average

running_mean = momentum * running_mean + (1 - momentum) * sample_mean
running_var = momentum * running_var + (1 - momentum) * sample_var