Solved – How to train an LSTM when the sequence has imbalanced classes

lstmneural networksunbalanced-classes

I'm labelling sequences at every time step, but some labels in the dataset only occur very briefly between two much more common labels. As a result, the NN is biased towards these common labels. I can't just show it more examples of the rare label because the state is very important in the sequence, so breaking the sequence up will negatively impact its predictive ability.

Here is a very simple example of what the data and labels look like:

Data:

1 0 0 0 0 2 0 3 0 0 0 0 0 0 0 0 0 0 4 0 3 0 0 0 0 0 0 2 0 1 0 0 0 0 2 0 0 3 0 0

Labels:

1 1 1 1 1 2 2 3 3 3 3 3 3 3 3 3 3 3 4 4 3 3 3 3 3 3 3 2 2 1 1 1 1 1 2 2 2 3 3 3

Basically an indicator appears in the sequence, then that is the label for the sequence until a new indicator comes along (hence the importance of the state)

Best Answer

Inversely proportional contributions to cost function

Another way of dealing with imbalanced data is to weight each label's contribution to the cost function inversely proportional to the frequency of the label. In your above example, I count the following frequencies of the classes:

1: 10
2:  7
3: 20
4:  2

So you could multiply the cost on a sample-by-sample basis by $\frac{1}{10}$ when the true label is 1, by $\frac{1}{7}$ for the label 2, $\frac{1}{20}$ for the label 3, and $\frac{1}{2}$ for the label 4. So you'll see 5 times as many 1 labels as 4 labels, but they'll each contribute $\frac{1}{5}^{th}$ as much to your overall cost function. In other words, you can expect each label to have roughly the same impact on your cost function on average.

In practice, I would use the frequencies of the labels across my whole training set, and set the numerator so that the sum of my multipliers is 1. E.g. in the above example, I'd use the fractions $\frac{1.26}{10}$, $\frac{1.26}{7}$, $\frac{1.26}{20}$, $\frac{1.26}{2}$ which add up to ~1. You don't need to scale in this way, but if you don't you are in effect modifying your learning rate.

A danger of using this approach (as with resampling) is the increased chance of overfitting to the rare labels. You'll likely want to regularize your model somehow if you use this kind of approach.

On a practical note, I believe most deep learning libraries offer this functionality. For example, in the python library keras, the keras.models.Model.fit() method has a sample_weight parameter:

sample_weight: Optional Numpy array of weights for the training samples, used for weighting the loss function (during training only). You can either pass a flat (1D) Numpy array with the same length as the input samples (1:1 mapping between weights and samples), or in the case of temporal data, you can pass a 2D array with shape (samples, sequence_length), to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile().

Lastly, I'll encourage you to make sure you have a good performance metric you trust. Using an approach like this may result in your model estimating your rare labels more often than is actually desirable. As Tim said in a comment:

If something is more common, it is reasonable that it gets predicted more commonly.