Solved – R – Plotting Random Forest

rrandom forest

I'm working on doing an analysis on predicting wine quality based on a number of characteristics present in the wine. I'm grouping wine into 2 categories: 1 being 'High Quality' and 0 being 'Average Quality'.

After performing model selection to determine which model resulted in the lowest overall error rate, I'm wanting to use Random Forests in order to further my analysis. After determining the optimal value of mtry (what number of variables considered at each split produces the lowest overall error rate), I run a Random Forest on my data using that value.

My random forest results are as I expect, and I was curious if there is a good way to visualize a sample tree from the forest? Or any other visualizations that will help me explain the results (I'm already creating importance plots for each of the variables).

Any thoughts?

Results from determining most optimal mtry value:

enter image description here

Current Code:

set.seed(8, sample.kind = "Rounding")
wine.bag=randomForest(quality01 ~ alcohol + volatile_acidity + sulphates + residual_sugar + 
    chlorides + free_sulfur_dioxide + fixed_acidity + pH + density + 
    citric_acid,data=wine,mtry=3,importance=T)
wine.bag

plot(wine.bag)

importance(wine.bag)
varImpPlot(wine.bag)
test=wine[,c(-12,-13,-14)]
rest=cor(test)
corrplot(rest, type = "upper", order = "hclust", 
         tl.col = "black", tl.srt = 45)
```

Best Answer

This is how I got something similar to your data:

# download from https://rpubs.com/YasmeenMubarak/462147
wine=read.csv("wineQualityReds.csv",row.names=1)
wine$quality01 = factor(ifelse(wine$quality>=7,1,0))
wine$quality = NULL
set.seed(8, sample.kind = "Rounding")
wine.bag=randomForest(quality01 ~ .,data=wine,mtry=3,importance=T)

It's not so easy to visualize the whole tree, especially when you have a lot of variables. The partial dependence plot @carlo mentioned can be done like this using partialPlot from randomForest, and below I plot the dependency of label 1 on chlorides, using the training data:

partialPlot(wine.bag,wine,"chlorides",1)

enter image description here

One option is to plot the tree using this function from shirin's github , but I think you will need to refine the function a bit for your use. I just reproduce the function below:

library(dplyr)
library(ggraph)
library(igraph)

tree_func <- function(final_model,tree_num) {

  tree <- randomForest::getTree(final_model, 
                                k = tree_num, 
                                labelVar = TRUE) %>%
    tibble::rownames_to_column() %>%
    # make leaf split points to NA, so the 0s won't get plotted
    mutate(`split point` = ifelse(is.na(prediction), `split point`, NA))

  graph_frame <- data.frame(from = rep(tree$rowname, 2),
                            to = c(tree$`left daughter`, tree$`right daughter`))

  graph <- graph_from_data_frame(graph_frame) %>% delete_vertices("0")

  V(graph)$node_label <- gsub("_", " ", as.character(tree$`split var`))
  V(graph)$leaf_label <- as.character(tree$prediction)
  V(graph)$split <- as.character(round(tree$`split point`, digits = 2))

  plot <- ggraph(graph, 'dendrogram') + 
    theme_bw() +
    geom_edge_link() +
    geom_node_point() +
    geom_node_text(aes(label = node_label), na.rm = TRUE, repel = TRUE) +
    geom_node_label(aes(label = split), vjust = 2.5, na.rm = TRUE, fill = "white") +
    geom_node_label(aes(label = leaf_label, fill = leaf_label), na.rm = TRUE, 
                    repel = TRUE, colour = "white", fontface = "bold", show.legend = FALSE) +
    theme(panel.grid.minor = element_blank(),
          panel.grid.major = element_blank(),
          panel.background = element_blank(),
          plot.background = element_rect(fill = "white"),
          panel.border = element_blank(),
          axis.line = element_blank(),
          axis.text.x = element_blank(),
          axis.text.y = element_blank(),
          axis.ticks = element_blank(),
          axis.title.x = element_blank(),
          axis.title.y = element_blank(),
          plot.title = element_text(size = 18))

  return(plot)
}

ggsave(tree_func(wine.bag,1),file="test.png",width=12,height=8)

As you can see, it is really a bit cluttered...

enter image description here

Related Question