I am trying to implement Gaussian Mixture model with stochastic variational inference, following this paper.
This is the pgm of Gaussian Mixture.
According to the paper, the full algorithm of stochastic variational inference is:
And I am still very confused of the method to scale it to GMM.
First, I thought the local variational parameter is just $q_z$ and others are all global parameters. Please correct me if I was wrong. What does the step 6 mean by as though Xi is replicated by N times
? What am I supposed to do to achieve this?
Could you please help me with this? Thanks in advance!
Best Answer
First, a few notes that help me make sense of the SVI paper:
In a mixture of $k$ Gaussians, our global parameters are the mean and precision (inverse variance) parameters $\mu_k, \tau_k$ params for each. That is, $\eta_g$ is the natural parameter for this distribution, a Normal-Gamma of the form
$$\mu, \tau \sim N(\mu|\gamma, \tau(2\alpha -1)Ga(\tau|\alpha, \beta)$$
with $\eta_0 = 2\alpha - 1$, $\eta_1 = \gamma*(2\alpha -1)$ and $\eta_2 = 2\beta+\gamma^2(2\alpha-1)$. (Bernardo and Smith, Bayesian Theory; note this varies a little from the four-parameter Normal-Gamma you'll commonly see.) We'll use $a, b, m$ to refer to the variational parameters for $\alpha, \beta, \mu$
The full conditional of $\mu_k, \tau_k$ is a Normal-Gamma with params $\dot\eta + \langle\sum_Nz_{n,k}$, $\sum_N z_{n,k}x_N$, $\sum_Nz_{n,k}x^2_{n}\rangle$, where $\dot\eta$ is the prior. (The $z_{n,k}$ in there can also be confusing; it makes sense starting with an $\exp\ln(p))$ trick applied to $\prod_N p(x_n|z_n, \alpha, \beta, \gamma) = \prod_N\prod_K\big(p(x_n|\alpha_k,\beta_k,\gamma_k)\big)^{z_{n,k}}$, and ending with a fair amount of algebra left to the reader.)
With that, we can complete step (5) of the SVI pseudocode with:
$$\phi_{n,k} \propto \exp (ln(\pi) + \mathbb E_q \ln(p(x_n|\alpha_k, \beta_k, \gamma_k))\\ =\exp(\ln(\pi) + \mathbb E_q \big[\langle \mu_k\tau_k, \frac{-\tau}{2} \rangle \cdot\langle x, x^2\rangle - \frac{\mu^2\tau - \ln \tau}{2})\big] $$
Updating the global parameters is easier, since each parameter corresponds to a count of the data or one of its sufficient statistics:
$$ \hat \lambda = \dot \eta + N\phi_n \langle 1, x, x^2 \rangle $$
Here's what the marginal likelihood of data looks like over many iterations, when trained on very artificial, easily separable data (code below). The first plot shows the likelihood with initial, random variational parameters and $0$ iterations; each subsequent is after the next power of two iterations. In the code, $a, b, m$ refer to variational parameters for $\alpha, \beta, \mu$.