Solved – Fine-tuning with a subset of the same data

deep learningmachine learningtransfer learning

My understanding of fine-tuning is to take a pre-trained model trained on a similar but separate dataset and update the weights of a portion of the model on your dataset.

I'm not sure if this is called something entirely different, or if it's a terrible idea.. but say you have a dataset for a classification task where the data varies to a reasonable extent depending on some internal category, to the point that typically an individual model would be trained for each of these categories. Would it be a good idea (or possible) to train an initial model on the whole data, then fine-tune the model for each subset?

For example, some sort of data for a binary classification task involving five different subjects, where the expression of the positive category is generally similar but still differs on an individual basis. This type of data does not have any available pre-trained models, or additional freely available data that could be used instead.

I feel like I'm failing either with my choice of search terms, or I'm missing something that makes this a spectacularly bad idea, because I don't seem to be able to find anything about this sort of approach. I can clearly see overfitting being a large potential problem, but I still feel like I should be able to find something about this – even if it's just saying it's a bad idea. Is there a name for this kind of approach?

Best Answer

The answer you seek is highly domain dependent. But before I get into that I want to introduce some terminology that might be useful.

Transfer learning is the improvement of learning in a new task through the transfer of knowledge from a related task that has already been learned.

Reference for Transfer Learning

Boosting, bagging and randomization are methods to improve model performance but on samples of same data.

Boosting and bagging are more specifically ensemble methods that create a number of classifiers and then combine them using various methods to get an improved model - or fine tuning as you say.

The reason the distinction between transfer learning vs. other techniques is important - problem domain gets involved.

In deep learning, let us say the new data is similar but not exactly the same, even the response categories can vary. For this scenario you would use transfer learning and not the other methods (Tensorflow tutorial).

But if you're predicting say ham / spam and find that your algorithm is not working well on a small population of spam mails or if a certain population of your spam mails is too small to be significant for your classifier, say logistic regression- you would benefit from using an algorithm that accommodates resampling techniques such as boosting bagging (example XGBoost - one of the frequently winning algorithms on Kaggle).

Theoretically algorithms like XGBoost are robust against overfitting, but there's a possibility to overfit. Related answer explores overfitting aspect more.