The attention mechanism

What makes transformers work

2,173 words11 min read

Before 2017, sequence models processed data step by step - each position attending only to what came before. Then the transformer architecture introduced self-attention, allowing every position to directly attend to every other position. This simple change revolutionized natural language processing and now powers everything from ChatGPT to image generators.

The problem with sequential processing

Recurrent Neural Networks process sequences one element at a time, maintaining a hidden state that accumulates information. To understand the word 'it' in 'The cat sat on the mat because it was tired,' the network must propagate information about 'cat' through every intermediate step.

This sequential bottleneck causes two problems. First, long-range dependencies degrade as information passes through many transformations. Second, sequential processing prevents parallelization - each step must wait for the previous step, making training slow.

The vanishing gradient problem is particularly severe in RNNs. At each time step, gradients are multiplied by the recurrent weight matrix. If the largest singular value of this matrix is less than 1, gradients decay exponentially. After 100 time steps, a gradient might shrink by a factor of 0.99^100 ≈ 0.37, making it nearly impossible to learn long-range dependencies.

LSTMs (Long Short-Term Memory) and GRUs (Gated Recurrent Units) mitigate this with gating mechanisms that control information flow. The cell state in an LSTM provides a highway for gradients to flow unchanged, allowing longer-range learning. However, these architectures are still fundamentally sequential, limiting parallelization.

The computational bottleneck is severe: processing a sequence of length n with an RNN requires O(n) sequential operations. Modern GPUs excel at parallel computation but sit idle waiting for sequential steps. A transformer, by contrast, processes all positions simultaneously, reducing sequential operations to O(1) for most computations.

Attention: Direct connections

Attention mechanisms create direct connections between any two positions in a sequence. Instead of routing through intermediate steps, the network can directly look at relevant context. For understanding 'it,' the model directly attends to 'cat' regardless of distance.

Click a token to see its attention pattern (Query):
Attention from "cat" to all tokens:
The
15.0%
cat
35.0%
sat
25.0%
on
5.0%
the
5.0%
mat
15.0%
Full attention matrix:
The
cat
sat
on
the
mat
The
40
12
12
12
12
12
cat
15
35
25
5
5
15
sat
5
40
30
10
5
10
on
12
12
12
40
12
12
the
12
12
12
12
40
12
mat
10
20
15
20
15
20
Interactive attention weights visualization. See how each word attends to other words in the sequence.

The key insight is that relevance is learned, not hard-coded. The network learns which positions to attend to based on content. A pronoun learns to attend to nouns; a verb learns to attend to its subject. These patterns emerge automatically from training data.

Think of attention as a soft lookup table. Traditional lookup tables have a key and return a single value. Attention has a query that compares against multiple keys, returning a weighted combination of all values based on how well each key matches. This soft matching allows the network to aggregate information from multiple sources.

Before transformers, attention was used to help RNN decoders focus on relevant parts of RNN encoder outputs - this was the attention mechanism in the original sequence-to-sequence models for translation. The transformer's innovation was using attention not just between encoder and decoder, but within each layer (self-attention), entirely replacing recurrence.

Queries, keys, and values

Self-attention uses three learned projections for each position: a query representing 'what am I looking for,' a key representing 'what do I contain,' and a value representing 'what information do I provide.' Attention computes compatibility between queries and keys to determine how much each value contributes.

Think of it like a search engine. Your query describes what you want. Each document has keywords, which serve as keys. When query matches keys, you retrieve the document content, which is the value. In self-attention, every position simultaneously queries every other position.

The queries, keys, and values are linear projections of the input: Q = XW_Q, K = XW_K, V = XW_V, where X is the input sequence (shape: seq_len × d_model) and W_Q, W_K, W_V are learned weight matrices (shape: d_model × d_k). The projection dimension d_k is typically 64 when using multi-head attention.

Why three separate projections? They allow the network to express different aspects of each position. The query projection emphasizes what information the position needs. The key projection emphasizes what information the position provides. The value projection determines what information is actually transmitted. These can be quite different.

import numpy as np

def self_attention(X, W_q, W_k, W_v):
    """
    Self-attention mechanism.
    X: input sequence [seq_len, d_model]
    W_q, W_k, W_v: projection matrices [d_model, d_k]
    """
    # Project inputs to queries, keys, values
    Q = X @ W_q  # [seq_len, d_k]
    K = X @ W_k  # [seq_len, d_k]
    V = X @ W_v  # [seq_len, d_v]
    
    # Compute attention scores
    d_k = K.shape[-1]
    scores = Q @ K.T / np.sqrt(d_k)  # [seq_len, seq_len]
    
    # Softmax to get attention weights
    weights = softmax(scores, axis=-1)
    
    # Weighted sum of values
    output = weights @ V  # [seq_len, d_v]
    
    return output, weights

def softmax(x, axis=-1):
    """Numerically stable softmax."""
    exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

Scaled dot-product attention

The attention formula is: Attention(Q, K, V) = softmax(QK^T / √d_k) V. The dot product QK^T computes similarity between each query and all keys. Softmax normalizes these scores to weights summing to 1. The weighted sum of values produces the output.

Why divide by √d_k? As dimensionality grows, dot products tend to grow in magnitude, pushing softmax into saturated regions where gradients vanish. Scaling by √d_k keeps values in a reasonable range. This small detail is crucial for stable training.

The matrix multiplication QK^T produces a seq_len × seq_len matrix of attention scores. For a sequence of 1000 tokens, this is a million-element matrix. For 100,000 tokens, it's 10 billion elements. This quadratic scaling is why context length is a major constraint for transformers.

Memory is often the bottleneck rather than computation. The attention matrix must be stored for the backward pass. FlashAttention, introduced in 2022, avoids materializing the full attention matrix by computing attention in blocks, dramatically reducing memory usage and improving speed through better GPU memory access patterns.

Multi-head attention

A single attention head can only focus on one type of relationship. Multi-head attention runs multiple attention operations in parallel, each with different learned projections, then concatenates and projects the results. This allows the model to jointly attend to information from different representation subspaces.

In practice, GPT-3 uses 96 attention heads. Different heads learn different roles: some track syntactic relationships, others semantic similarity, others positional patterns. The combination captures richer context than any single head could.

Multi-head attention doesn't increase computation significantly. Instead of using d_k = d_model for a single head, we use d_k = d_model / h for h heads. The total dimensionality remains the same, but split across heads. Each head operates in a smaller subspace, specializing in different relationship types.

Research has shown that different heads learn interpretable patterns. Some heads implement syntactic operations like dependency parsing. Others perform coreference resolution, linking pronouns to their referents. Some heads appear to do nothing useful - removing them has no effect on model performance.

def multi_head_attention(X, W_qs, W_ks, W_vs, W_o, num_heads=8):
    """
    Multi-head attention.
    Splits dimensions across heads, computes attention in parallel, combines.
    """
    d_model = X.shape[-1]
    d_k = d_model // num_heads
    
    heads = []
    for i in range(num_heads):
        # Each head gets a slice of the projection matrices
        Q = X @ W_qs[i]  # [seq_len, d_k]
        K = X @ W_ks[i]
        V = X @ W_vs[i]
        
        # Scaled dot-product attention for this head
        scores = Q @ K.T / np.sqrt(d_k)
        weights = softmax(scores, axis=-1)
        head_output = weights @ V
        heads.append(head_output)
    
    # Concatenate heads and apply output projection
    concatenated = np.concatenate(heads, axis=-1)  # [seq_len, d_model]
    output = concatenated @ W_o  # [seq_len, d_model]
    
    return output

Positional encoding

Self-attention is permutation invariant - it treats the sequence as a set, ignoring position. But position matters enormously in language: 'dog bites man' differs from 'man bites dog.' Positional encodings inject position information into the representation.

The original transformer uses sinusoidal encodings: PE(pos, 2i) = sin(pos / 10000^(2i/d)); PE(pos, 2i+1) = cos(pos / 10000^(2i/d)). Different frequencies for different dimensions create unique patterns for each position while allowing the model to learn relative positioning.

Why sinusoids? They have a useful property: PE(pos+k) can be expressed as a linear function of PE(pos). This lets the model learn to attend to relative positions - looking three positions back, regardless of absolute position. The geometric progression of frequencies (10000^(2i/d)) ensures positions remain distinguishable even for long sequences.

Modern transformers often use learned positional embeddings instead - simply learning a unique vector for each position. Rotary Position Embeddings (RoPE) encode position in the attention computation itself, enabling better length generalization.

RoPE rotates the query and key vectors based on position before computing attention. Queries and keys at positions i and j interact through rotation by angle (i-j)θ, making attention naturally depend on relative position. This elegant approach has become standard in modern LLMs like LLaMA.

Alibi (Attention with Linear Biases) takes another approach: instead of adding positional encodings to tokens, it adds a linear bias to attention scores based on distance. This simple modification enables length extrapolation - models can handle sequences longer than those seen in training.

The transformer block

A transformer block combines multi-head attention with a feed-forward network, connected by residual connections and layer normalization. The attention layer enables positions to communicate; the feed-forward network processes each position independently, adding nonlinear transformation capacity.

Residual connections add the input to the output: output = x + sublayer(x). This creates gradient highways that ease training of deep networks. Layer normalization stabilizes activations, typically applied before each sublayer in modern architectures.

The feed-forward network is surprisingly important - it often contains more parameters than the attention layers. In GPT-3, the FFN expands from 12288 dimensions to 49152 (4x expansion), then back to 12288. This expansion allows the network to represent more complex functions at each position.

Research suggests the FFN acts as a key-value memory, storing factual associations. The first layer's weights act as keys matching input patterns; the second layer's weights are values retrieved when patterns match. This explains why larger FFNs improve factual knowledge in language models.

Pre-norm versus post-norm refers to where layer normalization is applied. Post-norm (original transformer) applies LayerNorm after adding the residual. Pre-norm applies it before the sublayer. Pre-norm is more stable to train and has become standard, though post-norm can achieve slightly better final performance with careful tuning.

Causal masking for generation

For language generation, we cannot let positions attend to future tokens - that would be cheating. Causal masking sets attention weights to zero for future positions, creating a triangular attention pattern where each position only sees itself and earlier positions.

This masking enables efficient parallel training: we compute all positions simultaneously, but each position's output only depends on earlier positions. During inference, we generate one token at a time, but training processes entire sequences in parallel.

Encoder vs decoder architectures

The original transformer used both encoder and decoder stacks. The encoder processes input with bidirectional attention - every position sees every other position. The decoder generates output with causal attention, also attending to encoder outputs through cross-attention.

GPT models use decoder-only architecture - just causal self-attention, no encoder. BERT uses encoder-only architecture - bidirectional attention for understanding. T5 uses the full encoder-decoder for sequence-to-sequence tasks. Architecture choice depends on the application.

Decoder-only models like GPT have become dominant for generative AI. They're conceptually simpler - just predict the next token - and scale well. The same architecture works for chat, code generation, reasoning, and more. Pretraining is straightforward: maximize likelihood of the next token on massive text corpora.

Encoder-only models like BERT excel at understanding tasks: classification, extraction, similarity. Bidirectional context helps comprehension but makes generation awkward. BERT fills in masked tokens but can't naturally extend text. For embeddings and understanding, encoders remain valuable.

Computational complexity

Self-attention has O(n²) complexity in sequence length - every position attends to every other position. For long sequences, this becomes prohibitive. A 32,000 token context requires computing attention over a billion pairs.

Efficient attention variants reduce this cost. Sparse attention attends only to selected positions. Linear attention approximates softmax attention with linear complexity. Flash attention optimizes memory access patterns for GPU efficiency. These techniques enable transformers to handle increasingly long contexts.

Sliding window attention limits each position to attending within a fixed window, reducing complexity to O(n × w) where w is window size. Longformer combines sliding windows with global attention on special tokens. BigBird adds random attention connections for theoretical guarantees.

Ring attention distributes long sequences across multiple devices, with each device computing attention for its segment while communicating key-value caches with neighbors. This enables million-token contexts by parallelizing the memory-bound attention computation.

Why attention works so well

Several properties make attention powerful. Direct connections enable arbitrary-range dependencies. Soft attention is differentiable, allowing end-to-end training. The learned query-key matching discovers relevant patterns automatically. And parallelization across positions enables efficient GPU utilization.

Perhaps most importantly, attention is interpretable. We can visualize attention weights to see what the model focuses on. This reveals emergent behaviors: heads that track syntax, heads that resolve coreference, heads that attend to semantically similar words.

Attention is all you need.

Vaswani et al., 2017
How Things Work - A Visual Guide to Technology