Solved – Interplay between early stopping and cross validation

boostingcross-validationmachine learningneural networksvalidation

I am a little bit confused by early stopping and in particular by how it can be inserted inside a CV framework. As far as I understand, I can fix the optimal number of epochs (for NN, or number of trees for XGB) by early stopping, that is:

  • pick a validation set,
  • train with increasing number of epochs until a predefined metric evaluated on the validation set starts worsening
  • that is the optimal number of epochs to prevent overfitting

and that's fine. But then I would like to insert this early stopping framework inside a CV framework: Suppose I have a model with 10 hyperparameters I want to fix via CV. And suppose that there is an eleventh hyperparameter, the number of epochs.
My feeling is that one can do like this:

  • create the K resampled folds, for each of which you have a training and validation set
  • choose a suitable grid for your 10 hyperparameters
  • for each point on the grid train your model in each fold with early stopping, that is use the validation set of the fold to keep track of the preferred metric and stop when it gets worse
  • take the mean of the K validation metric
  • choose the point of the grid (i.e. the set of hyperparameters) that gives the best metric

Questions:

  1. What number of epochs should I choose as optimal? in each of the K folds I have, in general, a different number of stopping epochs. H20 doc seems to suggest they take the mean of the K stopping epochs. Is this right?
  2. Is it actually a "fair" practice to use the validation metric coming from the early stopping as a proxy of the out-of-sample metric? As Max Khun seems to point here in section 3.4.5, maybe the best thing to do would be:

    …if you want to do early stopping, then in each fold you should take your training set and split it again, holding out a small early-stopping-set to guide the early stopping, and then evaluate the model on the validation set of that fold.

    But this seems to me a really intricate process…

Unfortunately, I wasn't able to find references where this issue is presented in a clear and transparent form.

Best Answer

This topic has been already discussed from several angles:

However, I think that none of the answers covers your question fully, so I will summarize:

  1. You should not use the validation fold of cross-validation for early stopping—that way you are already letting the model "see" the testing data and you will not get an unbiased estimate of the model's performance. If you must, leave out some data from the training fold and use them for early stopping.

  2. However, this does not help you too much, for two reasons:

    1. Optimal stopping epoch may have large variance between different folds and there is no guarantee that taking the mean will be optimal in any way.
    2. If you decide to train on the whole dataset, length of an "epoch" will change: Epoch is defined as "using the whole dataset once", so how many weight updates happen in one epoch depends on the training set size and the batch size. Early stopping generally aims at limiting the maximal number of weight updates, so optimizing "epoch count" on a dataset of different size makes no sense.

      Thus, if anything, optimize early stopping in terms of weight updates, not epochs.

Finally, I think the best approach is not to use cross-validation for early stopping tuning, instead tune all the other hyperparameters and then during the final training leave aside a small validation set which you use for early stopping.