MCMC – Semi-Automating MCMC Convergence Diagnostics to Set Burn-In Length

bayesianmarkov-chain-montecarlor

I would like to automate the choice of burn-in for an MCMC chain, e.g. by removing the first n rows based on a convergence diagnostic.

To what extent can this step be safely automated? Even if I still double check the autocorrelation, mcmc trace, and pdfs, it would be nice to have the choice of burn-in length automated.

My question is general, but it would be great if you could provide specifics for dealing with an R mcmc.object; I am using the rjags and coda packages in R.

Best Answer

Here is one approach at the automation. Feedback much appreciated. This is an attempt to replace initial visual inspection with computation, followed by subsequent visual inspection, in keeping with standard practice.

This solution actually incorporates two potential solutions, first, calculate burn-in to remove the length of chain before some threshold is reached, and then using the autocorrelation matrix to calculate the thinning interval.

  1. calculate a vector of the maximum median Gelman-Rubin convergence diagnostic shrink factor (grsf) for all variables in the
  2. find the minimum number of samples at which the grsf across all variables goes below some threshold, e.g. 1.1 in the example, perhaps lower in practice
  3. sub sample the chains from this point to the end of the chain
  4. thin the chain using the autocorrelation of the most autocorrelated chain
  5. visually confirm convergence with trace, autocorrelation, and density plots

The mcmc object can be downloaded here: jags.out.Rdata

# jags.out is the mcmc.object with m variables
library(coda)    
load('jags.out.Rdata')
# 1. calculate max.gd.vec, 
# max.gd.vec is a vector of the maximum shrink factor
max.gd.vec     <- apply(gelman.plot(jags.out)$shrink[, ,'median'], 1, max)
# 2. will use window() to subsample the jags.out mcmc.object
# 3. start window at min(where max.gd.vec < 1.1, 100) 
window.start   <- max(100, min(as.numeric(names(which(max.gd.vec - 1.1 < 0)))))
jags.out.trunc <- window(jags.out, start = window.start)
# 4. calculate thinning interval
# thin.int is the chain thin interval
# step is very slow 
# 4.1 find n most autocorrelated variables
n = min(3, ncol(acm))
acm             <- autocorr.diag(jags.out.trunc)
acm.subset      <- colnames(acm)[rank(-colSums(acm))][1:n]
jags.out.subset <- jags.out.trunc[,acm.subset]
# 4.2 calculate the thinning interval
# ac.int is the time step interval for autocorrelation matrix
ac.int          <- 500 #set high to reduce computation time
thin.int        <- max(apply(acm2 < 0, 2, function(x) match(T,x)) * ac.int, 50)
# 4.3 thin the chain 
jags.out.thin   <- window(jags.out.trunc, thin = thin.int)
# 5. plots for visual diagnostics
plot(jags.out.thin)
autocorr.plot(jags.win.out.thin)

--update--

As implemented in R the computation of the autocorrelation matrix is slower than would be desirable (>15 min in some cases), to a lesser extent, so is computation of the GR shrink factor. There is a question about how to speed up step 4 on stackoverflow here

--update part 2--

additional answers:

  1. It is not possible to diagnose convergence, only to diagnose lack of convergence (Brooks, Giudici, and Philippe, 2003)

  2. The function autorun.jags from the package runjags automates calculation of run length and convergence diagnostics. It does not start monitoring the chain until the Gelman rubin diagnostic is below 1.05; it calculates the chain length using the Raftery and Lewis diagnostic.

  3. Gelman et al's (Gelman 2004 Bayesian Data Analysis, p. 295, Gelman and Shirley, 2010) state that they use a conservative approach of discarding the 1st half of the chain. Although a relatively simple solution, in practice this is sufficient to solve the issue for my particular set of models and data.


#code for answer 3
chain.length <- summary(jags.out)$end
jags.out.trunc <- window(jags.out, start = chain.length / 2)
# thin based on autocorrelation if < 50, otherwise ignore
acm <- autocorr.diag(jags.out.trunc, lags = c(1, 5, 10, 15, 25))
# require visual inspection, check acceptance rate
if (acm == 50) stop('check acceptance rate, inspect diagnostic figures') 
thin.int <- min(apply(acm2 < 0, 2, function(x) match(TRUE, x)), 50)
jags.out.thin <- window(jags.out.trunc, thin = thin.int)
Related Question