Solved – Stochastic gradient descent: why randomise training set

gradient descentlogisticregressionsamplingstochastic-approximation

I'm given a dataset of 200 million training examples. The stochastic gradient descent method requires me to sample these randomly, to avoid it gets 'stuck'.

First and for all, I don't see how it gets stuck. So the fact the sample needs to be random, and I cannot just traverse the dataset in a sequential manner is a riddle to me right now.

The following paragraph I found here is not clear enough.

The first step of the procedure requires that the order of the training dataset is randomized. This is to mix up the order that updates are made to the coefficients. Because the coefficients are updated after every training instance, the updates will be noisy jumping all over the place, and so will the corresponding cost function. By mixing up the order for the updates to the coefficients, it harnesses this random walk and avoids it getting distracted or stuck.

I would assume entries in a dataset are independent from each other?

But more importantly, if I'm tasked to shuffle this dataset of 200 million examples, does it not introduce a big overhead? Surely shuffling a dataset of 200 million samples is gonna take some time?

Best Answer

Generally, in case your data is ordered (see e.g. Mnist data set) SGD will have problems. Also, in case you run through it multiple times (so called epoches) having the same order on each run through will probably lead to problems like finding local minima or slower convergence. So you should randomize on each epoche.

In case of a huge amount of data points it is not recommended to shuffle the whole data set. Rather you can follow another random strategy. It may be enough to randomly sample (without replacement) a mini batch of data points. Subsequently you fit the model on the mini-batch by SGD and repeat this process across more mini-batches until convergence. This procedure will also find the solution and probably will use far less data than the full set of data points (but the latter depends on the type of model and data). It is therefore most of the times the much cheaper procedure.

Related Question