Solved – KL divergence between two bivariate Gaussian distribution

bivariatekullback-leiblernormal distribution

KL divergence between two multivariate Gaussians and univariate Gaussians have been discussed. I was wondering if there exists a simpler computation for the KL divergence between two bivariate Gaussians in terms of their means, variances and correlation coefficient without using the more general multivariate form.

Best Answer

We have for two $d$ dimensional multivariaiate Gaussian distributions $P = \mathcal{N}(\mu, \Sigma)$ and $Q = \mathcal{N}(m, S)$ that

$$\DeclareMathOperator{\tr}{tr} \mathbb{D}_{\textrm{KL}}(P \Vert Q) = \frac{1}{2} \left( \tr(S^{-1}\Sigma) - d + (m - \mu)S^{-1}(m-\mu) + \log\frac{|S|}{|\Sigma|} \right). $$

For the bivariate case i.e. $d=2$, parameterising in terms of the component means, standard deviations and correlation coefficients we define the mean vectors and covariance matrices as

$$ \mu = \begin{pmatrix} \mu_1\\ \mu_2 \end{pmatrix},~ \Sigma = \begin{pmatrix} \sigma_1^2 & \rho\sigma_1\sigma_2 \\ \rho\sigma_1\sigma_2 & \sigma_2^2 \end{pmatrix} \quad\textrm{and}\quad m = \begin{pmatrix} m_1 \\ m_2 \end{pmatrix},~ S = \begin{pmatrix} s_1^2 & r s_1 s_2 \\ r s_1 s_2 & s_2^2 \end{pmatrix}. $$

Using the definitions of the determinant and inverse of $2\times 2$ matrices we have that

$$ |\Sigma| = \sigma_1^2\sigma_2^2(1-\rho^2),~ |S| = s_1^2 s_2^2 (1 - r^2) ~\textrm{and}~ S^{-1} = \frac{1}{s_1^2 s_2^2 (1 - r^2)} \begin{pmatrix} s_2^2 & -r s_1 s_2 \\ -r s_1 s_2 & s_1^2 \end{pmatrix}. $$

Substituting these terms in to the above and simplifying gives

\begin{align} \mathbb{D}_{\textrm{KL}}(P \Vert Q) = &\, \frac{1}{2(1-r^2)} \left( \frac{(\mu_1-m_1)^2}{s_1^2} - 2r \frac{(\mu_1-m_1)(\mu_2-m_2)}{s_1 s_2} + \frac{(\mu_2-m_2)^2}{s_2^2} \right) +\,\\ &\, \frac{1}{2(1-r^2)} \left( \frac{\sigma_1^2-s_1^2}{s_1^2} - 2r \frac{\rho\sigma_1\sigma_2 - r s_1 s_2}{s_1 s_2} + \frac{\sigma_2^2-s_2^2}{s_2^2} \right) +\, \\ &\, \log\left( \frac{s_1 s_2 \sqrt{1-r^2}}{\sigma_1\sigma_2\sqrt{1-\rho^2}} \right). \end{align}

This can be verified with SymPy as follows

from sympy import *
d = 2
s1, s2, r, m1, m2 = symbols('s_1 s_2 r m_1 m_2')
sigma1, sigma2, rho, mu1, mu2 = symbols(r'\sigma_1 \sigma_2 \rho \mu_1 \mu_2')
m = Matrix([m1, m2])
S = Matrix([[s1**2, r*s1*s2], [r*s1*s2, s2**2]])
mu = Matrix([mu1, mu2])
Sigma = Matrix([[sigma1**2, rho*sigma1*sigma2], [rho*sigma1*sigma2, sigma2**2]])
lhs = (
    trace(S**(-1) * Sigma) - d + 
    ((m - mu).T * S**(-1) * (m - mu))[0] +
    log(det(S) / det(Sigma))
) / 2
rhs = (
    ((mu1-m1)**2/s1**2 - 2*r*(mu1-m1)*(mu2-m2)/(s1*s2) + (mu2-m2)**2/s2**2) /
    (2 * (1 - r**2)) +
    ((sigma1**2-s1**2)/s1**2 - 2*r*(rho*sigma1*sigma2-r*s1*s2)/(s1*s2) + 
     (sigma2**2-s2**2)/s2**2) /
    (2 * (1 - r**2)) +
    log((s1**2 * s2**2 * (1-r**2)) / (sigma1**2 * sigma2**2 * (1-rho**2))) / 2
)
simplify(lhs - rhs) == 0