Classification – Multi-Class Classification with Imbalanced Classes: Techniques and Strategies

classificationunbalanced-classes

I have a data from 5 classes and I would like to build a classifier. However the number of feature vectors in each class is very different. One has about 5000, one about 200,000, one about 1,000,000, one about 10,000,000 and one about 1,000,000,000.

As the largest class is too large to build a classifier with I will have to down sample it in any case.

I am currently using scikit learn and Random Forests although I can use another tool if that would be better. IF it were a binary classification problem I could have trained with balanced classes and computed the ROC curve to get the false positive rate I can tolerate. However I have no idea what the right thing to do in this multiclass case is.

Are there best practice recommendations for what to do in practice in
this situation? I don't want the classifier to simply ignore one of
the classes for example.

Best Answer

As your class sizes are so big. I would perform a pre-downsampling to something like 5000+10000+10000+10000+10000. Do you really need more samples? Then downsample again and model independently and aggregate multiple forests afterwards. That will save time and memory. During modeling you may even only bootstrap ~5000 samples for each tree to speedup process. For each tree the bootstrap can be stratified, such that 1000 samples from each class are selected.

Here's a thread on how to train a balanced multi class forest with down sampling and 1-vs-rest ROC plot.

And here's a R-code example on 1-vs-rest roc plots:

library(AUC)
#simulated probabilistic prediction(yhat) vs true class (y)
obs=500
nClass=5
y = sample(1:nClass,obs,rep=T)
yhat = sapply(y,function(y) {
  pred.prob = rep(0,nClass)
  pred.prob[y] = 0.2
  pred.prob = pred.prob + runif(nClass)
  pred.prob = pred.prob / sum(pred.prob)
})

#plot 1-vs-all, one curve for each class
for(i in 1:nClass) plot(roc(predictions = yhat[i,],
                        labels = as.factor(y==i)),
                        add=i!=1,
                        col=i)

enter image description here