Solved – Batch Normalization shift/scale parameters defeat the point

batch normalizationgradient descentmachine learningneural networks

According to the paper introducting Batch Normalization, the actual BN function is given as:


  • Input: Values of $x$ over a mini-batch $\mathcal B = \{x_{1,\ldots,m}\}$; parameters to be learned $\gamma,\beta$.
  • Output: $\{y_i = \mathrm{BN}_{\gamma,\beta}(x_i)\}$.

$\mu_{\mathcal B} \leftarrow \frac1m \sum_{i = 1}^m x_i$

$\sigma^2_{\mathcal B} \leftarrow \frac1m \sum_{i=1}^m (x_i – \mu_{\mathcal B})^2$

$\hat x_i \leftarrow \frac{x_i – \mu_{\mathcal B}}{\sqrt{\sigma_{\mathcal B}^2 + \epsilon}}$

$y_i \leftarrow \gamma \hat x_i + \beta \equiv \mathrm{BN}_{\gamma,\beta}(x_i)$


(Here, $\epsilon$ is some small constant added for numerical stability. The above is an almost exact copy of the box Algorithm 1, in section 3 of the paper linked above.)

Now, $\gamma,\beta$ are learned parameters, as far as I can tell on the level of each mini-batch. In particular, for a fixed mini-batch they can take any value. It seems to me that this makes shifting by the mean and scaling by the standard deviation pointless. The resulting output values are given by
$$
y_i = \frac{\gamma}{{\sqrt{\sigma_{\mathcal B}^2 + \epsilon}}}x_i + \beta – \frac{\gamma }{\sqrt{\sigma_{\mathcal B}^2 + \epsilon}}\mu_{\mathcal B}.
$$
Hence, if we define
$$
\gamma' = \frac{\gamma}{{\sqrt{\sigma_{\mathcal B}^2 + \epsilon}}},
$$
$$
\beta' = \beta – \frac{\gamma }{\sqrt{\sigma_{\mathcal B}^2 + \epsilon}}\mu_{\mathcal B},
$$
we might as just have defined and learned values for $\beta',\gamma'$ and then returned $y_i = \gamma'x_i + \beta'$.

I presume that I misunderstand — can someone explain where I went wrong?

Best Answer

I'm by no means an expert on this topic, but here are my thoughts.

I think the main point of batch normalization is that training of each network layer is unaffected by changes in scale of the preceding layers. The author wrote

$$BN(Wu) = BN((aW)u)$$

for any scalar $a$. Thus, also backpropagated error is unaffected as well. So there is no scale explosion, even when learning rate is high.

If you simply apply

$$BN(x) = y_i = \gamma_ix_i + \beta_i, \text{ where } x = Wu$$

the above no longer holds.

Parameters $\gamma$ and $\beta$ are a must have, because otherwise the normalized outputs $\hat{x_i}$ would be mostly close to $0$, which would hinder the network ability to fully utilize nonlinear transformations (authors give an example of sigmoid function, which is close to identity transformation in $0$ proximity)

Finally, one can argue that the input $y = Wx + b$ to the BN layer, as a combination of many factors, may resemble Gaussian distribution (CLT). So by studentizing it we keep the inputs stable and, hopefully, following standard normal distribution to some degree.