Solved – the relationship between VAE and EM algorithm

bayesianexpectation-maximizationinferenceneural networksvariational-bayes

What's the relationship between Variational Autoencoders and the Expectation Maximization Algorithm?

I know that the EM algorithm is used in latent variable models, specifically to do maximum likelihood estimation iteratively. Similarly, the VAE can be used for latent variable models and, although they are usually used for generative modelling or posterior inference, they can also be used for parameter inference. So I was wondering what's the relationship between them and when it's better to use one or the other.

Best Answer

What is the relationship between VAE and EM?

$\newcommand{\vect}[1]{\boldsymbol{\mathbf{#1}}} \newcommand{\vx}{\vect{x}} \newcommand{\vz}{\vect{z}} \newcommand{\vtheta}{\vect{\theta}} \newcommand{\Ebb}{\mathbb{E}} \newcommand{\vphi}{\vect{\phi}} \newcommand{L}{\mathcal{L}} \newcommand{\elbo}{L_{\vtheta, \vphi}(\vx)} \newcommand{\felbo}{L_{\vx}(\vtheta, q_{\vphi})}$

This answer is partially complete, but I've actually written a blog post about this that goes into the nitty-gritty details!

Notation

Observed data: $\mathcal{D} = \{\vx_1, \vx_2, \ldots, \vx_N\}$

Latent variables denoted by $\vz$.

Expectation Maximization Algorithm (Standard Version)

The EM algorithm is often (e.g. see Wikipedia) described as follows.

Start with a guess $\vtheta^{(0)}$, then until convergence:

  • Compute expectations $\Ebb_{p(\vz \mid \vx, \vtheta^{(t)})}[\log p_{\vtheta}(\vx, \vz)]$ for every data point $\vx\in \mathcal{D}$.
  • Choose parameter value $\vtheta^{(t+1)}$ to maximize expectations $$ \vtheta^{(t+1)} = \arg\max_{\vtheta} \sum_{\vx\in\mathcal{D}}\Ebb_{p(\vz \mid \vx, \vtheta^{(t)})}[\log p_{\vtheta}(\vx, \vz)] $$

Expectation-Maximization Algorithm (Rewritten)

One can rewrite the algorithm above in a slightly different way. Rather than computing expectations in the first step, we compute the distributions $p(\vz\mid, \vx, \vtheta^{(t)})$. The EM algorithm then looks as follows:

Start with a guess $\vtheta^{(0)}$, until convergence:

  • Compute distributions $\left\{p(\vz\mid, \vx, \vtheta^{(t)}) \, : \, \vx \in \mathcal{D}\right\}$
  • Choosen new parameter value in the same way as before $$ \vtheta^{(t+1)} = \arg\max_{\vtheta} \sum_{\vx\in\mathcal{D}}\Ebb_{p(\vz \mid \vx, \vtheta^{(t)})}[\log p_{\vtheta}(\vx, \vz)] $$

Variational Autoencoders

Why did I rewrite it like that? Because one can write the ELBO, which is usually considered as a function of $\vx$ parametrized by $\vtheta$ and $\vphi$ ($\vphi$ are the parameters of the encoder $q_{\vphi}$), as a functional of $q_{\vphi} and a function of $\vtheta$ that is parameterized by $\vx$ (indeed the data is fixed). This means the ELBO can be written as:

\begin{equation*} \mathcal{L}_{\boldsymbol{\mathbf{x}}}(\boldsymbol{\mathbf{\theta}}, q_{\boldsymbol{\mathbf{\phi}}})= \begin{cases} \displaystyle \log p_{\boldsymbol{\mathbf{\theta}}}(\boldsymbol{\mathbf{x}})- \text{KL}(q_{\boldsymbol{\mathbf{\phi}}}\,\,||\,\,p_{\boldsymbol{\mathbf{\theta}}}(\boldsymbol{\mathbf{z}}\mid \boldsymbol{\mathbf{x}})) \qquad \qquad &(1)\\ \qquad \\ \displaystyle \mathbb{E}_{q_{\boldsymbol{\mathbf{\phi}}}}[\log p_{\boldsymbol{\mathbf{\theta}}}(\boldsymbol{\mathbf{x}}, \boldsymbol{\mathbf{z}})] - \mathbb{E}_{q_{\boldsymbol{\mathbf{\phi}}}}[\log q_{\boldsymbol{\mathbf{\phi}}}] \qquad \qquad &(2) \end{cases} \end{equation*}

Now we can find two identical steps as those of the EM algorithm by performing maximization of the ELBO with respect to $q_{\vphi}$ first, and then with respect to $\vtheta$

  • E-step: Maximize $(1)$ with respect to $q_{\vphi}$ (this makes the KL-divergence zero and the bound is tight) $$ \left\{p_{\boldsymbol{\mathbf{\theta}}^{(t)}}(\boldsymbol{\mathbf{z}}\mid \boldsymbol{\mathbf{x}})= \arg\max_{q_{\boldsymbol{\mathbf{\phi}}}} \mathcal{L}_{\boldsymbol{\mathbf{x}}}(\boldsymbol{\mathbf{\theta}}^{(t)}, q_{\boldsymbol{\mathbf{\phi}}})\,\, : \,\, \boldsymbol{\mathbf{x}}\in\mathcal{D}\right\} $$
  • M-step: Maximize $(2)$ with respect to $\vtheta$ $$ \boldsymbol{\mathbf{\theta}}^{(t+1)} = \arg\max_{\boldsymbol{\mathbf{\theta}}} \sum_{\boldsymbol{\mathbf{x}}\in\mathcal{D}} \mathcal{L}_{\boldsymbol{\mathbf{x}}}(\boldsymbol{\mathbf{\theta}}, p_{\boldsymbol{\mathbf{\theta}}^{(t)}}(\boldsymbol{\mathbf{z}}\mid \boldsymbol{\mathbf{x}})) $$

The relationship between the Expectation Maximization algorithm and Variational Auto-Encoders can therefore be summarized as follows:

  • EM algorithm and VAE optimize the same objective function.

  • When expectations are in closed-form, one should use the EM algorithm which uses coordinate ascent.

  • When expectations are intractable, VAE uses stochastic gradient ascent on an unbiased estimator of the objective function.

Related Question