Solved – How to train a CNN with multiple heads

classificationconv-neural-networkmachine learningneural networks

I try to understand how the multi digit classification in this paper with the Google Street View data works. They try to detect multiple digits within a picture without localization (5 digits and the length of the sequence).

They state that they use 6 classifiers on top of convolution layers:

Each of the variables above is discrete, and when applied to the street number transcription problem,
each has a small number of possible values: L has only 7 values (0, . . . , 5, and “more than 5”), and
each of the digit variables has 10 possible values. This means it is feasible to represent each of them
with a softmax classifier that receives as input features extracted from X by a convolutional neural
network

and further:

Each of the softmax models (the model for L and each Si) can use
exactly the same backprop learning rule as when training an isolated softmax layer, except that a
digit classifier softmax model backprops nothing on examples for which that digit is not present.

As far as I understand it this means they use some convolution layers for feature detection and these features are basically the inputs of 6 independent classifiers. And these classifiers (e.g. some fully connected layers) are trained but the output of the conv layers stays fixed (backprop stops at the first fc layer).

But how are the conv layers trained?

Best Answer

This page shows a nice graph https://github.com/potterhsu/SVHNClassifier ,and there's a tensorflow model here https://github.com/potterhsu/SVHNClassifier/blob/master/model.py

So basically, as far as the model structure:

  • there's just one stack of convolutional layers, which feed into all of the softmax layers
  • each 'softmax' layer comprises a fully-connected layer (dense in tensorflow parlance), followed by a softmax

As far as backprop, backprop will run down the entire stack, for each digit. But, there are 5 available digit outputs right? eg, it could output 31256. But lets say the target number is 432, what should we do with the two additional digit classifiers? And the answer is: no backprop happens through those two additional digit classifiers, for which there is no target in this case.

And what will happen is that L for such cases will be, well in this case it will be 3, so the prediction output from the network will simply ignore the output of the two additional digit classifiers.

But otherwise, backprop is just standard backprop, through all layers.

As far as how to backprop only through some numbers in practice, a couple of approaches:

  • get the output from your network, and feed back the exact same numbers as the target for the numbers that arent being used: that way there'll be no gradient for those, or
  • use the value of L to modify the loss function something like, conceptually: loss = digit_one_loss * (L >= 1) + digit_two_loss * (L >= 2) ...

You'll need to figure out a way to change (L >= 2) and so on into numbers having a value of 0 or 1.

Related Question