Solved – SciKit Learn get feature importance for multiclass classification using Decision Tree

cartclassificationmachine learningpredictive-modelsscikit learn

I am using Scikit-learn for a multiclass classification task and would like to find out what are the most important features for each class. I have three classes (say class_a, class_b and class_c), and achieved Cross Validation scores of around 70% using Decision Tree classifier.

There is a feature_importances_ attribute in DecisionTree which allows me to find out the feature importance of each feature. However, how can I tell which feature is important for which class? I know that the feature importance is computed by the Gini score, but how do I know if a certain feature with a high score is important to distinguish say class_a from class_b and class_c? or is my understanding of feature importance completely wrong?

I would appreciate it very much if someone could shed some light on this. I have been poring over Scikit learn documentation and other sources but couldn't seem to find an answer to this, which makes me suspect there are some flaws in my understanding.

Thank you!

Best Answer

You could modify your problem by using multiple one-vs-rest classifiers. For example train a classifier to distinguish between (1) class_a and (2) rest. Then you can access the feature importance. Nonetheless, the feature importance is not the importance of a feature for a certain class, but a measure for the usability of a single feature to distinguish two classes (here one-vs-rest).

Therefore your the feature importance attribute does not answer the question "Which feature is important for a class?" but rather answers the question "Which feature helps best at distinguishing class_a from other classes present".