Solved – Why would gradient boosted trees generalize better than a neural network on time series classification

boostinggradient descentmachine learningneural networks

I'm trying to solve a binary classification task on a noisy time series dataset with highly right-skewed features. I find that both an MLP and a gradient boosted tree get a similar log loss on the training set, but the tree model is much better able to generalize to the validation set than the network is. Do you have any any intuition as to why this is?

Best Answer

Since gradient boosted trees deal well with new data but multilayer perceptrons do not, a simple explanation is that the multilayer perceptrons are overfitting, and I think the difference is due to (stronger) regularisation in the gradient boosted trees.

The overall objective function for both models can be written $$ \mathrm{Obj}(\theta) = L(\theta) + \Omega(\theta) $$ where $L(\theta)$ is the training loss and $\Omega(\theta)$ is regularisation function. Even though both are presumably using the same logistic loss function as the training loss $L(\theta)$, I suspect the regularisation functions are different. So if the "loss" you are looking at is really the objective function then you are looking at different functions.

I think most gradient boosted tree packages use several regularisation techniques which are outlined in the wikipedia article. One common strategy is to set a learning rate $\nu$ which shrink the values of new trees towards zero. An additional strategy which xgBoost uses is to prune the new trees based the leaf weights. This is all explained nicely in this XGBoost tutorial which I found helpful in better understand gradient boosted trees. The models gbm in R and GradientBoostingClassifier in Python also have different strategies like setting a threshold for the number of observations required to split a node. Most of these hyperparameters probably work quite well out of the box.

A MLP that you implement in, say, Tensorflow or Keras will not include regularisation by default, but it is easy to add either dropout or L1/L2 regularisation.