Variance of conditional multivariate gaussian

normal distributionprobability distributions

I was playing around with Gaussian Distributions on my machine and I was interesting in making a pretty plot. I wanted to show the distribution of $x_1$ if $x_2$ was given if $x_1,x_2$ were distributed by a multivariate normal distribution.

enter image description here

The image feels "wrong". I would imagine that the red likelihood in the center of the blue distribution would be wider than at the edge. After double checking the code, I notice it might be the maths.

Maths on wikipedia as well as my books confirm that the distribution of $x_1$ conditional on $x_2$ = $a$ is multivariate normal $(x_1 | x_2 = a) \sim N(\hat{\mu}, \hat{\Sigma})$ where

$$
\bar{\boldsymbol\mu}
=
\boldsymbol\mu_1 + \boldsymbol\Sigma_{12} \boldsymbol\Sigma_{22}^{-1}
\left(
\mathbf{a} – \boldsymbol\mu_2
\right)
$$

and covariance matrix

$$
\overline{\boldsymbol\Sigma}
=
\boldsymbol\Sigma_{11} – \boldsymbol\Sigma_{12} \boldsymbol\Sigma_{22}^{-1} \boldsymbol\Sigma_{21}.
$$

When looking at the maths, it seems that the variance of $p(x_1|x_2 = a)$ does not depend on the value of $a$. This is starting to feel very counter intuitive so I am wondering if I am missing something.

The code that generated the plot

import matplotlib.pylab as plt 

import torch
from torch.distributions import Normal as norm
from torch.distributions.multivariate_normal import MultivariateNormal as mvnorm

#@title different given values { run: "auto" }
g1 = -3.5 #@param {type:"slider", min:-4, max:4, step:0.1}
g2 = 0.2 #@param {type:"slider", min:-4, max:4, step:0.1}
g3 = 2.9 #@param {type:"slider", min:-4, max:4, step:0.1}

m = torch.tensor([0.0, 0.0])
c = torch.tensor([[1.0, 0.9], [0.9, 1.0]])

s = mvnorm(m, c).sample(sample_shape=(5000,))
s_np = s.numpy().reshape(5000, 2)
plt.figure(figsize=(6,5))
plt.scatter(s_np[:, 0], s_np[:, 1], alpha=0.3)

for g in [g1, g2, g3]:
    mu_pred = m[1] + c[0][1]/c[1][1]*(g - m[0])
    sigma_pred = c[1][1] - c[1][0]/c[0][0]*c[0][1]
    fitted_distr = norm(mu_pred, sigma_pred)
    print(f"g:{g:.3}, mu:{mu_pred:.2}, sigma:{sigma_pred:.4}")

    xs = torch.linspace(-4, 4, 300)
    likelihood = torch.exp(fitted_distr.log_prob(xs)).numpy()

    plt.plot(xs.numpy(), g + likelihood, c='red')

Best Answer

This question was also asked on another stack-exchange website and that question has been answered there.

Related Question