Solved – Training models for multiple epochs vs one “super epoch”

machine learning

This is a conceptual question regarding model training (e.g. CNNs). I have since solved the issue that raised this question, but I was still curious.

Preliminaries:
In the typical training setting, we have $N$ training examples, which we batch to use with mini-batch SGD (or similar optimization). One run through all training examples is a single epoch. Let's say you plan to run $M$ epochs.

Now the question: During the training it's generally good practice to shuffle the data such that the mini-batches are not the same during each epoch. If one is using the mini-batches to update gradients, can one instead train over a "super-epoch" by pre-shuffling the data and feeding it $N\cdot M$ training examples? That is, perform $M$ random shuffles (without replacement) and concatenate to generate $N \cdot M$ examples. Then train for only a single epoch.

Is the only downside of this (say, for TensorFlow) that you cannot check the progress of accuracy/etc. of your validation set following each epoch? I know that is a pretty big downside, and I'm not advocating for this method…was just curious if there was anything else I was missing in my understanding.

Best Answer

As long as you’re using mini-batching for both, you’ve described the same procedure in two different ways. Dividing the $M \times N$ shuffled samples into $M$ epochs is the same number of updates as all $M\times N$ updates concatenated together. If you further require that each of the $N$ samples appear exactly once before being used again, your method just a different description of the ordinary method of several epochs. (You can do "special" things include common early-stopping methods which check the model quality against a validation set on a regular basis, so you might not achieve all $M$ epochs before a termination condition is satisfied.)

Checking model statistics at the end of an epoch is a common way to see how well the model is doing, but it's purely a matter of social practice. You can report validation statistics every $k\ge 1$ mini-batches if you want, with the understanding that you'll be doing more computation because you're evaluating the model more frequently than once per epoch.