Solved – Why take the gradient of the moments (mean and variance) when using Batch Normalization in a Neural Network

conv-neural-networkmachine learningneural networks

When doing Batch Normalization (BN) it makes sense to me to treat the BN transform as a layer that we need to do back propagatiod and thus have derivatives to update its parameters (for each layer) the scale $\gamma^{(k)}$ and shift $\beta^{(k)}$.

However, what does not make sense to me is why we would need the derivates of the moments (mean $\mu$ and variance $\sigma$) using gradient descent. This is clearly reflected in the paper (but not explained clearly) when they even have derivatives with respect to these two quantities on page 4:

$$ \frac{\partial l}{\partial \sigma^2_{\mathcal{B}}} = \sum^{m}_{i=1}
\frac{\partial \mathcal{l}}{\partial \hat x_i} \cdot (x_i –
\mu_{\mathcal{B}} ) \cdot \frac{-1}{2} ( \sigma^2_{\mathcal{B}} +
\epsilon )^{-\frac{3}{2}}$$

$$ \frac{\partial l}{\partial \mu_{\mathcal{B}}} =
\left(\sum^{m}_{i=1} \frac{\partial \mathcal{l}}{\partial \hat x_i}
\cdot \frac{-1}{\sqrt{\sigma_{\mathcal{B}}^2 + \epsilon }} \right) +
\frac{\partial l}{\partial \sigma^2_{\mathcal{B}}} \cdot \frac{
\sum^m_{i=1} -2 (x_i – \mu_{\mathcal{B}} }{m} ) $$

The thing that is specifically confusing me is that I thought the means and standard deviations were constant during a specific epoch of training (so their derivatives should be zero. I thought that because the original paper said:

we make the second simplification: since we use mini-batches in stochastic gradient training, each mini-batch produces estimates of the mean and variance of each activation.

Furthermore, in their pseudocode they even compute the (moments) mean and variance according to the current batch:

enter image description here

which further confuses me why there would even be any derivatives with respect to such quantities.

Furthermore, they seem to be the population mean and variance during inference, which makes me further suspect that the moments (mean and variance) should be not be variables. Are they parameters, variables or constants? Someone knows?


2: Ioffe S. and Szegedy C. (2015),
"Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift",
Proceedings of the 32nd International Conference on Machine Learning, Lille, France, 2015.
Journal of Machine Learning Research: W&CP volume 37

Best Answer

Derivatives of the moments are used for backpropagation.

Along the two derivatives of the moments on page 4, it gives the derivatives with respect to the input, which makes use of the derivative of the moments, $$\frac{\partial l}{\partial x}=\frac{\partial l}{\partial \hat{x}}\cdot\frac{1}{\sqrt{\sigma^2+\epsilon}}+\frac{\partial l}{\partial \sigma^2}\cdot\frac{2(x-\mu)}{m}+\frac{\partial l}{\partial \mu}\cdot\frac{1}{m},$$ which will be used for computing the derivatives of the parameters in previous layers by the chain rule.

IMO, the moments are not treated as parameters nor constants, they can be thought of as some intermediate results of computing the output of the layer.