Solved – How does batch size affect convergence of SGD and why

gradient descentmachine learningneural networksoptimizationstochastic gradient descent

I've seen similar conclusion from many discussions, that as the minibatch size gets larger the convergence of SGD actually gets harder/worse, for example this paper and this answer. Also I've heard of people using tricks like small learning rates or batch sizes in the early stage to address this difficulty with large batch sizes.

However it seems counter-intuitive as the average loss of a minibatch can be thought of as an approximation to the expected loss over the data distribution,
$$\frac{1}{|X|}\sum_{x\in X} l(x,w)\approx E_{x\sim p_{data}}[l(x,w)]$$
the larger the batch size the more accurate it's supposed to be. Why in practice is it not the case?


Here are some of my (probably wrong) thoughts that try to explain.

The parameters of the model highly depend on each other, when the batch gets too large it will affect too many parameters at once, such that its hard for the parameters to reach a stable inherent dependency? (like the internal covariate shift problem mentioned in the batch normalization paper)

Or when nearly all the parameters are responsible in every iteration they will tend to learn redundant implicit patterns hence reduces the capacity of the model? (I mean say for digit classification problems some patterns should be responsible for dots, some for edges, but when this happens every pattern tries to be responsible for all shapes).

Or is it because the when the batches size gets closer to the scale of the training set, the minibatches can no longer be seen as i.i.d from the data distribution, as there will be a large probability for correlated minibatches?


Update
As pointed out in Benoit Sanchez's answer one important reason is that large minibatches require more computation to complete one update, and most of the analyses use a fix amount of training epochs for comparison.

However this paper (Wilson and Martinez, 2003) shows that a larger batch size is still slightly disadvantageous even given enough amount of training epochs. Is that generally the case?
enter image description here

Best Answer

Sure one update with a big minibatch is "better" (in terms of accuracy) than one update with a small minibatch. This can be seen in the table you copied in your question (call $N$ the sample size):

  • batch size 1: number of updates $27N$
  • batch size 20,000: number of updates $8343\times\frac{N}{20000}\approx 0.47N$

You can see that with bigger batches you need much fewer updates for the same accuracy.

But it can't be compared because it's not processing the same amount of data. I'm quoting the first article:

"We compare the effect of executing $k$ SGD iterations with small minibatches $B_j$ versus a single iteration with a large minibatch $\displaystyle\bigcup_{1\leq j\leq k} B_j$"

Here it's about processing the same amount of data and while there is small overhead for multiple mini-batches, this takes comparable processing resources.

There are several ways to understand why several updates is better (for the same amount of data being read). It's the key idea of stochastic gradient descent vs. gradient descent. Instead of reading everything and then correct yourself at the end, you correct yourself on the way, making the next reads more useful since you correct yourself from a better guess. Geometrically, several updates is better because you are drawing several segments, each in the direction of the (approximated) gradient at the start of each segment. while a single big update is a single segment from the very start in the direction of the (exact) gradient. It's better to change direction several times even if the direction is less precise.

The size of mini-batches is essentially the frequency of updates: the smaller minibatches the more updates. At one extreme (minibatch=dataset) you have gradient descent. At the other extreme (minibatch=one line) you have full per line SGD. Per line SGD is better anyway, but bigger minibatches are suited for more efficient parallelization.

At the end of the convergence process, SGD becomes less precise than (batch) GD. But at this point, things become (usually) a sort of uselessly precise fitting. While you get a slightly smaller loss function on the training set, you don't get real predictive power. You are only looking for the very precise optimum but it does not help. If the loss function is correctly regularized (which prevents over-fitting) you don't exactly "over"-fit, you just uselessly "hyper"-fit. This shows as a non significant change in accuracy on the test set.

Related Question