Neural Networks – Classification with a Neural Network When One Class Has Disproportionately Many Entries

classificationneural networksunbalanced-classes

I try to train a neural network using a dataset with several classes $c_1, c_2, \dotsc, c_{10}$. The class $c_1$ has a lot more entries in the training set than the other classes, and this makes my neural network to classify most of the the test set entries as being $c_1$.

What preprocessing should I make?

Best Answer

You are coping with an imbalanced dataset. Lucky for you, you are not alone. This is a common problem.

For surveys on the topic see Editorial: Special Issue on Learning from Imbalanced Data Sets (6 pages) and Learning from Imbalanced Data (22 pages)

The method I like best is the following: The method is based of the boosting algorithm Robert E. Schapire presented at "The strength of weak learnability" (Machine Learning, 5(2):197–227, 1990. The Strength of Weak Learnability ).

In this paper, Schapire presented a boosting algorithm based on combining triplets of 3 weak learners recursively. By the way, this was the first boosting algorithm.

We can use the first step of the algorithm (even without the recursion) to cope with the lack of balance.

The algorithm trains the first learner, L1, one the original data set. The second learner, L2, is trained on a set on which L1 has 50% chance to be correct (by sampling from the original distribution). The third learner, L3, is trained on the cases on which L1 and L2 disagree. As output, return the majority of the classifiers. See the paper to see why it improves the classification.

Now, for the application of the method of an imbalanced set: Assume the concept is binary and the majority of the samples are classified as true.

Let L1 return always true. L2 is being trained were L1 has 50% chance to be right. Since L1 is just true, L2 is being trained on a balanced data set. L3 is being trained when L1 and L2 disagree, that is, when L2 predicts false. The ensemble predicts by majority vote; hence, it predicts false only when both L2 and L3 predict false.

I used this method in practice many times, and it is very useful. It also has a theoretical justification so all fronts are covered.

Related Question