Multi-Head Attention Primer
One head of attention captures one kind of relationship between tokens. Real language has many — subject-verb agreement, pronoun-antecedent, local positional, semantic similarity. Multi-head attention runs h attention computations in parallel, each in its own sub-space, and mixes the results. Four short topics: why multi-head; the split of Q/K/V into h heads; parallel per-head attention; and the final combine via concat + a projection W_O.
Why “Multi-Head”?
One head sees one kind of relationship. Real sentences have many at once.
In the self-attention primer, we built one attention head end-to-end. It worked. On “the cat sat”, the row for cat ended up putting ~67% of its weight on sat — exactly what a verb-aware noun should do. So far so good. But sentences are richer than one relationship.
- Cat wants to find its verb (sat).
- Cat also has a determiner right before it (the) — a positional relationship.
- Sat has a subject (cat) and possibly a complement downstream.
- Long-range pronoun resolution: “she” may need to look 50 tokens back to find “Mary”.
- Semantic clustering: words about food, words about emotion, words about time.
A single attention head can only emit one set of softmax weights per token. The same row of the score matrix has to do all the relationship-tracking at once. With one head you're forced to pick: does this token attend to its verb, or its determiner, or its co-reference? You can't cleanly do more than one.
The fix. Run h attention computations in parallel. Each gets its own set of (W_Q, W_K, W_V) projection matrices — so each lives in its own subspace and can specialize on its own pattern. One head ends up being a syntactic head; another a local-positional head; another a long-range anaphora head. The model figures out who specializes in what during training, with no explicit supervision.
The original Transformer used h = 8. GPT-3 used h = 96. The very largest models use 64–128 heads per layer. Each head is small (often d_k = d_model / h = 64) so the total compute is roughly the same as a single big head — but the capacity to encode many simultaneous relationships is much higher.
The interpretive caveat. “Head i is the syntactic head” is a useful simplification, not a literal claim. In practice, the model spreads any given pattern across multiple heads, heads compete and specialize during training, and many heads turn out to be partly redundant (you can prune them and the model barely changes). Still, the mental model of “heads capture different relationship types” is the right intuition for why we have multiple heads at all.
Split: d_model → h Heads of d_k
Slice the feature dimensions h ways. Each head gets its own subspace, the same number of tokens.
The simplest implementation of multi-head attention is also the easiest to picture. Take the Q, K, V matrices from the previous primer — each is shape (n × d_model), where n is sequence length and d_model is the input embedding dimension. Slice each by columns into h equal pieces. Each piece is shape (n × d_k) with d_k = d_model / h.
Concretely on our running example: d_model = 4, h = 2, so d_k = 2. The Q matrix that was 3 × 4 becomes two matrices each 3 × 2, called Q₁ and Q₂. The same split applies to K and V — we end up with six per-head matrices instead of three full-width ones.
The non-obvious bit. Each head sees only its own slice of the feature dimensions, but it still sees every token. Slicing happens along the feature axis, never along the sequence axis. So head 1 isn't “the head that processes the first half of the sentence.” All heads look at all tokens — what differs is which embedding dimensions each head is allowed to use.
Implementation note. In practice no one literally slices Q after computing it at full width. Instead, the W_Q matrix is itself sized (d_model × d_model) but interpreted as h separate (d_model × d_k) blocks. The forward pass usually reshapes (n × d_model) into (n × h × d_k) with a single view() or reshape() call, permutes the axes to (h × n × d_k), and runs the next steps with h as a batch dimension. The math is identical to “split and run separately,” but a GPU can do it in one batched matrix multiply.
Why equal slices? Because the input embedding is not yet differentiated. The model has no reason to make head 1 bigger than head 2 before training starts. Equal slices is the symmetric prior — and during training each head discovers its own specialization within its slice.
Parallel: Each Head Does Its Own Attention
Same operation, h times. Each head, its own subspace, its own pattern.
Once Q, K, V are split, each head i has its own (Q_i, K_i, V_i). Each then runs the exact same scaled dot-product attention we wrote down at the end of the self-attention primer — just on its own slice:
head_i = softmax( Q_i · K_iᵀ / √d_k ) · V_i
Three matrix multiplies and one softmax per head, exactly as before. The only thing that changes is that we do it h times — once per head — and the heads do not interact during this step. No shared weights, no shared score matrix, no cross-talk. They're entirely independent computations on the same input tokens.
Because the heads are independent, they can — and on a GPU, do — run in parallel. The implementation reshapes Q, K, V to treat the head dimension as a batch dimension; the same scaled dot-product attention kernel processes all h heads in a single fused call. Multi-head attention is roughly the same cost as single-head attention with the same total d_model — you've just bought more expressiveness for free.
What does each head actually learn? This has been studied extensively. The short version: in well-trained Transformers, you can probe each head and find that many specialize on remarkably interpretable patterns:
- Positional heads. Some heads attend purely to the previous token, or to the next token, or to the start of the sentence. Pure position, ignoring content.
- Syntactic heads. Some attend from a verb to its subject, or from a pronoun to its antecedent, or from a clause to its head noun.
- Semantic heads. Some attend strongly to topically similar tokens — “food” tokens find other food tokens, even across long distances.
- Induction heads. A famous class: heads that recognize repeated patterns and copy a token by analogy. (Anthropic interpretability research showed these are central to in-context learning.)
And the caveat from §1 again. These specializations are emergent, not designed. The model wasn't told “head 5 should be the positional head” — training arrived there. Many heads are redundant; the famous Voita et al. result is that you can prune ~half the heads from a trained Transformer and lose almost no quality. The multi-head structure is a useful inductive bias — it gives the model the capacity to learn diverse patterns — not a fixed per-head meaning.
Combine: Concat, Then Project Through W_O
Stack the head outputs back together, then mix them with a final learned matrix.
After the parallel attention step, we have h output matrices, one per head. Each is shape (n × d_k). To pass these on to the next layer of the Transformer, we need to fold them back into a single tensor of the original input shape (n × d_model). Two operations get us there.
- Concatenate. Stack the h head outputs side by side along the feature axis. h matrices of size (n × d_k) become one matrix of size (n × h · d_k) = (n × d_model). Same shape as Q/K/V/the input embedding — good. But the dimensions are stacked, not mixed. Head 1's output occupies the first d_k dimensions; head 2's the next; etc.
- Project through W_O. A learned (d_model × d_model) weight matrix. Multiplying the concat-result by W_O lets every output dimension be a mixture of every head's contribution. This is where the heads finally interact.
Without W_O, the heads would never talk to each other — head 1's information would live in dimensions 0–63 forever, head 2's in 64–127, etc., layer after layer. The W_O projection lets the model learn which heads' outputs to amplify, which to suppress, and how to combine them across dimensions. It's a small matrix (d_model² parameters) doing a quietly important job.
The full formula. Putting all four sections together, the entire multi-head attention block is:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_O where head_i = softmax( Q_i · K_iᵀ / √d_k ) · V_i and Q_i = Q · W_Q^i, K_i = K · W_K^i, V_i = V · W_V^i
Four weight matrices per layer: W_Q, W_K, W_V, W_O. Each is (d_model × d_model). That's 4 · d_model² parameters per attention layer — for GPT-2 small with d_model = 768, that's about 2.4M parameters per layer's attention. Stack 12 layers and attention alone is ~28M of the model's ~117M total. The feedforward MLPs that surround attention have even more — but that's the next primer.
Modern variants. Vanilla multi-head attention as described here is the textbook setup. Real deployment introduces optimizations that save memory and inference cost:
- Multi-Query Attention (MQA). All heads share a single K and V — only Q is per-head. Shrinks the KV cache (the dominant memory cost for inference) by a factor of h. Used in PaLM.
- Grouped-Query Attention (GQA). A compromise: heads are partitioned into g groups, each group shares K and V. Most modern open-source LLMs (Llama 2/3, Mistral) use GQA — it's almost as fast as MQA but barely loses quality vs full multi-head.
- FlashAttention. Not a change to the math — a kernel-level rewrite that fuses the score matrix, scaling, softmax, and weighted sum into one streaming computation, avoiding writing the full n × n attention matrix to memory. Same numerical result, dramatically less memory traffic. Standard in every modern training stack.
Where we are. Eighteen primers in, we've now traced the entire data flow inside one attention block: how tokens become embeddings, how embeddings become Q/K/V, how Q/K/V split across heads, how each head produces an attention pattern, and how the heads combine back into one vector per token. What's left for the next primer: the residual connections, layer norm, and feedforward network that wrap attention into a full Transformer block. The hard part is now behind us.