Solved – Explanation of the ‘free bits’ technique for variational autoencoders

approximate-inferenceautoencodersautoregressivedeep learninginformation theory

I have been reading through a couple of papers on the variational autoencoder model: 'Variational Lossy Autoencoder' and 'Improving Variational Inference With Inverse Autoregressive Flow'. There is one (perhaps very obvious) thing that is confusing me – the former paper mentions the 'free bits' technique for training, and references the latter paper.

My question is – what is the 'free bits' technique?! As I understand, the IAF paper is concerned with allowing a more complex posterior distribution so as to better fit the true posterior, which in turn will improve the coding length of the VAE model, as detailed in the first paper. It is not clear to me what 'free bits' is referring to. Any help in understanding these papers better is appreciated!

Edit: references for the above papers:

  • Xi Chen, Diederik P. Kingma, Tim Salimans, Yan Duan, Prafulla Dhariwal, John Schulman, Ilya Sutskever, Pieter Abbeel "Variational Lossy Encoder", https://arxiv.org/abs/1611.02731.
  • Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling "Improving Variational Inference with Inverse Autoregressive Flow", https://arxiv.org/abs/1606.04934.

Best Answer

From what I can tell, and I'd love to be corrected as this seems quite interesting:

The `IAF' paper contains the relevant description of this 'free bits' method. In particular around equation (15). This identifies the term as relating to a modified objective function, "We then use the following objective, which ensures that using less than $\lambda$ nats of information per subset $j$ (on average per minibatch $M$) is not advantageous:"

$\tilde L_\lambda = E_{x∼M} E_{q(z|x)}[\log p(x|z)] - \sum_{j=1}^K \text{maximum}(λ, E_{x∼M} [D_{KL}(q(z_j |x)||p(z_j ))]) $

The $E_{x \sim M}$ notation is I believe $x$ within a minibatch $M$ and is related to the stochastic gradient ascent approach and hence not of the essence to your question. Let's ignore it, leaving:

$E_{q(z|x)}[\log p(x|z)] - \sum_{j=1}^K \text{maximum}(λ, D_{KL}(q(z_j |x)||p(z_j ))) $

They've split the latent variables into $K$ groups. As this seems to be icing on the cake, let's ignore it for the moment, leaving:

$E_{q(z|x)}[\log p(x|z)] - \text{maximum}(λ, D_{KL}(q(z |x)||p(z))) $

I think we're approaching ground here. If we dumped the maximisation we would be back to a vanilla Evidence Lower BOund (ELBO) criterion for Variational Bayes methods.

If we look at the expression $D_{KL}(q(z |x)||p(z ))$, this is the extra message length required to express a datum $z$ if the prior $p$ is used instead of the variational posterior $q$. This means that if our variational posterior is close to the prior in KL divergence, the $\max$ will take the value $\lambda$, and our current solution will be penalised more heavily than under a vanilla ELBO.

In particular if we have a solution where $D_{KL}<\lambda$ then we don't have to trade anything off in order to increase the model complexity a little bit (i.e. move $q$ further from $p$). I guess this is where they get the term 'free bits' from - increasing the model complexity for free, up to a certain point.

Bringing back the stuff we ignored: the summation over $K$ is establishing a complexity freebie quota per group of latent variables. This could be useful in some model where one group of parameters would otherwise hoover up the entire quota.

[EDIT] For instance suppose we had a (totally made up) model involving some 'filter weights' and some 'variances': if they were treated as sharing the same complexity quota, perhaps after training we would find that the 'variances' were still very close to the prior because the 'filter weights' had used up the free bits. By splitting the variables into two groups, we might be able to ensure the 'variances' also used some free bits / i.e. get pushed away from the prior. [/EDIT]

The expectation over $x$ in a minibatch - well I'm not as familiar with the notation - but from their quotation above the complexity quota is reset at the end of each mini batch.

[EDIT] Suppose we had a model where some of the latent variables $z$ were observation specific (think cluster indicators, matrix factors, random effects etc). Then for each observation we'd have a ration of something like $\lambda/N$ free bits. So as we got more data the ration would get smaller. By making $\lambda$ minibatch specific we could fix the ration size, so that even as more data came in overall the ration wouldn't go to zero. [/EDIT]

Related Question