Solved – How to correctly use validation and test sets for Neural Network training

conv-neural-networkmachine learningneural networkstrainvalidation

I am in the machine learning business for a long time, but still, this fundamental fact gets me confused, since every paper, article and/or book describe different kind of usages for validation and test sets. I aim to set my mind at ease forever on this issue:

Let's say I have a relatively small Convolutional Neural Network and I want to train it on MNIST dataset. Naturally, I want to learn the best hyperparameters for the given CNN, like the weight decay coefficient $\lambda$, the learning rate $\alpha$, etc. Naturally, MNIST has 60K training images and 10K test images. The basic practice is allocating 10K of the training set as the validation set for MNIST.

Assume that we are doing a grid search on $\lambda$ and $\alpha$ hyperparameters and assume that we have fixed a $(\lambda^*,\alpha^*)$ pair. Then we start training the CNN using SGD, using the 50K training set and measuring the performance on the 10K validation set in the end of each epoch (a complete pass over the 50K training set). Usually we either stop training if a fixed budget of epochs has been depleted or we start to lose accuracy on the validation set. And then, only for a single time in the training process of the CNN with the $(\lambda^*,\alpha^*)$ pair, we measure the real performance on the test set. We pick the one with best performance. This usage is intuitive and does not cause the network to unfairly overfit on the test set. But the problem is, we did not ever used the 10K validation set in the actual training with SGD.

So, now, what is the correct way to use the validation set here? After we finish training with the $(\lambda^*,\alpha^*)$ pair, either by running out of allowed epochs or early exiting due to overfit on validation set, should we merge the validation set and training set, and train a few epochs more? At this point, our learning rate may already be very small due to decaying it gradually. Should we train everything from scratch by using the 60K training + validation set, setting the learning rate to its initial value?

Another alternative may be to train the CNN with every $(\lambda,\alpha)$ pair in consideration of our grid search, using only 50K training samples in the SGD and 10K validation samples for measuring the accuracy. No interaction with the test set is allowed. Let's assume that after we train with every $(\lambda,\alpha)$ pair, we pick the pair which yields the highest accuracy on the validation set. Then we train the CNN with the picked hyperparameters, from scratch, this time using the 60K training + validation samples directly in SGD, for an amount of fixed epochs. After the training ends, we use the test set once and for all to declare our final accuracy performance.

This method, in my mind, would cause the following issue: The hyperparameters we have picked would be the optimal ones for the 50K training samples! Since we have used a smaller subset of the actual training set, most probably the model would more easily overfit, seeing a lower data variation, so our hyperparameter search may tend to find a higher $\lambda$. And with this $\lambda$ we would probably see a lower accuracy on the actual test set with the full 60K training samples, since the model won't be able to compensate for the variation of the extra 10K samples, due to high $\lambda$.

So, I am not certain what the ultimately correct way to use the validation-test sets. What would be the valid – logical approach here?

Best Answer

The bottom line is:

As soon as you use a portion of the data to choose which model performs better, you are already biasing your model towards that data.1

Machine learning in general

In general machine learning scenarios you would use cross-validation to find the optimal combination of your hyperparameters, then fix them and train on the whole training set. In the end, you would evaluate on the test set only to get a realistic idea about its performance on new, unseen data.

If you would then train a different model and select the one of them which performs better on the test set, you are already using the test set as part of your model selection loop, so you would need yet a new, independent test set to evaluate the test performance.

Neural networks

Neural networks are a bit specific in the sense that their training is usually very long, thus cross-validation is not used very often (if training would take 1 day, then doing 10 fold cross validation already takes over a week on a single machine). Moreover, one of the important hyperparameters is the number of training epochs. The optimal length of the training varies with different initializations and different training sets, so fixing number of epochs to one number and then training on all training data (training+validation) for this fixed number is not very reliable approach.

Instead, as you mentioned, some form of early stopping is used: Potentially, the model is trained for a long time, saving "snapshots" periodically, and eventually the "snapshot" with the best performance on some validation set is picked. To enable this, you have to always keep some portion of the validation data aside2. Therefore, you will never train the neural net on all of the samples.

Finally, there are plenty of other hyperparameters, such as the learning rate, weight decay, dropout ratios, but also the network architecture itself (depth, number of units, size of conv. kernels, etc.). You could potentially use the same validation set which you use for early stopping to tune these, but then again, you are overfitting to this set by using it for early stopping, so it does give you a biased estimate. Ideal would be, however, using yet another, separate validation set. Once you fix all the remaining hyperparameters, you could merge this second validation set into your final training set.


To wrap it up:

  1. Split all your data into training + validation 1 + validation 2 + testing
  2. Train network on training, use validation 1 for early stopping
  3. Evaluate on validation 2, change hyperparameters, repeat 2.
  4. Select the best hyperparameter combination from 3., train network on training + validation 2, use validation 1 for early stopping
  5. Evaluate on testing. This is your final (real) model performance.

1 This is exactly the reason why Kaggle challenges have 2 test sets: a public and private one. You can use the public test set to check the performance of your model, but eventually it is the performance on the private test set that matters, and if you overfit to the public test set, you lose.

2 Amari et al. (1997) in their article Asymptotic Statistical Theory of Overtraining and Cross-Validation recommend setting the ratio of samples used for early stopping to $1/\sqrt{2N}$, where $N$ is the size of the training set.

Related Question