Solved – What loss function for multi-class, multi-label classification tasks in neural networks

cross entropykerasloss-functionsneural networkspython

I'm training a neural network to classify a set of objects into n-classes. Each object can belong to multiple classes at the same time (multi-class, multi-label).

I read that for multi-class problems it is generally recommended to use softmax and categorical cross entropy as the loss function instead of mse and I understand more or less why.

For my problem of multi-label it wouldn't make sense to use softmax of course as each class probability should be independent from the other. So my final layer is just sigmoid units that squash their inputs into a probability range 0..1 for every class.

Now I'm not sure what loss function I should use for this. Looking at the definition of categorical crossentropy I believe it would not apply well to this problem as it will only take into account the output of neurons that should be 1 and ignores the others.

Binary cross entropy sounds like it would fit better, but I only see it ever mentioned for binary classification problems with a single output neuron.

I'm using python and keras for training in case it matters.

Best Answer

If you are using keras, just put sigmoids on your output layer and binary_crossentropy on your cost function.

If you are using tensorflow, then can use sigmoid_cross_entropy_with_logits. But for my case this direct loss function was not converging. So I ended up using explicit sigmoid cross entropy loss $(y \cdot \ln(\text{sigmoid}(\text{logits})) + (1-y) \cdot \ln(1-\text{sigmoid}(\text{logits})))$ . You can make your own like in this Example

Sigmoid, unlike softmax don't give probability distribution around $n_{classes}$ as output, but independent probabilities.

If on average any row is assigned less labels then you can use softmax_cross_entropy_with_logits because with this loss while the classes are mutually exclusive, their probabilities need not be. All that is required is that each row of labels is a valid probability distribution. If they are not, the computation of the gradient will be incorrect.