Solved – Visualising a linear model with 6 predictors in R

data visualizationrregression

I've recently been wondering how to visualise a linear model with >1 predictor. Today I came across a paper that managed to visualise a linear model with 6 predictors (and one response variable) in a scatterplot.

The model was similar to this:

weight ~ days since 1st Sept + site + age + year + PC1 + arrival date

Weight was a continuous response variable. Days since 1st Sept was a integer predictor, site was a factor predictor (2 levels), age was a factor predictor (2 levels), year was a factor predictor (2 levels), PC1 was a continuous predictor and arrival date was a integer predictor. 

On the y axis of the scatterplot was values of weight predicted by the model. On the x axis of the scatterplot was days since 1st Sept. Eight lines of best fit showed predicted values at each level of the three factor predictors. PC1 and arrival date were 'held constant at their mean values'. 

The plot looked like this (made using Photoshop):

enter image description here

I'm hoping someone can recreate a plot similar to this using a dataset installed in R? I'm particularly interested in how to keep two predictors 'constant at their mean values', to enable visualisation of the other predictors. The mtcars dataset might be appropriate.

Best Answer

Here is some code that is hopefully self-explanatory:

set.seed(20987)     # for reproducability

N = 200

  # variables
days_since   = rpois(N, lambda=60)
site         = factor(sample(c("site1", "site2"), N, replace=T), c("site1", "site2"))
age          = factor(sample(c("juv", "adult"),   N, replace=T), c("juv", "adult"))
year         = factor(sample(c("2012", "2013"),   N, replace=T), c("2012", "2013"))
PC1          = rnorm(N, mean=100, sd=25)
arrival_date = sample.int(365, N, replace=T)

  # betas
B0  =  13
Bds =  74
Bs  = 114
Ba  = 160
By  = 191
Bpc =  59
Bad =  11

  # response variable
weight = B0 + Bds*days_since + Bs*(site=="site2") + Ba*(age=="adult") + 
         By*(year=="2013") + Bpc*PC1 + Bad*arrival_date + rnorm(N, mean=0, sd=10)

model = lm(weight~days_since+site+age+year+PC1+arrival_date)

  # predicted values for plot
ds    = seq(min(days_since), max(days_since))
ds1j2 = predict(model, data.frame(days_since=ds, site="site1", age="juv",   
                       year="2012", PC1=mean(PC1), arrival_date=mean(arrival_date)))
ds1j3 = predict(model, data.frame(days_since=ds, site="site1", age="juv",   
                       year="2013", PC1=mean(PC1), arrival_date=mean(arrival_date)))
ds1a2 = predict(model, data.frame(days_since=ds, site="site1", age="adult", 
                       year="2012", PC1=mean(PC1), arrival_date=mean(arrival_date)))
ds1a3 = predict(model, data.frame(days_since=ds, site="site1", age="adult", 
                       year="2013", PC1=mean(PC1), arrival_date=mean(arrival_date)))
ds2j2 = predict(model, data.frame(days_since=ds, site="site2", age="juv",   
                       year="2012", PC1=mean(PC1), arrival_date=mean(arrival_date)))
ds2j3 = predict(model, data.frame(days_since=ds, site="site2", age="juv",   
                       year="2013", PC1=mean(PC1), arrival_date=mean(arrival_date)))
ds2a2 = predict(model, data.frame(days_since=ds, site="site2", age="adult", 
                       year="2012", PC1=mean(PC1), arrival_date=mean(arrival_date)))
ds2a3 = predict(model, data.frame(days_since=ds, site="site2", age="adult", 
                       year="2013", PC1=mean(PC1), arrival_date=mean(arrival_date)))

  # plot
windows()
  plot(x=ds, y=ds1j2, ylim=c(11000, 14500), type="l", lty=1,
       ylab="predicted weight", xlab="days since 1st Sept")
                                points(range(ds), range(ds1j2), pch=5)
  lines(x=ds, y=ds1j3, lty=1);  points(range(ds), range(ds1j3), pch=18)
  lines(x=ds, y=ds1a2, lty=2);  points(range(ds), range(ds1a2), pch=5)
  lines(x=ds, y=ds1a3, lty=2);  points(range(ds), range(ds1a3), pch=18)
  lines(x=ds, y=ds2j2, lty=1);  points(range(ds), range(ds2j2), pch=1)
  lines(x=ds, y=ds2j3, lty=1);  points(range(ds), range(ds2j3), pch=16)
  lines(x=ds, y=ds2a2, lty=2);  points(range(ds), range(ds2a2), pch=1)
  lines(x=ds, y=ds2a3, lty=2);  points(range(ds), range(ds2a3), pch=16)

  legend("bottomright", lty=rep(1:2, 4), pch=c(5,18,5,18,1,16,1,16), 
         legend=c("site 1, juveniles, 2012", "site 1, juveniles, 2013", 
                  "site 1, adults,    2012", "site 1, adults,    2013", 
                  "site 2, juveniles, 2012", "site 2, juveniles, 2013", 
                  "site 2, adults,    2012", "site 2, adults,    2013")) 

You can write code that's much shorter by writing functions that will read in a list and do all of this for you rather than copying and pasting the same thing eight times in a row, but this should be easier to follow. Here is the plot:

enter image description here

This kind of plot is more interesting / useful when there are interactions (the lines aren't parallel). In this case, we just have a set of eight lines that are shifted vertically relative to each other.

Related Question