Neural Networks – Effects of Class Imbalance on Neural Network Weights

classificationmulti-classnatural languageneural networksunbalanced-classes

My question is about unbalanced classes problem in case of a classifier neural network for natural language processing (in particular, a neural network with LSTM).

I want to train a neural network to discriminate data between four classes and then I want to extract weights vectors from the embedding layer so that I can use them to have a numerical dense vector representation of my data.

If the classes are unbalanced (for example, class 1 has 27000 data, class 2 has 3600 data, class 3 has 260 data and class 4 has 10 data) and the metrics is set to "accuracy" (imaging to develop it using Keras in Python), the problem is that our classifier will tend to classify wrong the classes with less data so to reach a high accuracy. However, I do not really focus one thing: how are weights influenced by this fact ? For example, if I extract weights vectors that represent data from my embedding layer and I plot them, because of imbalanced classes problem, the effect is that data are not well discriminated in the plot according to the different classes to which they belong ? I mean: in this plot I should observe four clusters but they are all overlapping because of the unbalanced classes problem ?

Maybe it is a trivial question, so I apologize and I thank you in advance.

Best Answer

In general association among trained weights and class-imbalance is not that easy to establish. One way to overcome the imbalanced dataset and see how weights are changing, is using transfer learning, use balanced version the train first and then continue training with the unbalanced version of the dataset. However, this is an open research, see a survey article Survey on deep learning with class imbalance.