Marginal likelihood in variational bayes

bayes-theorembayesianmachine learningstatistical-inferencestatistics

In this paper, https://arxiv.org/abs/1312.6114 I have some questions regarding equation 1 and 2

equation one starts with

The marginal likelihood is composed of a sum over the marginal lieklihoods of individual datapoints $\log p_\theta(x^{(1)}, …, x^{(n)}) = \sum_{i=1}^N \log p_\theta(x^{(i)})$ which can be rewritten as…

$$
\log p_\theta(x^{(i)}) = D_{KL}(q_\phi(z|x^{(i)}) || p_\theta(z|x^{(i)}))+ \mathcal{L} (\theta, \phi; x^{(i)})
$$

How is this the same as marginal likelihood. I've been looking at this equation for quite some time and I can't reason through it like I can with standard marginal likelihood.

The only thing I have been able to deduce from it (I think) is that the two terms on the RHS of the equation will be negatively correlated because as the divergence gets smaller, the likelihood of them both goes up and vice versa, correct?

In the second equation, the author makes a statement about the likelihood and I also don't see how that came about…

$$
\log p_\theta(x^{(i)}) \ge \mathcal{L}(\theta, \phi; x^{(i)}) = \mathbb{E}_{q_\phi(z|x^{(i)})} \big[ -\log q_\phi(z|x) + \log p_\theta(x,z) \big]
$$

Why is this equal to the joint likelihood of $\theta$ and $\phi$? Why is the input to $q_\theta$ different (joint) than that of $q_\phi$ which is conditional?

Best Answer

This is standard stuff from variational inference, which it might be helpful to look into more in-depth as it appears more often than just in VAEs (e.g., in variational Bayesian NNs). Luckily we can get it all fairly simply with some easy probability theory.

The idea is that the true posterior $p_\theta(z|x_i)$ is too difficult to compute directly (i.e., via Bayes rule), so we instead approximate it with $q_\phi(z|x_i)$. We then optimize $\phi$ instead of working with probabilities directly. Mathematically we have that $$ p_\theta(z|x) = \frac{p_\theta(x|z) p_\theta(z)}{p_\theta(x)} = \frac{p_\theta(x|z) p_\theta(z)}{\int p_\theta(x|z) p_\theta(z) dz} \approx q_\phi(z|x) $$ It turns out that the log-marginal likelihood can be expressed via: $$ \log p_\theta(x) = \mathcal{D}_\text{KL}\left[ q_\phi(z|x)\mid\mid p_\theta(z|x) \right] + \mathcal{L}(\theta,\phi\mid x) \tag{1} $$ where $\mathcal{L}$ is the evidence lower bound (ELBO), which may be written \begin{align} \mathcal{L}(\theta,\phi\mid x) &= \mathbb{E}_{q_\phi(z|x)}\left[ -\log q_\phi(z|x) + \log p_\theta(x,z) \right] \tag{a} \\ &= \int q_\phi(z|x)\left[ -\log q_\phi(z|x) + \log p_\theta(x|z) + \log p_\theta(z) \right]dz \\ &= \int q_\phi(z|x) \log\left(\frac{p_\theta(z)}{q_\phi(z|x)}\right) dz + \int q_\phi(z|x) \log p_\theta(x|z)\, dz \\ &= -\mathcal{D}_\text{KL}\left[ q_\phi(z|x) \mid\mid p_\theta(z) \right] + \mathbb{E}_{q_\phi(z|x)}\left[ \log p_\theta(x|z) \right] \tag{b} \end{align} where (a) is the form in eq (2) in the paper and (b) is the form in eq (3) in the paper. We used the fact that $$ \mathcal{D}_\text{KL}\left[ p_\xi(y) \mid\mid p_\eta(y) \right] = \int \log\left(\frac{p_\xi(y)}{p_\eta(y)}\right) p_\xi(y) \,dy = \mathbb{E}_{p_\xi(y)}\left[ \log\frac{p_\xi(y)}{p_\eta(y)} \right] $$

Ok, but we still need to derive equation (1): \begin{align} \log p_\theta(x) &= \mathbb{E}_{q_\phi(z|x)}\left[ \log p_\theta(x) \right] \\ &= \mathbb{E}_{q_\phi(z|x)}\left[ \log\left( \frac{p_\theta(x,z)}{p_\theta(z|x)} \frac{q_\phi(z|x)}{q_\phi(z|x)} \right) \right] \\ &= \underbrace{ \mathbb{E}_{q_\phi(z|x)}\left[ \log\left( \frac{q_\phi(z|x)}{p_\theta(z|x)} \right) \right] }_{ \mathcal{D}_\text{KL}[q_\phi(z|x)\mid\mid p_\theta(z|x)]} + \underbrace{ \mathbb{E}_{q_\phi(z|x)}\left[ \log\left( \frac{p_\theta(x,z)}{q_\phi(z|x)} \right) \right]}_{\text{Exactly (a) above}} \\ &= \mathcal{D}_\text{KL}[q_\phi(z|x)\mid\mid p_\theta(z|x)] + \mathcal{L}(\theta,\phi\mid x) \end{align} where the first step (taking the expectation) used the fact that there is no dependence on $z$ in the marginal and the second used the identity $p_\theta(x,z) = p_\theta(z|x) p_\theta(x)$. This gives us equation (1) from the paper.

Next we derive the "lower-bounding" part (equation 2 in the paper).

Firstly, though, I want to say that because the KL-divergence is non-negative, we must immediately have equation (2), $$ \log p_\theta(x) \geq \mathcal{L}(\theta,\phi\mid x), $$ trivially from equation (1).

However, it's more common to see people get to (2) as follows. Recall Jensen's inequality, which states that for a concave function $f$ (like log), we get that $f(\mathbb{E}[y]) \geq \mathbb{E}[f(y)]$, so we see that $$ \log p(x) = \log\int p(x,y) \frac{q(y)}{q(y)} dy = \log \mathbb{E}_{q(y)}\left[ \frac{p(x,y)}{q(y)} \right] \geq \mathbb{E}_{q(y)}\left[ \log\left( \frac{p(x,y)}{q(y)} \right) \right] \tag{JI} $$ Using this, we get \begin{align} \log p_\theta(x) &= \log \int p_\theta(x,z) dz \\ &= \log\int q_\phi(z|x)\frac{p_\theta(x,z)}{q_\phi(z|x) } dz \\ &= \log \mathbb{E}_{q_\phi(z|x)}\left[\frac{p_\theta(x,z)}{q_\phi(z|x) }\right] \\ &\geq \mathbb{E}_{q_\phi(z|x)}\left[ \log \left(\frac{p_\theta(x,z)}{q_\phi(z|x) }\right)\right] \;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\;\text{From (JI)}\\ &= \mathbb{E}_{q_\phi(z|x)}\left[ \log p_\theta(x,z) - \log q_\phi(z|x) \right] \end{align} which is exactly equation (a).


As to your comments:

How is this the same as marginal likelihood. I've been looking at this equation for quite some time and I can't reason through it like I can with standard marginal likelihood.

As noted in the derivation, it can be interpreted as approximating the true posterior with a variational distribution. The reasoning is then that we decompose into two terms: one basically optimizes the likelihood using the approximate distribution and the other forces the approximation to match the true posterior.

The only thing I have been able to deduce from it (I think) is that the two terms on the RHS of the equation will be negatively correlated because as the divergence gets smaller, the likelihood of them both goes up and vice versa, correct?

I am not sure if that is the case. I think it is possible to choose $q_\phi$ such that the first term gets much worse but the second improves. But I am not sure on that.

[For the second equation,] Why is this equal to the joint likelihood of 𝜃 and 𝜙? Why is the input to $p_\theta$ different (joint) than that of $q_\phi$ which is conditional?

Well perhaps the derivation cleared this up for you, but $x$ is fixed beforehand on both sides of the equation, whereas the expectation is taken over $z\sim q_\phi(z|x)$ on the RHS. So both $x$ and $z$ are well-defined.


One thing to notice is that since the KL-divergence is non-negative, the ELBO $\mathcal{L}$ is always a lower bound on the log-marginal-likelihood. Maximizing it guarantees we also "push up" the true marginal log-likelihood. Another thing is that when the variational approximate posterior ($q_\phi$) is perfect (meaning it matches the true posterior exactly, so the KL vanishes), we get that $\log p_\theta(x) = \mathcal{L}(\theta,\phi\mid x)$, so that the ELBO exactly is the log-marginal likelihood. In other words, we are simultaneously (1) optimizing $q_\phi$ to match the true posterior and (2) optimizing the marginal likelihood. It happens that as you do (1) (improving $q_\phi$), (2) will also become easier (as the ELBO $\mathcal{L}$ will approach the true marginal likelihood)!

The even more common interpretation comes from equation (b), which states that we should choose a latent prior distribution $p_\theta(z)$, and force our approximate posterior to match it, while also optimizing the conditional likelihood (usually some kind of reconstruction error). Hence we get the VAE as a kind of regularized stochastic autoencoder.

Related Question