Neural Networks – Understanding Learned Loss Attenuation for Classification

loss-functionsneural networksuncertainty

In the paper What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision? they propose loss functions that capture aleatoric uncertainty. My question heavily relies on understanding of this paper, so I can go straight to the question without the details. They suggest the following loss objective for regression, where the network outputs two values [ŷ, σ²]:
regression objective

This allows the model to learn to output a large variance in order to attenuate the loss whenever there's aleatoric uncertainty.

Now, my question is, can the same principle be used for a classification loss? In the paper they end up defining a more complex loss function, but why not just do the same thing as they did for regression? Just replacing the Mean Squared Error for a classification loss like this:

Doesn't the same logic apply?

Best Answer

The motivation of Bayesian methods is that they allow you to characterize the model in probabilistic terms. In the case of a Gaussian likelihood, the model in the paper allows you to treat each observation $i$ has its own Gaussian r.v. with a mean $\hat y_i$ and a variance $\hat\sigma_i^2$. Then you can answer probabilistic questions about the model's predictions, e.g. "What is the probability that $0 < y_j < 2$?"

The first expression you write is the cross-entropy (negative log-likelihood) for a normal distribution where both the mean and variance are estimated from the data.

The second expression you write does not appear to be a probability distribution. Even if we substitute the standard classification loss $\sum_k y_{ik} \log p_{ik}$ where $\sum_k p_{ik}=1$ and $0 \le p_{ik}$ and $y_{ik}$ as one-hot vector, the loss does not correspond to a probability distribution unless we fix $\hat \sigma_i=1$ for all $i$.

There's no particular reason that neural networks must be probability models (see e.g. networks). But in the context of the paper, you would not have a probabilistic interpretation of the model's predictions, because the second loss is not a probability model.

The reason that the authors introduce the "more complex loss function" is to add a source of noise to corrupt model predictions in a way that doesn't change the probability model for the labels.

Related Question