Transformers – Why Residual Connections are Needed in Neural Networks

attentionneural networksresidual-networkstransformers

Residual connections are often motivated by the fact that very deep neural networks tend to "forget" some features of their input data-set samples during training.

This problem is circumvented by summing the input x to the result of a typical feed-forward computation in the following way:

$$ \mathcal F(x) + x = \left[ W_2 \sigma( W_1 \mathbf{x} + b_1 ) + b_2 \right] + \mathbf{x}.$$

This was schematically represented in [1] as:

enter image description here

On the other hand, it is also well known that transformer architectures have some residual networks, as the following picture elaborates:

enter image description here

Question: Residual connections are motivated in the context of very deep network architectures, but attention blocks perform very little computations compared to the networks that were outperformed in [1]; so, what is the motivation for the presence of shortcut connections in the attention-blocks of transformer architectures ?

Best Answer

The reason for having the residual connection in Transformer is more technical than motivated by the architecture design.

Residual connections mainly help mitigate the vanishing gradient problem. During the back-propagation, the signal gets multiplied by the derivative of the activation function. In the case of ReLU, it means that in approximately half of the cases, the gradient is zero. Without the residual connections, a large part of the training signal would get lost during back-propagation. Residual connections reduce effect because summation is linear with respect to derivative, so each residual block also gets a signal that is not affected by the vanishing gradient. The summation operations of residual connections form a path in the computation graphs where the gradient does not get lost.

Another effect of residual connections is that the information stays local in the Transformer layer stack. The self-attention mechanism allows an arbitrary information flow in the network and thus arbitrary permuting the input tokens. The residual connections, however, always "remind" the representation of what the original state was. To some extent, the residual connections give a guarantee that contextual representations of the input tokens really represent the tokens.

Related Question