Machine Learning – How to Train a Final Model with TensorFlow

cross-validationkerasmachine learningscikit learntensorflow

Based on a previous question and on this article, it is suggested that you split the data between train and test (or train/validate/test). But once you have control of your model, you should retrain or train again using the entire dataset to get more points and obtain a (theoretically) better performance.

This approach is easily applied wit Scikit-learn, but I am struggling to apply it with the TensorFlow/Keras package. Actually, it seems that Tensorflow does not conceive this method, as you end up always dedicating part of the data for the validation set, in order to control for early stops and the best model setup (ModelCheckpoint, or anything considering the "val_loss"). On the other hand, if you retrain a model from scratch with the entire dataset and without any validation control scheme, you might easily end up overfitting the model.

Therefore, I would like to ask if there is any procedure in Tensorflow/Keras that would allow to update or extend the model from the training phase to the full final phase, where you incorporate the rest of the data used to validate, without running in overfitting or suboptimization issues.

Best Answer

If you follow the methodology outline in the references you provide, you get a value for your hyperparameters. This includes the number of training epochs (i.e. the number of epochs used in each training run, if you used early stopping).

You then use the parameter values you found to build your final model. To train you final model, use all the data that you've used so far (as any of training/validation/test data) as your training data. You can't use this data as validation or test data, as you risk overfitting and/or reporting overfitted results.

If you think this could still lead to overfitting and have some test data that you haven't used at all up to this point, then exclude that test data from the training data. You could then evaluate your final model using this test data.

Related Question