多头注意力 入门
一个 attention 头只抓住 token 之间一种关系。真实语言有很多种 —— 主谓一致、代词指代、局部位置、语义相似。多头注意力(multi-head attention)把 h个 attention 并行算,各自在自己的子空间里跑,最后把结果混在一起。 4 个短主题:为啥要多头;Q/K/V 怎么拆成h 个头;每个头并行做 attention; 最后用 concat + 投影 W_O 合起来。
为啥要「多头」?
一个头只看一种关系。真实句子里同时有好多种。
在自注意力 primer 里,我们把一个 attention 头从头到尾走完了一遍。能工作。 在「the cat sat」上,cat 这一行把 ~67% 的权重放给了 sat —— 懂动词的名词正该这样。到这都很好。但句子里的关系不止一种。
- cat 想找自己的动词(sat)。
- cat 紧前面还有一个限定词(the)—— 一种位置关系。
- sat 有主语(cat),后面可能还有补语。
- 长距代词消解:「she」可能要往回看 50 个 token 才找到 「Mary」。
- 语义聚类:跟食物相关的词、跟情绪相关的词、跟时间相关的词。
一个 attention 头只能给每个 token 输出一组 softmax 权重。 分数矩阵的同一行要同时干所有的关系追踪。只有一个头时,你被逼着挑: 这个 token 是看自己的动词、看自己的限定词、还是看自己的指代词? 没法干净地同时做不止一种。
解法。并行跑 h 个 attention。每个有自己的 (W_Q、W_K、W_V) 投影矩阵 —— 每个住在自己的子空间里,可以专攻自己的模式。 一个头最后是句法头、另一个是局部位置头、再一个是长距指代头。 谁专攻什么,模型在训练里自己琢磨出来,不需要显式监督。
原始 Transformer 用 h = 8。GPT-3 用 h = 96。 最大的那些模型每层用 64–128 个头。每个头都很小(通常 d_k = d_model / h = 64),所以总算力跟一个大头差不多 —— 但同时编码多种关系的容量高得多。
解读上的一个注意。「第 i 个头是句法头」这种说法 是个有用的简化,不是字面真理。实际中模型把同一个模式散到多个头上, 头在训练中竞争、特化,很多头其实部分冗余(剪掉模型几乎不变)。 但「多个头抓不同关系」这个心智模型, 作为「为啥要多个头」的直觉是对的。
拆:d_model → h 个 d_k 的头
特征维度切成 h 份。每个头一份子空间,token 数都一样。
多头注意力最朴素的实现也最好画。把上一篇 primer 的 Q、K、V 拿过来 —— 每个是形状 (n × d_model):n 是序列长度,d_model 是输入 embedding 维度。按列切成 h 等份。 每一份是形状 (n × d_k),其中 d_k = d_model / h。
具体到小例子:d_model = 4、h = 2,所以 d_k = 2。原本 3 × 4 的 Q 矩阵变成两个 3 × 2 的矩阵,叫Q₁ 和 Q₂。K 和 V 用一样的方式拆 —— 我们最终拿到 6 个 每头一个的矩阵,而不是 3 个满宽的。
一个不那么显然的点。每个头只看到特征维度的自己那份, 但它仍然看到所有 token。切是沿着特征轴切的,从来不沿着序列轴切。 所以 head 1 不是「负责处理句子前半的那个头」。所有头看所有 token —— 区别只是每个头能用哪些 embedding 维度。
实现上的小注释。实际上没人真去把 Q 算满宽后再切。 W_Q 本身是 (d_model × d_model) 大小,但解读成 h 个 独立的 (d_model × d_k) 块。前向通常把 (n × d_model)用一次 view() 或 reshape() 变成 (n × h × d_k),再换轴成 (h × n × d_k), 把 h 当作批量维去跑后面几步。数学上跟「切开各算各的」一致, 但 GPU 能一次批量矩阵乘搞定。
为啥等分?因为输入 embedding 还没有区分。 训练开始之前,模型没有任何理由让 head 1 比 head 2 大。 等分是对称先验 —— 训练里每个头会在自己那块里自己找到自己的特化。
并行:每个头各做各的 attention
同一个操作,做 h 次。每个头,各自的子空间、各自的模式。
Q、K、V 拆完之后,每个头 i 各拿到自己的 (Q_i, K_i, V_i)。 每个头各自做完全一样的缩放点积注意力 —— 跟自注意力 primer 最后写下的那个公式 一模一样,只是在自己那块上做:
head_i = softmax( Q_i · K_iᵀ / √d_k ) · V_i
每个头三次矩阵乘加一次 softmax,跟之前一样。唯一变化是我们做 h次 —— 一头一次 —— 而且这一步里头之间互不交互。没有共享权重、 没有共享分数矩阵、没有信息交叉。它们在同样的输入 token 上,各算各的。
因为头之间独立,所以它们能 —— 在 GPU 上实际上 —— 并行跑。实现里把 Q、K、V 重整,让头维度当作批量维; 同一个缩放点积注意力 kernel 一次融合调用就处理掉所有 h 个头。 在同等总 d_model 下,多头注意力跟单头注意力代价差不多 —— 你白白多得了一份表达力。
每个头实际学到了什么?这个被研究得很透。简短版本: 训得好的 Transformer 里,你可以探一探每个头,会发现很多专攻 相当好解读的模式:
- 位置头。有些头纯粹关注上一个 token、下一个 token, 或者句首。纯位置、不看内容。
- 句法头。有些从动词指向主语、代词指向先行词、 从句指向中心名词。
- 语义头。有些强烈关注主题相似的 token —— 「食物」类的找到其他食物类,即使隔很远。
- 归纳头(induction heads)。一个著名类别: 识别重复模式、按类比复制 token。(Anthropic 的可解释性研究表明 这是上下文学习的核心。)
再次重提 §1 的注意。这些特化是涌现的、不是设计出来的。 没人告诉模型「第 5 个头应该当位置头」—— 训练自己走到了那里。 很多头是冗余的;著名的 Voita 等人结果是:你可以剪掉训好的 Transformer 差不多一半的头,质量几乎不变。多头结构是一种有用的归纳偏置 —— 给模型学到多种模式的容量 —— 而不是某种固定的「每头一个含义」。
合:concat 完再乘 W_O
把各个头拼回去,再用最后一个学过的矩阵把它们混在一起。
并行 attention 那一步算完,我们有 h 个输出矩阵,每个头一个, 各自形状 (n × d_k)。为了把它们传给 Transformer 下一层, 我们需要把它们叠回一个 (n × d_model) 的张量,跟输入同形。 两步就够。
- Concatenate(拼接)。把 h 个头的输出 沿着特征轴并排叠起来。h 个 (n × d_k) 变成一个 (n × h · d_k) = (n × d_model)。形状跟 Q/K/V/输入 embedding 一样了 —— 好。但维度只是叠起来,没混。head 1 的输出占据前 d_k个维度;head 2 占接下来 d_k;依此类推。
- 乘 W_O 投影。一个 (d_model × d_model) 的学过的 权重矩阵。把拼接结果乘 W_O,每个输出维度都可以是每个头贡献的混合。这一步才是各个头终于「对话」的地方。
没有 W_O,各个头永远互不交流 —— head 1 的信息永远住在维度 0–63、 head 2 的住在 64–127,层接一层。W_O 投影让模型学到应该放大哪些头、压制哪些头、 以及怎么跨维度组合。是个小矩阵(d_model² 个参数)在悄悄干一件重要的活。
完整公式。把 4 节合起来,整个多头注意力块是:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_O where head_i = softmax( Q_i · K_iᵀ / √d_k ) · V_i and Q_i = Q · W_Q^i, K_i = K · W_K^i, V_i = V · W_V^i
每层 4 个权重矩阵:W_Q、W_K、W_V、W_O。每个都是 (d_model × d_model)。 每层的 attention 一共 4 · d_model² 个参数 —— GPT-2 small 里 d_model = 768,每层 attention 大约 240 万参数。叠 12 层, 光 attention 就占了模型 1.17 亿参数里的约 2800 万。围着 attention 的 前馈 MLP 比这更多 —— 但那是下一篇 primer 的事了。
现代变种。这里讲的「原味」多头注意力是教科书版。 真部署里有几种优化省显存、省推理代价:
- Multi-Query Attention(MQA)。所有头共享一份 K 和 V —— 只有 Q 是每头一份。把 KV 缓存(推理时最大头的显存代价) 缩小到原来的 1/h。PaLM 用的就是这个。
- Grouped-Query Attention(GQA)。一种折中: 头被分成 g 组,每组共享一份 K 和 V。 现代绝大多数开源 LLM(Llama 2/3、Mistral)用 GQA —— 几乎跟 MQA 一样快, 但相比完整多头质量几乎不掉。
- FlashAttention。不是数学上的改动 —— 是 kernel 级重写, 把分数矩阵、缩放、softmax、加权和融成一个流式计算,避免把完整 n × n的 attention 矩阵写到显存。结果一字不差,但显存带宽用量大幅下降。 每个现代训练栈的标配。
走到这里。18 篇 primer 下来,我们已经追完了一个 attention 块里数据的全过程:token 怎么变成 embedding,embedding 怎么变成 Q/K/V, Q/K/V 怎么跨头切分,每个头怎么产出一个 attention 模式, 头怎么合回每个 token 一个向量。下一篇 primer 还剩什么:把 attention 包成一个完整 Transformer 块的残差连接、layer norm、前馈网络。 难的部分已经在身后了。