Solved – Why does Batch Normalization not completely prevent a network from being able to train at all

batch normalizationmachine learningneural networks

I do understand that BN does work; I just don't understand how "fixing" (changing) the distribution of every mini-batch doesn't throw everything completely out of whack.

For example, let's say you were (for some reason) training a network to match a letter to a number grade. You have one feature – the numeric value.

The batch input to one of your non-linear layers is (90, A), (80, B), (70, C). It gets normalized first, and the B data point becomes (0, B) – ignoring the constant factor/bias parameters since that's basically just another linear layer.

Your second training batch input to that layer is (60, D), (60, D), (60, D). These get normalized as well.

Now you have 3 data points that say (0, D) – but your network just spent time learning that a value of "0" at this point in the network should tend towards "B"

How does this sort of transformation not break down training in its entirety?

Even though you likely normalized your features before-hand, aren't mini-batches going to end up having significantly differing means and variances that end up throwing everything off?

A feature being equal to 0.1 in one batch might mean something completely different than that feature being equal to 0.1 in another.

What's my fundamental misunderstanding here? Apologies if this is a silly question; I haven't been able to find any answer to this on the web, since I'm assuming it's more of my failure to understand statistics than BN specifically.

Best Answer

Batch normalisation is designed specifically to correct for batch-wise effects. You mention grades, which, if they are exam grades, are a classic example of where it may have benefit. Each year the set of students sit the same questions, but these questions differ from year to year. If we assume that year to year (ignoring long term variation due to education reforms etc, that is another ball game) the student population has a consistent IQ, we then expect that differences in raw test scores from year to year reflect differences in question difficulty. Note this is an assumption and as such open to challenge, but it is this assumption that leads to using batch normalisation.

Batch normalising allows correcting for year to year differences in test difficulty.

If the batches are distinct, sufficiently large to provide reliable estimates of batch effects and categorical rather than arbitrary divisions of a continuous process, then the model is being trained to work on batch corrected data and as long as the methodology is consistent then it can be applied to new data. We apply the model with the implict assumption that the batch wise effects continue to crop up in new years.

Batch normalisation is not a great solution for subsampling a continous population, while it may often converge over many iterations to typical behaviour, there will be many local inconsistencies. If the data is not in clear batches, it would be better using methods designed for segmenting continous processes (moving averages, detrending, splines etc)

How does this sort of transformation not break down training in its entirety?

If it is appropriate, what will happen is that it corrects irrelevant batch wise errors, making your data more informative and consistent which makes training better by removing irrelevant noise. If it is not relevant it will break it (although it may do so in non obvious ways, so it may look like it's working during training)

Even though you likely normalized your features before-hand, aren't mini-batches going to end up having significantly differing means and variances that end up throwing everything off?

The point is that batches end up with comparable means and variances after correction. This then makes the batches comparable.

A feature being equal to 0.1 in one batch might mean something completely different than that feature being equal to 0.1 in another.

Yes, this is the point of batch normalisation, to standardise the scales between batches so we can then compare values after correcting for batch wise effects