Solved – How does batch normalization compute the population statistics after training

batch normalizationconv-neural-networkdeep learningmachine learningneural networks

I was reading the batch normalization (BN) paper (1) and it said:

For this, once the network has been trained, we use
the normalization $$\hat{x} = \frac{x – E[x]}{ \sqrt{Var[x] + \epsilon}}$$ using the population,
rather than mini-batch, statistics.

my question is, how does it compute this population statistics and over what training set (test,validation,train)? I thought I knew what that meant but after some time, I realize that I am not sure how it calculates this. I assume it tries to estimate the true mean and variance though I am not sure how it does that. What I'd probably do is compute the mean and variance according to the whole data set and use those moments for inference.

However, what made me suspect that I am wrong is their discussion about unbiased variance estimate later in that same section:

We use the unbiased variance estimate $Var[x] = \frac{m}{m-1} \cdot E_{\mathcal{B}}[\sigma^2_{\mathcal{B}}]$ where the expecation is over training mini-batches of size $m$ and $\sigma^2_{\mathcal{B}}$ are their sample variances.

Since we are talking about population statistics, this comment on the paper felt like it came out of no-where (to me) and wasn't sure what they were talking about. Are they just (randomly) clarifying they use unbiased estimates during training or are they using an unbiased estimate to compute the population statistic?


1: Ioffe S. and Szegedy C. (2015),
"Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift",
Proceedings of the 32nd International Conference on Machine Learning, Lille, France, 2015.
Journal of Machine Learning Research: W&CP volume 37

Best Answer

Typically, the population statistics are taken from the training set. If you include the test set, at test time, you will have information that technically, you shouldn't have access to (information about the whole dataset). For the same reason, the validation set shouldn't be used to compute those statistics.

Keep in mind that due to the fact that batch-normalization isn't only at the input layer, the population's statistics will vary from epoch to epoch, as the network learns and changes its parameters (and therefore, its outputs at each layer).

Therefore the common way to compute these statistics is to keep a (exponentially decaying or moving) averages during training. This will smoothen out the stochastic variations due to mini-batch training, and stay up to date to the current status of learning. You can see an example of this in the torch code for batch norm : https://github.com/torch/nn/blob/master/lib/THNN/generic/BatchNormalization.c#L22

The paper mentions that they use moving averages instead of just keeping the last computed statistics :

Using moving averages instead, we can track the accuracy of a model as it trains.

For your second question, they are saying that they use that unbiased estimate to estimate the population variance (for future inference).