Solved – Does the chronological order of batches and epochs matter

deep learningmachine learningneural networksoverfitting

BACKGROUND

I have a very large file that contains my training data. For the sake of simplicity, assume the data stored in this file is shuffled and stratified so that any idiosyncrasies get as spread-out as possible.

I want to train a neural network on this data file. Typically, a neural network consumes data in batches, then once all the batches are consumed, one starts over from the first batch for the next iteration of training (aka. epoch).

PROBLEM

If the file is so large that only part of it (say, half of it) can be loaded to memory, how should I consume the training data? As far as the training is concerned, what are the pros and cons of the two scenarios below?

Scenario 1:

  • For each segment of the file…
    • For each epoch…
      • For each batch…
        • Train the network on the batch.

Scenario 2:

  • For each epoch…
    • For each segment of the file…
      • For each batch…
        • Train the network on the batch.

Let's spell these out in more details. Assume that the file is made up of 8 batches b1 b2 b3 b4 b5 b6 b7 b8 and that at most 4 contiguous batches can be loaded to memory. If the training consist of 3 epochs, then scenario 1 would effectively feed the following sequence of batches to the neural network:

First ½ of the file                 Second ½ of the file    
|                                   |                                   
epoch 1     epoch 2     epoch 3     epoch 1     epoch 2     epoch 3     
|           |           |           |           |           |           
b1 b2 b3 b4 b1 b2 b3 b4 b1 b2 b3 b4 b5 b6 b7 b8 b5 b6 b7 b8 b5 b6 b7 b8

On the other hand, scenario 2 would be

epoch 1                 epoch 2                 epoch 3
|                       |                       |
First ½     Second ½    First ½     Second ½    First ½     Second ½    
|           |           |           |           |           |          
b1 b2 b3 b4 b5 b6 b7 b8 b1 b2 b3 b4 b5 b6 b7 b8 b1 b2 b3 b4 b5 b6 b7 b8 

Presumably, scenario 1 first biases the neural network to the batches 1-4 and then adds a bias towards batches 5-8, whereas scenario 2, "epoch-wise" at least, does not introduce any bias since it progresses through all batches 1-8 at each iteration. But the question is, does any of this ordering matter? Is there an altogether better alternative to these two scenarios?

Best Answer

Takeaway message: With SGD and large datasets, epochs are not that important; To avoid keeping entire datasets in memory, just keep generating freshly sampled minibatches and you'll be fine.

With respect to your suggestions, both are not recommended.

  • generally, fixed batches are a bad idea.
  • your first solution is specifically bad. It is in fact equivalent to training your network on half the data and then fine-tuning on the second half. In some cases this can lead to what is known as "catastrophic forgetting", which is that the finetuning phase will actually make the network "forget" its ability to predict well for the first half of the data.

A better approach

It is a very common practice to train on datasets that are so large they don't fit into memory. Generally what you want to do is implement a generator that prepares "fresh" batches for you. With keras, for example, this is very easy:

  1. write a custom generator that yields (examples, labels) with the number of examples being the desired batch size. The generator is expected to loop over your data indefinitely.
  2. train your network using the fit_generator method. If you generate a batch of size $b$ and your training set is of size $m$, then a single epoch would consist of $\frac{m}{b}$ steps.

As for the implementation of the generator, which is what your'e asking about, since you're training with SGD you want to sample $b$ random examples every time. Now you have two approaches: sample with replacement and sample without replacement. The naive implementation would just be to sample with replacement (since sampling without replacement would require you to keep track of the previous batches the generator yielded). It's true that in this way you're not guaranteed to go over all the $m$ training examples at every epoch, but in different epochs you'll be "missing" different examples, which will sort-of cancel itself out as training progresses. This tends to be especially insignificant for large datasets (see [2] for a more detailed discussion on this question).

Related Question