Attention as Soft Dictionary Lookup

The Dictionary Metaphor πŸ”‘πŸ“š

By now, the scaled-dot product attention formula might have burned into our brains 🧠, $\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$, yet still this brilliant metaphor from Kevin Murphy’s PML book gives it a refreshed interpretation —

“We can think of attention as a soft dictionary look up, in which we compare the query $q$ to each key $k_i$, and then retrieve the corresponding value $v_i$.” — Chapter 15, p. 513

In a real dictionary, we query values by key and only grab the value whose key matches the query (dict[query]). The attention mechanism, however, allows us to grab values from multiple keys and return their weighted sum for the given query.

Image Source: Probabilistic Machine Learning: An Introduction, Chapter 15

Image Source: Probabilistic Machine Learning: An Introduction, Chapter 15

The more compatible a key is to the query, the higher the “attention weight” we assign to the key’s value. The $i$th key’s attention weight is $\frac{\exp(\alpha(q, k_i))}{\sum_{j=1}^{m}\exp(a(q, k_j))}$ — attention weights of all keys sum to 1. The numerator of the attention weight, $\alpha(q, k_i)$, is the “attention score”. For masked tokens, we can set attention scores to a large negative number (e.g., $-10^6$) so that their attention weights will turn out close to 0.

The most popular attention score function is the dot product between $q$ and $k_i$, scaled by the square root of the k dimension, $\sqrt{d_k}$. The scaled dot-product attention is used by the original transformer (Vaswani et al., 2017) and many of its descendants. You can replace $\alpha(q, k_i)$ with other similarity functions.

Hugging Face Implementation πŸ¦œπŸ€—

I heard that in increasingly more MLE interviews are candidates asked to code components of the transformer architecture from scratch. I find the implementation in the Hugging Face book the easiest to follow and perhaps closest to the day-to-day coding style of NLP practitioners. Let’s begin by reviewing key concepts such as self-attention and multi-headed attention and then code up a basic TransformerForSequenceClassification model relying on the transformer encoder.

Self-Attention

The transformer encoder uses the self-attention between each input token and all other input tokens to create contextual token embeddings. Before encoding, homophones (“flies” in “time flies like an arrow” and “fruit flies like a banana”) have the same initial embedding. In the first sentence, however, the token “flies” attends most strongly to “time” and “arrow”, so its contextual embedding will be close to that of these two, whereas in the second sentence, “flies” attends most strongly to “fruit” and “banana”, so its embedding will be close to theirs.

After encoding, tokens attain new embeddings from associated tokens — as the developmental psychologist Jean Piaget put it, “Through others we become ourselves."

The bertviz package can visualize the attention weight between tokens.

The bertviz package can visualize the attention weight between tokens.

Multi-Headed Attention

The original transformer paper pioneered the multi-head attention (MHA), where each “head” could capture a different notion of similarity (e.g., semantic, syntactic). The query $Q$, key $K$, and value $V$ matrices are split along the embedding dimension, $d_{model}$, and each split with the embedding size $d_{model} / h$ is fed to each head. Outputs from each head are concatenated to form a single output tensor — $\text{MHA}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \dots, \text{head}_h)W^O$ — before being passed to the linear layer.

“By dividing the hidden_dim, each head indeed sees the entire sequence (seq_len) but only a “slice” or portion of each token’s embedding dimension (head_dim). This design enables the model to parallelly attend to information from different representation subspaces at different positions, enriching the model’s ability to capture diverse relationships within the data.” — Natural Language Processing with Transformers

Instead of starting with random embeddings for each token, we can use the tokenizer of pre-trained model, say bert-base-uncased, to encode the input text.

 1from math import sqrt
 2import torch
 3from torch import nn
 4import torch.nn.functional as F
 5from transformers import AutoTokenizer, AutoModel, AutoConfig
 6
 7# load tokenizer from model checkpoint
 8model_ckpt = "bert-base-uncased"
 9tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
10
11# load config associated with given model
12config = AutoConfig.from_pretrained(model_ckpt)
13
14# input text
15text = "time flies like an arrow"
16
17# tokenize input text
18inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
19
20# nn.Embedding is a lookup table to find embeddings of each input_id
21token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
22
23# look up embeddings by id
24input_embs = token_emb(inputs.input_ids)

For simplicity, we can use the same linear projection of the input embedding (input_embs) for $Q$ (query), $K$ (key), $V$ (value). In practice, we usually use 3 different linear projections. The scaled_dot_product_attention function below takes query, key, and value as inputs and returns output embeddings of the tokens.

 1# init Q, K, V with input_embs
 2query = key = value = input_embs
 3
 4def scaled_dot_product_attention(query, key, value):
 5    # get key dim: hidden_dim
 6    dim_k = key.size(-1)  # last dim
 7
 8    # attention weights
 9    weights = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
10
11    # attention scores
12    scores = F.softmax(weights, dim=-1)
13
14    # output embeddings
15    return torch.bmm(scores, value)

The code below implements MHA with 12 heads. We can instantiate a MultiHeadAttention object with the pre-trained model config and call it on input_embs to get the attention outputs that represent the input sequence.

Note that we don’t need to specifically use methods such as encode (or however you name it) to get the output — when calling a model object on some input data (hidden_state), it invokes the forward function and returns the output x.

 1class AttentionHead(nn.Module):
 2    def __init__(self, embed_dim, head_dim):
 3        super().__init__()
 4        # init 3 independent linear layers
 5        self.q = nn.Linear(embed_dim, head_dim)
 6        self.k = nn.Linear(embed_dim, head_dim)
 7        self.v = nn.Linear(embed_dim, head_dim)
 8
 9    def forward(self, hidden_state):
10        attn_outputs = scaled_dot_product_attention(
11            self.q(hidden_state), self.k(hidden_state), self.v(hidden_state)
12        )
13        return attn_outputs
14
15
16class MultiHeadAttention(nn.Module):
17    def __init__(self, config):
18        super().__init__()
19        # embedding size is 768 in the case of "bert-base-uncased"
20        embed_dim = config.hidden_size
21
22        # conventionally, hidden_size is divisible by num_heads
23        num_heads = config.num_attention_heads
24
25        # if we have 12 heads, each head get 768 // 12 = 54 hidden_dim
26        head_dim = embed_dim // num_heads
27
28        # create a list of attention heads
29        self.heads = nn.ModuleList(
30            [AttentionHead(embed_dim, head_dim) for _ in range(num_heads)]
31        )
32
33        # final linear layer
34        self.output_linear = nn.Linear(embed_dim, embed_dim)
35
36    def forward(self, hidden_state):
37        # concat output from each head on the last dim
38        x = torch.cat([h(hidden_state) for h in self.heads], dim=-1)
39        # pass through final linear layer
40        x = self.output_linear(x)
41        return x
42
43# use model config from the beginning
44multihead_attn = MultiHeadAttention(config)
45
46# attention outputs concatenated from 12 heads
47attn_output = multihead_attn(input_embs)

Token Embeddings

While we’re at it, let’s finish coding up the rest of the transformer decoder. The attention outputs are passed through a feedforward network (FFN).

 1class FeedForward(nn.Module):
 2    def __init__(self, config):
 3        super().__init__()
 4        self.linear_1 = nn.Linear(config.hidden_size, config.intermediate_size)
 5        self.linear_2 = nn.Linear(config.intermediate_size, config.hidden_size)
 6        self.gelu = nn.GELU()
 7        self.dropout = nn.Dropout(config.hidden_dropout_prob)
 8
 9    def forward(self, x):
10        x = self.linear_1(x)
11        x = self.gelu(x)
12        x = self.linear_2(x)
13        x = self.dropout(x)
14        return x

Before FNN, we apply layer normalization to ensure each input has 0 mean and unity (1) variance. Moreover, to preserve information throughout the layers and alleviate the vanishing gradient problem, we can apply skip connections, which pass a tensor to the next layer without processing (x) and add it to the processed tensor.

 1class TransformerEncoderLayer(nn.Module):
 2    def __init__(self, config):
 3        super().__init__()
 4        self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
 5        self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
 6        self.attention = MultiHeadAttention(config)
 7        self.feed_forward = FeedForward(config)
 8
 9    def forward(self, x):
10        # apply layer norm on input
11        hidden_state = self.layer_norm_1(x)
12        # apply attention with skip connection
13        x = x + self.attention(hidden_state)
14        # apply feedforward with skip connection
15        x = x + self.feed_forward(self.layer_norm_2(x))
16        # return processed
17        return x

Positional Encoding

Token embeddings from TransformerEncoderLayer agnostic to positional information, which can be injected via positional encoding. Each position (absolute or relative) in the input sequence is represented by a unique embedding, which is learned or fixed (such as sinusoidal waves below):

  • Even-indexed positions: $PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)$
  • Odd-indexed positions: $PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)$

Check out the Machine Learning Mastery tutorial and Lilian Weng’s blogpost. The code below uses positional encoding that comes with the pre-trained model.

 1class Embeddings(nn.Module):
 2    def __init__(self, config):
 3        super().__init__()
 4        # look up token and position embeddings
 5        self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
 6        self.position_embeddings = nn.Embedding(
 7            config.max_position_embeddings, config.hidden_size
 8        )
 9        # define layernorm and dropout layers
10        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
11        self.drop_out = nn.Dropout()
12
13    def forward(self, input_ids):
14        # length of the input sequence
15        seq_length = input_ids.size(1)
16        # position id: [0 to seq_length - 1]
17        position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0)
18        # look up embeddings by id
19        token_embeddings = self.token_embeddings(input_ids)
20        position_embeddings = self.position_embeddings(position_ids)
21        # add up token and position embeddings
22        embeddings = token_embeddings + position_embeddings
23        # pass through layer norm and dropout
24        embeddings = self.layer_norm(embeddings)
25        embeddings = self.drop_out(embeddings)
26        return embeddings

The code below is the final encoder of our vanilla transformer.

 1class TransformerEncoder(nn.Module):
 2    def __init__(self, config):
 3        super().__init__()
 4        self.embeddings = Embeddings(config)
 5        # repeat 12 times
 6        self.layers = nn.ModuleList(
 7            [TransformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]
 8        )
 9
10    def forward(self, x):
11        x = self.embeddings(x)
12        for layer in self.layers:
13            x = layer(x)
14        return x

Sequence Classification

A common use case of the transformer is sequence classification, which maps input embeddings (token + positional) to probabilities of class labels.

 1class TransformerForSequenceClassification(nn.Module):
 2    def __init__(self, config):
 3        super().__init__()
 4        self.encoder = TransformerEncoder(config)
 5        self.dropout = nn.Dropout(config.hidden_dropout_prob)
 6        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
 7
 8    def forward(self, x):
 9        # select [CLS] token
10        x = self.encoder(x)[:, 0, :]
11        # apply dropout on embedding
12        x = self.dropout(x)
13        # pass through classification layer
14        x = self.classifier(x)
15        return x
16
17# specify number of labels
18config.num_labels = 3
19encoder_classifier = TransformerForSequenceClassification(config)
20encoder_classifier(inputs.input_ids).size()

Resources

  1. Probabilistic Machine Learning: An Introduction (2023) by Kevin P. Murphy
  2. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., … & Polosukhin, I. (2017). Attention is all you need. Advances in Neural Information Processing Systems, 30.
  3. Natural Language Processing with Transformers (2022) by Hugging Face
  4. A Gentle Introduction to Positional Encoding in Transformer Models (Part I, Part II)
  5. The Transformer Family Version 2.0 by Lilian Weng