Attention as Soft Dictionary Lookup
The Dictionary Metaphor ππ
By now, the scaled-dot product attention formula might have burned into our brains π§ , $\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$, yet still this brilliant metaphor from Kevin Murphy’s PML book gives it a refreshed interpretation —
“We can think of attention as a soft dictionary look up, in which we compare the query $q$ to each key $k_i$, and then retrieve the corresponding value $v_i$.” — Chapter 15, p. 513
In a real dictionary, we query values by key and only grab the value whose key matches the query (dict[query]
). The attention mechanism, however, allows us to grab values from multiple keys and return their weighted sum for the given query.
The more compatible a key is to the query, the higher the “attention weight” we assign to the key’s value. The $i$th key’s attention weight is $\frac{\exp(\alpha(q, k_i))}{\sum_{j=1}^{m}\exp(a(q, k_j))}$ — attention weights of all keys sum to 1. The numerator of the attention weight, $\alpha(q, k_i)$, is the “attention score”. For masked tokens, we can set attention scores to a large negative number (e.g., $-10^6$) so that their attention weights will turn out close to 0.
The most popular attention score function is the dot product between $q$ and $k_i$, scaled by the square root of the k dimension, $\sqrt{d_k}$. The scaled dot-product attention is used by the original transformer (Vaswani et al., 2017) and many of its descendants. You can replace $\alpha(q, k_i)$ with other similarity functions.
Hugging Face Implementation π¦π€
I heard that in increasingly more MLE interviews are candidates asked to code components of the transformer architecture from scratch. I find the implementation in the Hugging Face book the easiest to follow and perhaps closest to the day-to-day coding style of NLP practitioners. Let’s begin by reviewing key concepts such as self-attention and multi-headed attention and then code up a basic TransformerForSequenceClassification
model relying on the transformer encoder.
Self-Attention
The transformer encoder uses the self-attention between each input token and all other input tokens to create contextual token embeddings. Before encoding, homophones (“flies” in “time flies like an arrow” and “fruit flies like a banana”) have the same initial embedding. In the first sentence, however, the token “flies” attends most strongly to “time” and “arrow”, so its contextual embedding will be close to that of these two, whereas in the second sentence, “flies” attends most strongly to “fruit” and “banana”, so its embedding will be close to theirs.
After encoding, tokens attain new embeddings from associated tokens — as the developmental psychologist Jean Piaget put it, “Through others we become ourselves."
Multi-Headed Attention
The original transformer paper pioneered the multi-head attention (MHA), where each “head” could capture a different notion of similarity (e.g., semantic, syntactic). The query $Q$, key $K$, and value $V$ matrices are split along the embedding dimension, $d_{model}$, and each split with the embedding size $d_{model} / h$ is fed to each head. Outputs from each head are concatenated to form a single output tensor — $\text{MHA}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \dots, \text{head}_h)W^O$ — before being passed to the linear layer.
“By dividing the
hidden_dim
, each head indeed sees the entire sequence (seq_len
) but only a “slice” or portion of each token’s embedding dimension (head_dim
). This design enables the model to parallelly attend to information from different representation subspaces at different positions, enriching the model’s ability to capture diverse relationships within the data.” — Natural Language Processing with Transformers
Instead of starting with random embeddings for each token, we can use the tokenizer of pre-trained model, say bert-base-uncased
, to encode the input text.
1from math import sqrt
2import torch
3from torch import nn
4import torch.nn.functional as F
5from transformers import AutoTokenizer, AutoModel, AutoConfig
6
7# load tokenizer from model checkpoint
8model_ckpt = "bert-base-uncased"
9tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
10
11# load config associated with given model
12config = AutoConfig.from_pretrained(model_ckpt)
13
14# input text
15text = "time flies like an arrow"
16
17# tokenize input text
18inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
19
20# nn.Embedding is a lookup table to find embeddings of each input_id
21token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
22
23# look up embeddings by id
24input_embs = token_emb(inputs.input_ids)
For simplicity, we can use the same linear projection of the input embedding (input_embs
) for $Q$ (query
), $K$ (key
), $V$ (value
). In practice, we usually use 3 different linear projections. The scaled_dot_product_attention
function below takes query
, key
, and value
as inputs and returns output embeddings of the tokens.
1# init Q, K, V with input_embs
2query = key = value = input_embs
3
4def scaled_dot_product_attention(query, key, value):
5 # get key dim: hidden_dim
6 dim_k = key.size(-1) # last dim
7
8 # attention weights
9 weights = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
10
11 # attention scores
12 scores = F.softmax(weights, dim=-1)
13
14 # output embeddings
15 return torch.bmm(scores, value)
The code below implements MHA with 12 heads. We can instantiate a MultiHeadAttention
object with the pre-trained model config and call it on input_embs
to get the attention outputs that represent the input sequence.
Note that we don’t need to specifically use methods such as encode
(or however you name it) to get the output — when calling a model object on some input data (hidden_state
), it invokes the forward
function and returns the output x
.
1class AttentionHead(nn.Module):
2 def __init__(self, embed_dim, head_dim):
3 super().__init__()
4 # init 3 independent linear layers
5 self.q = nn.Linear(embed_dim, head_dim)
6 self.k = nn.Linear(embed_dim, head_dim)
7 self.v = nn.Linear(embed_dim, head_dim)
8
9 def forward(self, hidden_state):
10 attn_outputs = scaled_dot_product_attention(
11 self.q(hidden_state), self.k(hidden_state), self.v(hidden_state)
12 )
13 return attn_outputs
14
15
16class MultiHeadAttention(nn.Module):
17 def __init__(self, config):
18 super().__init__()
19 # embedding size is 768 in the case of "bert-base-uncased"
20 embed_dim = config.hidden_size
21
22 # conventionally, hidden_size is divisible by num_heads
23 num_heads = config.num_attention_heads
24
25 # if we have 12 heads, each head get 768 // 12 = 54 hidden_dim
26 head_dim = embed_dim // num_heads
27
28 # create a list of attention heads
29 self.heads = nn.ModuleList(
30 [AttentionHead(embed_dim, head_dim) for _ in range(num_heads)]
31 )
32
33 # final linear layer
34 self.output_linear = nn.Linear(embed_dim, embed_dim)
35
36 def forward(self, hidden_state):
37 # concat output from each head on the last dim
38 x = torch.cat([h(hidden_state) for h in self.heads], dim=-1)
39 # pass through final linear layer
40 x = self.output_linear(x)
41 return x
42
43# use model config from the beginning
44multihead_attn = MultiHeadAttention(config)
45
46# attention outputs concatenated from 12 heads
47attn_output = multihead_attn(input_embs)
Token Embeddings
While we’re at it, let’s finish coding up the rest of the transformer decoder. The attention outputs are passed through a feedforward network (FFN).
1class FeedForward(nn.Module):
2 def __init__(self, config):
3 super().__init__()
4 self.linear_1 = nn.Linear(config.hidden_size, config.intermediate_size)
5 self.linear_2 = nn.Linear(config.intermediate_size, config.hidden_size)
6 self.gelu = nn.GELU()
7 self.dropout = nn.Dropout(config.hidden_dropout_prob)
8
9 def forward(self, x):
10 x = self.linear_1(x)
11 x = self.gelu(x)
12 x = self.linear_2(x)
13 x = self.dropout(x)
14 return x
Before FNN, we apply layer normalization to ensure each input has 0 mean and unity (1) variance. Moreover, to preserve information throughout the layers and alleviate the vanishing gradient problem, we can apply skip connections, which pass a tensor to the next layer without processing (x
) and add it to the processed tensor.
1class TransformerEncoderLayer(nn.Module):
2 def __init__(self, config):
3 super().__init__()
4 self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
5 self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
6 self.attention = MultiHeadAttention(config)
7 self.feed_forward = FeedForward(config)
8
9 def forward(self, x):
10 # apply layer norm on input
11 hidden_state = self.layer_norm_1(x)
12 # apply attention with skip connection
13 x = x + self.attention(hidden_state)
14 # apply feedforward with skip connection
15 x = x + self.feed_forward(self.layer_norm_2(x))
16 # return processed
17 return x
Positional Encoding
Token embeddings from TransformerEncoderLayer
agnostic to positional information, which can be injected via positional encoding. Each position (absolute or relative) in the input sequence is represented by a unique embedding, which is learned or fixed (such as sinusoidal waves below):
- Even-indexed positions: $PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)$
- Odd-indexed positions: $PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)$
Check out the Machine Learning Mastery tutorial and Lilian Weng’s blogpost. The code below uses positional encoding that comes with the pre-trained model.
1class Embeddings(nn.Module):
2 def __init__(self, config):
3 super().__init__()
4 # look up token and position embeddings
5 self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
6 self.position_embeddings = nn.Embedding(
7 config.max_position_embeddings, config.hidden_size
8 )
9 # define layernorm and dropout layers
10 self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
11 self.drop_out = nn.Dropout()
12
13 def forward(self, input_ids):
14 # length of the input sequence
15 seq_length = input_ids.size(1)
16 # position id: [0 to seq_length - 1]
17 position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0)
18 # look up embeddings by id
19 token_embeddings = self.token_embeddings(input_ids)
20 position_embeddings = self.position_embeddings(position_ids)
21 # add up token and position embeddings
22 embeddings = token_embeddings + position_embeddings
23 # pass through layer norm and dropout
24 embeddings = self.layer_norm(embeddings)
25 embeddings = self.drop_out(embeddings)
26 return embeddings
The code below is the final encoder of our vanilla transformer.
1class TransformerEncoder(nn.Module):
2 def __init__(self, config):
3 super().__init__()
4 self.embeddings = Embeddings(config)
5 # repeat 12 times
6 self.layers = nn.ModuleList(
7 [TransformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]
8 )
9
10 def forward(self, x):
11 x = self.embeddings(x)
12 for layer in self.layers:
13 x = layer(x)
14 return x
Sequence Classification
A common use case of the transformer is sequence classification, which maps input embeddings (token + positional) to probabilities of class labels.
1class TransformerForSequenceClassification(nn.Module):
2 def __init__(self, config):
3 super().__init__()
4 self.encoder = TransformerEncoder(config)
5 self.dropout = nn.Dropout(config.hidden_dropout_prob)
6 self.classifier = nn.Linear(config.hidden_size, config.num_labels)
7
8 def forward(self, x):
9 # select [CLS] token
10 x = self.encoder(x)[:, 0, :]
11 # apply dropout on embedding
12 x = self.dropout(x)
13 # pass through classification layer
14 x = self.classifier(x)
15 return x
16
17# specify number of labels
18config.num_labels = 3
19encoder_classifier = TransformerForSequenceClassification(config)
20encoder_classifier(inputs.input_ids).size()
Resources
- Probabilistic Machine Learning: An Introduction (2023) by Kevin P. Murphy
- Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., … & Polosukhin, I. (2017). Attention is all you need. Advances in Neural Information Processing Systems, 30.
- Natural Language Processing with Transformers (2022) by Hugging Face
- A Gentle Introduction to Positional Encoding in Transformer Models (Part I, Part II)
- The Transformer Family Version 2.0 by Lilian Weng