Solved – Interpret learning curves

cartclassificationinterpretation

I am training a Decision Tree on a dataset of around 580.000 data points.

I took the following steps:

  • Split the dataset in training (75%) and validation (25%) set.

  • Determined the best depth for the Decision Tree by creating trees with depth ranging between 1 and 100 and taking the one with the best score.

  • Used this optimal depth to create trees with different training set sizes (steps of 5000). I took the error of each tree w.r.t. the training set and the validation set to plot the learning curves.

I got the following learning curves:

learning curves

I am not quite sure how to interpret these curves. Both the error on the training and validation set are low. Since the difference between both curves is not that big, I'm assuming my model does not overfit (a lot).

Is this a correct interpretation? Anything I can do to improve on this?

Best Answer

For the plot as it is there is little to worry about. Your CV (and it better be cross-validation error, than a single train/test split) error is decreasing at decreasing rate (and eventually saturates) as the number of your training instances grows. This is normal. Yes, you overfit, but not a lot.

There is a lot to worry about before you obtain this plot.

Please keep in mind that the balanced tree of depth 100 would have around 2**100 = 1.27e30 leaves. This is much larger than number of points in your data set. Therefore such depth makes no sense to me. Since your tree is not always balanced, there is no strict rule. But the optimal depth 58 also seems suspicious to me. Check how many leaves you have. It should be much less than 580K.

The way to improve depends on your purpose. If your primary goal is understanding (looks like so, since you are using a decision tree)

  • have a look at feature importance: probably only few of 55 make a real difference; with python scikit-learn use "clf.feature_importances_" after you trained your classifier

  • build a tree of human-understandable depth (3 or 4, at most 5 or 6) and visualize it

If you are hunting for accuracy, try other methods. Based on your initial choice of a tree:

  • for a single tree GridSearchCV for mean_samples_split or min_samples_leaf combined with max_depth or max_leaf_nodes; of course, you can include other parameters too

  • a simplest way: with 55 features you are good for random forest

  • bagging and boosting, read about Ensemble Methods, for example here with python or find info yourself for a language of your choice

Related Question