How to Choose a Predictive Model After K-Fold Cross-Validation

cross-validationmodel selection

I am wondering how to choose a predictive model after doing K-fold cross-validation.

This may be awkwardly phrased, so let me explain in more detail: whenever I run K-fold cross-validation, I use K subsets of the training data, and end up with K different models.

I would like to know how to pick one of the K models, so that I can present it to someone and say "this is the best model that we can produce."

Is it OK to pick any one of the K models? Or is there some kind of best practice that is involved, such as picking the model that achieves the median test error?

Best Answer

I think that you are missing something still in your understanding of the purpose of cross-validation.

Let's get some terminology straight, generally when we say 'a model' we refer to a particular method for describing how some input data relates to what we are trying to predict. We don't generally refer to particular instances of that method as different models. So you might say 'I have a linear regression model' but you wouldn't call two different sets of the trained coefficients different models. At least not in the context of model selection.

So, when you do K-fold cross validation, you are testing how well your model is able to get trained by some data and then predict data it hasn't seen. We use cross validation for this because if you train using all the data you have, you have none left for testing. You could do this once, say by using 80% of the data to train and 20% to test, but what if the 20% you happened to pick to test happens to contain a bunch of points that are particularly easy (or particularly hard) to predict? We will not have come up with the best estimate possible of the models ability to learn and predict.

We want to use all of the data. So to continue the above example of an 80/20 split, we would do 5-fold cross validation by training the model 5 times on 80% of the data and testing on 20%. We ensure that each data point ends up in the 20% test set exactly once. We've therefore used every data point we have to contribute to an understanding of how well our model performs the task of learning from some data and predicting some new data.

But the purpose of cross-validation is not to come up with our final model. We don't use these 5 instances of our trained model to do any real prediction. For that we want to use all the data we have to come up with the best model possible. The purpose of cross-validation is model checking, not model building.

Now, say we have two models, say a linear regression model and a neural network. How can we say which model is better? We can do K-fold cross-validation and see which one proves better at predicting the test set points. But once we have used cross-validation to select the better performing model, we train that model (whether it be the linear regression or the neural network) on all the data. We don't use the actual model instances we trained during cross-validation for our final predictive model.

Note that there is a technique called bootstrap aggregation (usually shortened to 'bagging') that does in a way use model instances produced in a way similar to cross-validation to build up an ensemble model, but that is an advanced technique beyond the scope of your question here.