I was playing around with Gaussian Distributions on my machine and I was interesting in making a pretty plot. I wanted to show the distribution of $x_1$ if $x_2$ was given if $x_1,x_2$ were distributed by a multivariate normal distribution.
The image feels "wrong". I would imagine that the red likelihood in the center of the blue distribution would be wider than at the edge. After double checking the code, I notice it might be the maths.
Maths on wikipedia as well as my books confirm that the distribution of $x_1$ conditional on $x_2$ = $a$ is multivariate normal $(x_1 | x_2 = a) \sim N(\hat{\mu}, \hat{\Sigma})$ where
$$
\bar{\boldsymbol\mu}
=
\boldsymbol\mu_1 + \boldsymbol\Sigma_{12} \boldsymbol\Sigma_{22}^{-1}
\left(
\mathbf{a} – \boldsymbol\mu_2
\right)
$$
and covariance matrix
$$
\overline{\boldsymbol\Sigma}
=
\boldsymbol\Sigma_{11} – \boldsymbol\Sigma_{12} \boldsymbol\Sigma_{22}^{-1} \boldsymbol\Sigma_{21}.
$$
When looking at the maths, it seems that the variance of $p(x_1|x_2 = a)$ does not depend on the value of $a$. This is starting to feel very counter intuitive so I am wondering if I am missing something.
The code that generated the plot
import matplotlib.pylab as plt
import torch
from torch.distributions import Normal as norm
from torch.distributions.multivariate_normal import MultivariateNormal as mvnorm
#@title different given values { run: "auto" }
g1 = -3.5 #@param {type:"slider", min:-4, max:4, step:0.1}
g2 = 0.2 #@param {type:"slider", min:-4, max:4, step:0.1}
g3 = 2.9 #@param {type:"slider", min:-4, max:4, step:0.1}
m = torch.tensor([0.0, 0.0])
c = torch.tensor([[1.0, 0.9], [0.9, 1.0]])
s = mvnorm(m, c).sample(sample_shape=(5000,))
s_np = s.numpy().reshape(5000, 2)
plt.figure(figsize=(6,5))
plt.scatter(s_np[:, 0], s_np[:, 1], alpha=0.3)
for g in [g1, g2, g3]:
mu_pred = m[1] + c[0][1]/c[1][1]*(g - m[0])
sigma_pred = c[1][1] - c[1][0]/c[0][0]*c[0][1]
fitted_distr = norm(mu_pred, sigma_pred)
print(f"g:{g:.3}, mu:{mu_pred:.2}, sigma:{sigma_pred:.4}")
xs = torch.linspace(-4, 4, 300)
likelihood = torch.exp(fitted_distr.log_prob(xs)).numpy()
plt.plot(xs.numpy(), g + likelihood, c='red')
Best Answer
This question was also asked on another stack-exchange website and that question has been answered there.