Solved – Computing KL divergence in loss function of Bayesian neural networks

deep learningmachine learningvariational-bayes

Hi I am trying to understand how the loss function for Bayesian Neural Networks (BNN) is computed. In the TensorFlow documentation they illustrate a BNN in practice where they train the network to minimise the negative of the ELBO (as seen below).

import tensorflow as tf
import tensorflow_probability as tfp

model = tf.keras.Sequential([
    tf.keras.layers.Reshape([32, 32, 3]),
    tfp.layers.Convolution2DFlipout(
        64, kernel_size=5, padding='SAME', activation=tf.nn.relu),
    tf.keras.layers.MaxPooling2D(pool_size=[2, 2],
                                 strides=[2, 2],
                                 padding='SAME'),
    tf.keras.layers.Flatten(),
    tfp.layers.DenseFlipout(10),
])

logits = model(features)
neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
    labels=labels, logits=logits)
kl = sum(model.losses)
loss = neg_log_likelihood + kl
train_op = tf.train.AdamOptimizer().minimize(loss)

However, they seem to be computing the KL divergence as the sum of the losses of the network weights i.e. kl = sum(model.losses). This is not how the KL divergence should be computed.

They repeat it here too :

  # Compute the -ELBO as the loss, averaged over the batch size.
  neg_log_likelihood = -tf.reduce_mean(labels_distribution.log_prob(labels))
  kl = sum(neural_net.losses) / mnist_data.train.num_examples
  elbo_loss = neg_log_likelihood + kl

Am I missing something very basic? The exact KL divergence is something which is quite difficult to compute unless you make assumptions about the underlying probability distribution of the weights such as it following a Normal distribution in which case the KL divergence would be computed as:

enter image description here

where $\theta$ is the weights, $\mu$ is the mean weight, and $\sigma$ is the standard deviation of the weight and $\mu'$ and $\sigma'$ are the means and standard deviations after the updates.

Best Answer

tfp.layers computes the KL terms and adds them to model.losses automatically.

Those layers call this function here which ends up computing the KL value as you've written it out.

As you can see in the documentation, the prior defaults to the standard normal distribution, and the posterior is approximated with a mean field distribution.