Solved – lasso regression on top of random forest

lassorandom forest

Some time ago, I found a paper describing usage of lasso/elastic net regression on binary variables come from random forest. In short, (i,j)-th variable takes 1 if given observation belong to leaf no. i inside tree no. j. Because this will result in a huge number of variables, authors suggested using lasso regression to obtain sparse and interpretable solution.

Does somebody know if such a procedure is implemented in R/Python or can (at least) post a link to this article (or similar). Thanks in advance.

Best Answer

First I think it is hard to say one model out "perform" another. Each model has different pros and cons and should be applied to different cases. For example, I would not say random forest outperforms linear regression, because linear regression is 1. more "stable" 2. requires less computational power 3. more interpretable, plus, if you ground truth between feature and value is really linear, no one can beat linear regression.

Now, back to your question, on code to try two approaches.

You can easily to do the experiment with both way and compare the performance. The trick is using model.matrix in R. Here is one example from ISL book to use model.matrix to convert factors to design matrix and use ridge or lasso.

# Chapter 6 Lab 2 of ISL book: Ridge Regression and the Lasso
library(ISLR)
library(glmnet)
Hitters=na.omit(Hitters)

# transfer formula input to matrix input
x=model.matrix(Salary~.,Hitters)[,-1]
y=Hitters$Salary

set.seed(1)
train=sample(1:nrow(x), nrow(x)/2)
test=(-train)
y.test=y[test]

grid=10^seq(10,-2,length=100)

# The Lasso
lasso.mod=glmnet(x[train,],y[train],alpha=1,lambda=grid)
plot(lasso.mod)
set.seed(1)
cv.out=cv.glmnet(x[train,],y[train],alpha=1)
plot(cv.out)

# get best fit lamda and fit all data
bestlam=cv.out$lambda.min
lasso.pred=predict(lasso.mod,s=bestlam,newx=x[test,])
mean((lasso.pred-y.test)^2)
out=glmnet(x,y,alpha=1,lambda=grid)

On the other hand, you can easily do randomForest like

randomforest(Salary~.,data=Hitters)

Related Question