neural-networks – How Padding Masking is Considered in Transformer Attention Head

natural languageneural networkstransformers

For purely educational purposes, my goal is to implement basic Transformer architecture from scratch. So far I focused on the encoder for classification tasks and assumed that all samples in a batch have the same length. This means, I didn't care about any masking.

However, now I want to support masking. I like to think that I understand the the purpose of, e.g., the target mask so the order cannot "peek into the future". I generate this mask as follows:

source_batch = torch.LongTensor([
    [1, 2, 3, 0, 0, 0],
    [1, 2, 3, 4, 5, 6],
    [1, 2, 3, 4, 5, 0]
])

batch_size, seq_len = source_batch.shape

def generate_tgt_mask(size):
    return torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1)

print(generate_tgt_mask(seq_len))

yielding:

tensor([[0., -inf, -inf, -inf, -inf, -inf],
        [0.,   0., -inf, -inf, -inf, -inf],
        [0.,   0.,   0., -inf, -inf, -inf],
        [0.,   0.,   0.,   0., -inf, -inf],
        [0.,   0.,   0.,   0.,   0., -inf],
        [0.,   0.,   0.,   0.,   0.,   0.]])

which should be the expected outcome when I check the PyTorch docs. This mask has a shape of (L,L) where L is the sequence length of the source or target sequence. Again, this matches the docs.

I use this mask in my implementation of the Scaled Dot Product Attention as follows — which should be in line with many other implementations I've seen:

class Attention(nn.Module):
    ### Implements Scaled Dot Product Attention
    
    def __init__(self):
        super().__init__()


    def forward(self, Q, K, V, mask=None, dropout=None):
        # All shapes: (batch_size, seq_len, hidden_size)
        
        # Perform Q*K^T (* is the dot product here)
        # We have to use torch.matmul since we work with batches!
        out = torch.matmul(Q, K.transpose(1, 2)) # => shape: (B, L, L)

        # Divide by scaling factor
        out = out / (Q.shape[-1] ** 0.5)

        # Optional: src_mask/tgt_mask (shape: (L, L); mask values are represented by -inf)
        if mask is not None:
            out += mask.unsqueeze(0) # Broadcast since it's the same mask for all samples in batch
        
        # Push throught softmax layer
        out = f.softmax(out, dim=-1)
        
        # Optional: Dropout
        if dropout is not None:
            out = nn.Dropout(out, dropout)
        
        # Multiply with values V
        out = torch.matmul(out, V)
        
        return out

So far so good…at least I like to think. However, my problem is now the mask to address the padding (e.g. src_key_padding_mask). From different tutorials using the nn.Transformer, this mask can be generated as follows:

pad_token_index = 0

src_key_padding_mask = (source_batch != pad_token_index)

print(src_key_padding_mask)

yielding:

tensor([[ True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True, False]])

having shape of (N,L) which again matches the doc.

What I'm now missing is: How do I have to incorporate this matrix into my implementation of Attention?

Intuitively, I would assume that the masking matrix would contain -inf for each position associated the a padding. For example, looking at the first sequence in my example batch above, I would assume the masking matrix to look like:

tensor([[0.,   0.,   0.,   -inf, -inf, -inf],
        [0.,   0.,   0.,   -inf, -inf, -inf],
        [0.,   0.,   0.,   -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf]])

And indeed, some — but not all — example code that implement the Transformer archictectur from scratch, create the masking matrix for the padding like this. Applying this matrix to the scores obviously also sets the scores to 0, that is, the last 3 rows are all 0.

However, once pushed throught Softmax, the last 3 rows now all contain the value 1/6. For example, for the source_batch above I get

tensor([[[0.1989, 0.4297, 0.3714, 0.0000, 0.0000, 0.0000],
         [0.4334, 0.2225, 0.3440, 0.0000, 0.0000, 0.0000],
         [0.2880, 0.2284, 0.4836, 0.0000, 0.0000, 0.0000],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       ...
       (the other 2 samples of the batch are not shown)

What am I missing here? I'm pretty sure it's something trivial, but I just can't see it right now.

Best Answer

Your implementation is correct. It doesn't matter that the rows corresponding to the padding tokens have a uniform attention because the next module that uses the attention's output (variable out in your code) should ignore these padding tokens. For example, if the next module is a linear layer followed by a cross-entropy loss (common in sequence tagging), when you compute the mean loss you have to mask these padding positions. Therefore, your loss wouldn't include these invalid positions, so their attention values don't matter.

Related Question