Solved – Trying to implement the Jensen-Shannon Divergence for Multivariate Gaussians

entropyinformation theorymutual informationnormal distribution

Given two multivariate Gaussian distributions $P \equiv \mathcal{N}(\mu_p, \Sigma_p)$ and $Q \equiv \mathcal{N}(\mu_q, \Sigma_q)$, I am trying to calculate the Jensen-Shannon divergence between them.

I am following JSD-discussion for multivariate Gaussian in this discussion in this question. There it is suggested that one can approximate the midpoint measure $M$ using Monte Carlo sampling.

Specifically, it is pointed out that the JSD for continuous RVs (in my case Gaussian), is given by

$$
\mathrm{JSD} = \frac{1}{2} (D_{KL}(P\,\|M)+ D_{KL}(Q\|M)) = h(M) – \frac{1}{2} (h(P) + h(Q)) \>,
$$

where $h(P)$ and $h(Q)$ are just the differential entropies for the MVN. These properties are well known and we can calculate them easily, e.g.

$$
h(P) = \frac{1}{2} \log_2\big((2\pi e)^n |\Sigma_p|\big)
$$

What is causing me trouble is $M$. I believe I have misunderstood/not-implemented-correctly the Monte Carlo estimate for it.

User FrankD says that for the JSD approximation:
$$
JSD(P\|Q) = \frac{1}{2} (D_{KL}(P\|M)+ D_{KL}(Q\|M))
$$
we can use Monte Carlo estimates for the individual components. The Kullback-Leibler divergence is defined as:
$$
D_{KL}(P|M) = \int P(x) log\big(\frac{P(x)}{M(x)}\big) dx
$$
The Monte Carlo approximation of this is:
$$
D_{KL}^{approx}(P|M) = \frac{1}{n} \sum^n_i log\big(\frac{P(x_i)}{M(x_i)}\big)
$$

where the $x_i$ have been sampled from $P(x)$, which is easy as it is a Gaussian in our case. As $n \to \infty, D_{KL}^{approx}(P|M) \to KLD(P|M)$. $M(x_i)$ can be calculated as

$$
M(x_i) = \frac{1}{2}P(x_i) + \frac{1}{2}Q(x_i)
$$.

Here is my attempt:

import numpy as np
from scipy.stats import multivariate_normal as MVN

def jsd(mu_1: np.array, sigma_1: np.ndarray, mu_2: np.array, sigma_2: np.ndarray):
    """
    Monte carlo approximation to jensen shannon divergence for multivariate Gaussians.
    """
    assert mu_1.shape == mu_2.shape, "Shape mismatch."
    assert sigma_1.shape == sigma_2.shape, "Shape mismatch."

    # Monte Carlo samples
    MC_samples = 1000

    # Take MC samples
    P_samples = MVN.rvs(mean=mu_1, cov=sigma_1, size=MC_samples)
    Q_samples = MVN.rvs(mean=mu_2, cov=sigma_2, size=MC_samples)

    P = lambda x: MVN.pdf(x, mean=mu_1, cov=sigma_1)
    Q = lambda x: MVN.pdf(x, mean=mu_2, cov=sigma_2)
    M = lambda x: 0.5 * P(x) + 0.5 * Q(x)

    P_div_M = lambda x: P(x) / M(x)
    Q_div_M = lambda x: Q(x) / M(x)

    D_KL_approx_PM = lambda x: (1 / MC_samples) * sum(np.log2(P_div_M(x)))
    D_KL_approx_QM = lambda x: (1 / MC_samples) * sum(np.log2(Q_div_M(x)))

    return 0.5 * D_KL_approx_PM(P_samples) + 0.5 * D_KL_approx_QM(Q_samples)

Suffice to say, this does not quite produce what it should.

Best Answer

Actually, using the answer in https://stackoverflow.com/questions/26079881/kl-divergence-of-two-gmms (and the fact, that the author factored out the 1/2 from the logarithm, made the montecarlo approximation sample from both distributions to average the result), I would say, that the symmetrized numerical code for jensen shannon divergence using monte carlo integration, even for general scikit.stats distributions (_p and _q), should look like this:

def distributions_js(distribution_p, distribution_q, n_samples=10 ** 5):
    # jensen shannon divergence. (Jensen shannon distance is the square root of the divergence)
    # all the logarithms are defined as log2 (because of information entrophy)
    X = distribution_p.rvs(n_samples)
    p_X = distribution_p.pdf(X)
    q_X = distribution_q.pdf(X)
    log_mix_X = np.log2(p_X + q_X)

    Y = distribution_q.rvs(n_samples)
    p_Y = distribution_p.pdf(Y)
    q_Y = distribution_q.pdf(Y)
    log_mix_Y = np.log2(p_Y + q_Y)

    return (np.log2(p_X).mean() - (log_mix_X.mean() - np.log2(2))
            + np.log2(q_Y).mean() - (log_mix_Y.mean() - np.log2(2))) / 2

print("should be different:")
print(distributions_js(st.norm(loc=10000), st.norm(loc=0)))
print("should be same:")
print(distributions_js(st.norm(loc=0), st.norm(loc=0)))

For noncontinuous, change .pdf to probabilities of samples.