Solved – What stops the network from learning the same weights in multi-head attention mechanism

attentiondeep learningneural networks

I have been trying to understand the transformer network and specifically the multi-head attention bit. So, as I understand it that multiple attention weighted linear combination of the input features are calculated.

My question is what stops the network from learning the same weights or linear combination for each of these heads i.e. basically making the multiple head bit redundant. Can that happen? I am guessing it has to happen for example in the trivial case where the translation only depends on the word in the current position?

I also wonder if we actually use the full input vector for each of the heads. So, imagine my input vector is of length 256 and I am using 8 heads. Would I divide my input into $256 / 8 = 32$ length vectors and perform attention on each of these and concatenate the results or do I use the full vector for each of these and then combine the results?

Best Answer

We observe these kind of redundancies in literally all neural network architectures, starting from simple fully-connected networks (see diagram below), where same inputs are mapped to multiple hidden layers. Nothing prohibits the network from ending up with same weights in here as well.

enter image description here

We fight this by random initialization of weights. You usually need to initialize all the weights randomly, unless some special cases where initializing with zeros or other values proved to worked better. The optimization algorithms are deterministic, so there is no reason whatsoever why the same inputs could lead to different outputs if all the initial conditions were the same.

Same seems to be true for the original attention paper, but to convince yourself, you can check also this great "annotated" paper with PyTorch code (or Keras implementation if you prefer) and this blog post. Unless I missed something from the paper and the implementations, the weights are treated the same in each case, so there is not extra measures to prevent redundancy. In fact, if you look at the code in the "annotated Transformer" post, in the MultiHeadedAttention class you can see that all the weights in multi-head attention layer are generated using same kind of nn.Linear layers.

Related Question