The relationship between Probabilistic Graphical Models and Graph Neural Networks

bayesian networkgraph theoryneural networks

I would like to learn more about one or both of these. I incline towards Bayesian networks and PGMs but since Battaglia et al, 2018 I have had half an eye on the various kinds of GNN.

You seem to be able to do similar things with both PGMs and GNNs – inference and learning (of probabilities/weights and of structure) on graphs. Do the two have different strengths?

Are they just rival academic camps? eg the new book "Deep Learning on Graphs" (Ma & Tang, 2021) doesn't have Koller & Friedman, 2009 in the bibliography & doesn't mention the term "graphical model".

Are the approaches complementary or antagonistic?

References

Battaglia et al., 2018, "Relational inductive biases, deep learning, and graph networks", https://deepmind.com/research/publications/relational-inductive-biases-deep-learning-and-graph-networks

Koller & Friedman, 2009, " Probabilistic Graphical Models: Principles and Techniques", https://mitpress.mit.edu/books/probabilistic-graphical-models

Ma & Tang, 2021, "Deep Learning on Graphs" http://cse.msu.edu/~mayao4/dlg_book/

Best Answer

The other answers and comments have done a good job of highlighting how PGMs and GNNs are traditionally used in different settings.

However, there is a deep underlying connection between them: basic GNNs can be derived through embedded mean field variational inference on the joint distribution of the node features and their latent representations.

The work that first showed the above is [H. Dai, B. Dai, and L. Song. Discriminative embeddings of latent variable models for structured data. In ICML, 2016].

Moreover, William L. Hamilton's incredible book "Graph Representation Learning" has an entire chapter dedicated to the theoretical motivations of graph neural networks, including spectral graph convolutions, PGMs, and the Weisfieler-Lehman graph isomorphism test.

Below I have summarised the main steps for this construction, which Hamilton (whose notation I use) and of course the authors of the paper explain in much greater detail.

Given a graph $G = (V, E)$, a GNN tries to learn an embedding $z_u \forall u\in V$. From a probabilistic perspective, we can think of the embeddings as latent variables to infer, to explain the observed data, namely, the graph structure (adjacency matrix $A$) and input node features $X$.

More concretely, a graph defines a Markov random field:

$$p(\{x_v\}, \{z_v)\} \propto \prod_{v \in V}{Φ(x_v, z_v)} \prod_{(u, v) \in E}{Ψ(z_u, z_v)}$$

where $Φ$ and $Ψ$ are non-negative potential functions. The above essentially says that the graph factors according to the graph structure.

In order to infer the distribution of latent embeddings, we will need to compute the posterior $p(\{ x_v\}, \{ z_v\})$ i.e. the likelihood of a set of embeddings given the observed features. This calculation is generally intractable, even with the right potential functions, so we resort to approximate inference.

We assume that the posterior over the latent variables factorizes into $V$ independent distributions, one per node:

$$p(\{ z_u\}, \{ x_u\}) \approx q(\{ z_u\}) = \prod_{v \in V}{q_v(z_v)}$$

To obtain the optimal $\{ q_v\}$, we proceed as usual by minimising the KL divergence:

$$KL[q(\{ z_v\}) || \{ p(\{ z_v\} | \{ x_v\})\} ] = \int_{(\mathbb{R}^d)^V}\prod_{v \in V}q(\{ z_v\})log\left(\frac{\prod_{v \in V}{q(\{ z_v\})}}{p(\{ z_v\} | \{ x_v \})}\right) \prod_{v \in V}{dz_v}$$

It can be shown through VI, that the $ q(z_v)$ minimising the KL divergence must satisfy the fixed point equation:

$$log(q_v(z_v)) = c_v + log(Φ(x_v, z_v)) + \sum_{u \in N(u)}{\int_{\mathbb{R}^d}}{q_u(z_u)log(Ψ(z_u, z_v)} dz_u$$

which can be iteratively approximated through:

$$log(q_v^{(t)}(z_v)) = c_v + log(Φ(x_v, z_v)) + \sum_{u \in N(u)}{\int_{\mathbb{R}^d}}{q_u^{(t-1)}(z_u)log(Ψ(z_u, z_v)} dz_u$$

where $N(u)$ denotes the neighbourhood of node $u$ in the graph.

The remaining tool needed to make this work is to derive a sufficient statistic for the distribution $q_v(z_v)$ by embedding it in Hilbert space so that any computation carried out on the distribution can equivalently be carried out on its embedding.

The Hilbert space embedding of a distribution $p(x)$ on some $x\in \mathbb{R}^d$ is defined for some suitable feature map $\phi$ to make the embedding injective:

$$μ_x = \int_{\mathbb{R}^m}\phi(x)p(x)dx$$

By embedding $q_v(z_v)$ we can transform the fixed-point equation above to:

$$ μ_v^{(t)} = c + f(μ_v^{(t-1)}, x_v, \{ μ_v \forall u \in N(u)\})$$

Here, $f$ is a vector-valued function that aggregates information from the embeddings of neighboring nodes, which is then used to update the node's current embedding.

This corresponds exactly to a form of neural message passing!

We could then try to learn the Hilbert space embeddings that would correspond to a probabilistic model, but now we can define $f$ as a basic GNN:

$$μ_v^{(t)} = σ \left( W_{self}^{(t)}x_v + W_{neigh}^{(t)} \sum_{u \in N(u) }{μ_u^{(t-1)}}\right)$$