Machine Learning – Training on the Full Dataset After Cross-Validation

cross-validationmachine learningmodel selection

TL:DR: Is it ever a good idea to train an ML model on all the data available before shipping it to production? Put another way, is it ever ok to train on all data available and not check if the model overfits, or get a final read of the expected performance of the model?


Say I have a family of models parametrized by $\alpha$. I can do a search (e.g. a grid search) on $\alpha$ by, for example, running k-fold cross-validation for each candidate.

The point of using cross-validation for choosing $\alpha$ is that I can check if a learned model $\beta_i$ for that particular $\alpha_i$ had e.g. overfit, by testing it on the "unseen data" in each CV iteration (a validation set). After iterating through all $\alpha_i$'s, I could then choose a model $\beta_{\alpha^*}$ learned for the parameters $\alpha^*$ that seemed to do best on the grid search, e.g. on average across all folds.

Now, say that after model selection I would like to use all the the data that I have available in an attempt to ship the best possible model in production. For this, I could use the parameters $\alpha^*$ that I chose via grid search with cross-validation, and then, after training the model on the full ($F$) dataset, I would a get a single new learned model $\beta^{F}_{\alpha^*}$

The problem is that, if I use my entire dataset for training, I can't reliably check if this new learned model $\beta^{F}_{\alpha^*}$ overfits or how it may perform on unseen data. So is this at all good practice? What is a good way to think about this problem?

Best Answer

The way to think of cross-validation is as estimating the performance obtained using a method for building a model, rather than for estimating the performance of a model.

If you use cross-validation to estimate the hyperparameters of a model (the $\alpha$s) and then use those hyper-parameters to fit a model to the whole dataset, then that is fine, provided that you recognise that the cross-validation estimate of performance is likely to be (possibly substantially) optimistically biased. This is because part of the model (the hyper-parameters) have been selected to minimise the cross-validation performance, so if the cross-validation statistic has a non-zero variance (and it will) there is the possibility of over-fitting the model selection criterion.

If you want to choose the hyper-parameters and estimate the performance of the resulting model then you need to perform a nested cross-validation, where the outer cross-validation is used to assess the performance of the model, and in each fold cross-validation is used to determine the hyper-parameters separately in each fold. You build the final model by using cross-validation on the whole set to choose the hyper-parameters and then build the classifier on the whole dataset using the optimized hyper-parameters.

This is of course computationally expensive, but worth it as the bias introduced by improper performance estimation can be large. See my paper

G. C. Cawley and N. L. C. Talbot, Over-fitting in model selection and subsequent selection bias in performance evaluation, Journal of Machine Learning Research, 2010. Research, vol. 11, pp. 2079-2107, July 2010. (www, pdf)

However, it is still possible to have over-fitting in model selection (nested cross-validation just allows you to test for it). A method I have found useful is to add a regularisation term to the cross-validation error that penalises hyper-parameter values likely to result in overly-complex models, see

G. C. Cawley and N. L. C. Talbot, Preventing over-fitting in model selection via Bayesian regularisation of the hyper-parameters, Journal of Machine Learning Research, volume 8, pages 841-861, April 2007. (www,pdf)

So the answers to your question are (i) yes, you should use the full dataset to produce your final model as the more data you use the more likely it is to generalise well but (ii) make sure you obtain an unbiased performance estimate via nested cross-validation and potentially consider penalising the cross-validation statistic to further avoid over-fitting in model selection.