Random Forest Predictors – Relative Importance of Predictors in Random Forest Classification in R

classificationmachine learningrrandom forest

I'd like to determine the relative importance of sets of variables toward a randomForest classification model in R. The importance function provides the MeanDecreaseGini metric for each individual predictor–is it as simple as summing this across each predictor in a set?

For example:

# Assumes df has variables a1, a2, b1, b2, and outcome
rf <- randomForest(outcome ~ ., data=df)
importance(rf)
# To determine whether the "a" predictors are more important than the "b"s,
# can I sum the MeanDecreaseGini for a1 and a2 and compare to that of b1+b2?

Best Answer

First I would like to clarify what the importance metric actually measures.

MeanDecreaseGini is a measure of variable importance based on the Gini impurity index used for the calculation of splits during training. A common misconception is that the variable importance metric refers to the Gini used for asserting model performance which is closely related to AUC, but this is wrong. Here is the explanation from the randomForest package written by Breiman and Cutler:

Gini importance
Every time a split of a node is made on variable m the gini impurity criterion for the two descendent nodes is less than the parent node. Adding up the gini decreases for each individual variable over all trees in the forest gives a fast variable importance that is often very consistent with the permutation importance measure.

The Gini impurity index is defined as $$ G = \sum_{i=1}^{n_c} p_i(1-p_i) $$ Where $n_c$ is the number of classes in the target variable and $p_i$ is the ratio of this class.

For a two class problem, this results in the following curve which is maximized for the 50-50 sample and minimized for the homogeneous sets: Gini impurity for 2 class

The importance is then calculated as $$ I = G_{parent} - G_{split1} - G_{split2} $$ averaged over all splits in the forest involving the predictor in question. As this is an average it could easily be extended to be averaged over all splits on variables contained in a group.

Looking closer we know each variable importance is an average conditional on the variable used and the meanDecreaseGini of the group would just be the mean of these importances weighted on the share this variable is used in the forest compared to the other variables in the same group. This holds because the the tower property $$ \mathbb{E}[\mathbb{E}[X|Y]] = \mathbb{E}[X] $$

Now, to answer your question directly it is not as simple as just summing up all importances in each group to get the combined MeanDecreaseGini but computing the weighted average will get you the answer you are looking for. We just need to find the variable frequencies within each group.

Here is a simple script to get these from a random forest object in R:

var.share <- function(rf.obj, members) {
  count <- table(rf.obj$forest$bestvar)[-1]
  names(count) <- names(rf.obj$forest$ncat)
  share <- count[members] / sum(count[members])
  return(share)
}

Just pass in the names of the variables in the group as the members parameter.

I hope this answers your question. I can write up a function to get the group importances directly if it is of interest.

EDIT:
Here is a function that gives the group importance given a randomForest object and a list of vectors with variable names. It uses var.share as previously defined. I have not done any input checking so you need to make sure you use the right variable names.

group.importance <- function(rf.obj, groups) {
  var.imp <- as.matrix(sapply(groups, function(g) {
    sum(importance(rf.obj, 2)[g, ]*var.share(rf.obj, g))
  }))
  colnames(var.imp) <- "MeanDecreaseGini"
  return(var.imp)
}

Example of usage:

library(randomForest)                                                          
data(iris)

rf.obj <- randomForest(Species ~ ., data=iris)

groups <- list(Sepal=c("Sepal.Width", "Sepal.Length"), 
               Petal=c("Petal.Width", "Petal.Length"))

group.importance(rf.obj, groups)

>

      MeanDecreaseGini
Sepal         6.187198
Petal        43.913020

It also works for overlapping groups:

overlapping.groups <- list(Sepal=c("Sepal.Width", "Sepal.Length"), 
                           Petal=c("Petal.Width", "Petal.Length"),
                           Width=c("Sepal.Width", "Petal.Width"), 
                           Length=c("Sepal.Length", "Petal.Length"))

group.importance(rf.obj, overlapping.groups)

>

       MeanDecreaseGini
Sepal          6.187198
Petal         43.913020
Width          30.513776
Length        30.386706