Training Transformers: self attention weights vs embedding layer

attentionmachine learningnatural languagetransformersword embeddings

I have been trying to wrap my head around transformers. While I have found many good resources that explain the self attention mechanism I've yet to find a good answer on how it really works with respect to training.

With respect to the embedding layer my understanding is that an input of words or pixels is first tokenized and then projected using an learned linear transformation from the token space to a new embedding space [Commonly position is also projected]. From what I gather there are a few ways to actually train the embedding such as masking out words or image patches and trying to fill in the blank. My intuition is that it is this training that leads to an embedding which projects the token for Queen to a vector that is more "similar" to the vector transformed from the token for king than dog.

I believe I understand the mechanism of self attention

$Q=XW_Q$

$K=XW_k$

$V=XW_v$

$Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$

The product of Q and K is a vectorized implementation of the dot product which measures the similarity between two vectors. I think this is like a distillation step where all information is weighted by relevance. I think the final step is to use the weight to the embedding.

Where I get confused is that the weights $W_Q$ and $W_K$ also seem to serve a similar purpose to the embedding. In fact my first question is are $W_Q$, $W_K$ the same weights as the initial embedding typically found at the beginning of transformers? It seems like $W_V$ must just be weights from the layer right before SA? So V is just the activations? If not why does the training in SA or MHSA result in weights encode words which often appear together similarly? I guess I just don't see what guarantees this property. Is it because pre-training tasks for language and image transformers back prop through all layers?

Am I wrong to assume that $W_K$ and $W_V$ must be trained in the same manner as the embedding in order to learn transformations that would actually make the projection of vector king onto vector queen larger than the projection of vector king onto vector dog?

If anyone could give some intuition on how training the weights for SA and MHSA actually works it would be greatly appreciated.

Best Answer

I'm more familiar with NLP, so let me explain in that context.

With respect to the embedding layer my understanding is that an input of words or pixels is first tokenized and then projected using an learned linear transformation from the token space to a new embedding space

This is correct. But to ensure that we’re on the same page, I’ll give an example. Given an input comprised of tokens $w_1,w_2,w_3$, the embedding layer maps each token into a vector representation, i.e. an embedding. Let $x_1,x_2,x_3$ denote the embeddings of $w_1,w_2,w_3$ respectively. This mapping process can be achieved by applying a linear transformation to one-hot encodings of the tokens. The parameter of this transformation is the embedding matrix whose values are initialised randomly and learned during training.

are $W_Q,W_K$ the same weights as the initial embedding typically found at the beginning of transformers?

No, they aren’t embedding matrices, and neither is $W_V$. They’re 3 distinct matrices corresponding to 3 separate linear transformations. If the input embeddings are organised such that $x_i^\top$ is the $i$-th row of the matrix $X$, then

$$ \begin{align} Q &= XW_Q,\\ K &= XW_K,\\ V &= XW_V \end{align} $$

are 3 linear transformations applied to $X$ to get the query, the key, and the value vectors organised in matrices $Q,K,V$ respectively. There is no non-linear activation function here. Matrices $W_Q,W_K,W_V$ are also initialised randomly and learned during training. Then, the attention matrix $A$ is defined as

$$ A=\mathrm{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) $$

where $d_k$ is the dimension of the query and the key vectors. This definition states that the attention $A_{ij}$ between $w_i$ and $w_j$ is proportional to the exponent of the scaled dot product between the query (row) vector $q_i$ and the key (row) vector $k_j$. So, dot product is not the intuition but rather the definition of attention.

The weight matrices $W_Q,W_K,W_V$ function as linear transformations applied to $X$. Without these transformations, $Q=K=V=X$, which is a special case achieved when $W_Q=W_K=W_V=I$. In other words, the linear transformations intuitively make SA more expressive. I'm not aware of papers that study them systematically.

why does the training in SA or MHSA result in weights encode words which often appear together similarly?

The premise of this question is debatable unless the training objective explicitly encourages such similarity. Merely employing SA or MHSA doesn't automatically produce similar representations for collocated words. Furthermore, when the Transformer paper introducing MHSA came out, NLP had largely moved on from representing a word as a single vector (the so called static word embeddings like word2vec) to contextual word embeddings which are produced by a sequential encoder like LSTMs. The contextual embedding of a word is a function of the whole input, so the word analogy task commonly used to evaluate static word embeddings isn't sensible and is no longer used to evaluate contextual embeddings (which seems like what you have in mind).

I guess I just don’t see what guarantees this property. Is it because pre-training tasks for language and image transformers back prop through all layers?

As I said above, there’s no guarantee that collocated tokens would have similar representations by the mere use of SA or MHSA.

Am I wrong to assume that $W_K$ and $W_V$ must be trained in the same manner as the embedding in order to learn transformations that would actually make the projection of vector king onto vector queen larger than the projection of vector king onto vector dog?

Again, here you seem to conflate static and contextual word embeddings. Nowadays we no longer care that much about the static embeddings produced by the embedding layer, i.e. $x_1,x_2,x_3$ in the example. Thus, the word analogy task evaluation à la word2vec is no longer used. We care more about the contextual embeddings output by the sequential encoder, e.g., the Transformer encoder (which uses MHSA) in BERT language model. What these contextual embeddings encode is an active research area.

If anyone could give some intuition on how training the weights for SA and MHSA actually works it would be greatly appreciated.

The weight matrices $W_Q,W_K,W_V$ of the SA layers are trained via backpropagation, just like other network parameters including the embedding matrix. There's no other technique. It's all just backpropagation.