Identify overfitting from LSTM plot, from the prediction on trained+unseen data

loss-functionslstmoverfittingrecurrent neural network

I am currently learning LSTM-RNN models and I have done some tests to see how they work. As in the most NN, overfitting and underfitting is a problem in ML. I have read articles such as this guy here: https://machinelearningmastery.com/learning-curves-for-diagnosing-machine-learning-model-performance/ this: https://towardsdatascience.com/learning-curve-to-identify-overfitting-underfitting-problems-133177f38df5 and this: Dealing with LSTM overfitting All of them are talking about detecting overfitting and underfitting using loss functions: train loss function and test/validation loss function. In papers around the google I see they are depicting plots of real datasets + prediction on trained datasets + prediction on unseen datasets. I haven't seen someone depicting loss functions. So, my question is how can I understand if a LSTM-RNN model works well and doesn't overfit/underfit from the plot of (real dataset + prediction on trained dataset + prediction on unseen dataset)?? Is it possible?

Best Answer

For future readers: I clarified my understanding of the question in the comments.

EDIT: This answer is not specific to LSTM or neural networks, it is true for any predictive algorithm.

Response: In general, you probably can tell overfitting/underfitting from a single plot of true values (all, train and test) + training data predictions + testing data predictions. However, there are some pretty big issues with doing this, and I don't see why you wouldn't just use more objective methods.

How to do it from plot: It's pretty straightforward. You know that the definition of overfitting is that the model does much better in training than in test. Visually, from a plot, you will detect this by seeing that the model predictions match very closely with true values in the training set section of the plot, but are noticeably worse/farther away/messy-looking in the test set.

For underfitting, you will see in the plot that the predictions are bad/messy/far-away-from-true-values in both, the training set section of the plot and the test set section. As a general note, it is pretty unlikely that you are underfitting with a neural network.

The problems with doing that (please read!):

  1. You are working with 10,000 measurements. To be able to visually detect over/under fitting from a plot of 10,000 points is going to be very difficult. As in, you'll have to zoom in a ton to be able to tell what's going on. I literally mean that there aren't enough pixels on your screen to easily distinguish what's going on between the train and test set on a single plot, so unless you zoom in a lot and scroll side to side (annoying and difficult), this will be a pain.
  2. This method of eyeballing it from the plot is pretty subjective. If it's truly extremely obvious overfitting, you will be able to tell. But besides that, why do this subjective method when you can use an objective one?

My recommendations:

  1. The most straightforward and objective way to tell if you're overfitting is to compare the error in your training (better yet, cross validation) set vs. the error in your test set. That is, compute the average error across all points in both sets. If training/cross-validation set error is significantly lower than test set error, there's overfitting. If they're about equal, but both are bad, there's underfitting.
  2. If you insist on having a plot, I would recommend that you plot the errors (prediction minus true value at every point). NOT the true values vs. predictions as you are suggesting (because again, visually hard to tell what's going on). Plot the errors and maybe even run some simple moving average or something to make it even more easily visually interpretable (so you don't have what looks like crazy white noise). If you plot this as I'm describing (perhaps one color for train set error, one color for test set error), you will probably be able to visually compare the error (and performance) between the two sets. However, why not just do option 1 and have a quantifiable result?

Best of luck!

Related Question