自注意力 入门

Transformer 里的每个 token 都要看看其他所有 token,自己决定一句话:对我此刻而言,哪几个最关键?这个决定就是自注意力(self-attention):现代每一个 LLM 的核心操作。 4 个短主题:Query / Key / Value 心智模型;分数矩阵 Q · Kᵀ;带着那个著名 √d 除数和 softmax 的缩放点积注意力(scaled dot-product attention); 以及最终的 V 的加权和 —— 在一个小例子上从头走到尾。

01

Query、Key、Value

从一个 embedding 投影出三个 —— 每个 token 在 attention 里能扮演的三种角色。

每个 token 进入 Transformer 一层时,只是一个向量 —— embedding 加上位置信息。 为了做 attention,这一个向量被三次线性投影,扇出成三个独立的向量:Query(Q)、Key(K)、Value(V)。 名字借自信息检索,这个类比也意外地贴切。

  • Query(Q)。这个 token 在找什么。可以想成它发出去的一个搜索请求: 「我是个动词 —— 周围有哪个名词想当我的主语?」
  • Key(K)。这个 token 对外宣告了什么—— 一种像标签一样的自我总结, 别的 token 拿这个来判断它跟自己相不相关。想成公告板上的一张卡片: 「我是名词、第三人称、有生命。」
  • Value(V)。如果匹配成功,这个 token 贡献的内容。 Q 配上了 K,发起请求的那一方就把这一份 V 混进自己的表示里。
从一个 embedding —— 三个投影embeddings —— 每个 token 一行embeddingQueryKeyValuethe0.20.70.10.40.30.60.10.20.50.20.40.10.10.40.20.5cat0.80.30.50.20.70.40.50.30.60.70.30.20.50.60.30.1sat0.10.60.90.40.20.50.80.40.40.30.70.60.70.20.40.83 个 token,每个一行 4 维 embedding。
1 / 4
每个 token:x · W_Q → Query(我在找什么);x · W_K → Key(我对外提供什么);x · W_V → Value(匹配上我会贡献什么)。

三个投影是把输入 embedding x 乘以三个学到的权重矩阵得到的:Q = x · W_QK = x · W_KV = x · W_V。attention 学到的所有东西都在这三个权重矩阵里 —— 梯度下降会把它们塑造成「每个 token 找什么」和「每个 token 提供什么」 在语言建模任务下对得最齐的样子。

为啥要把一个 embedding 拆成三个视图?因为这三个角色本质上不一样。 一个动词想在自己主语身上看什么,跟它想对自己依赖的下游 宣告什么,跟它实际带给下游的信息,是三件不同的事。如果硬塞进一个向量, 模型就得让 embedding 空间里同一个方向同时干三件活。三个矩阵 = 三件活、各干各的。

一个关于维度的细节。原始 Transformer 里,输入 embedding 是 512 维,Q、K、V 也都是 512 维。 多头注意力(实际用的形式,后面 Transformer 那篇细讲)里,Q/K/V 被切成 h 个头、 每个头 d_k = d_model / h 维,各算各的 attention、互不打扰。 本篇 primer 只看单头注意力,用一个很小的 d_k = 4,矩阵小了才看得清。

本篇的小例子。全篇都用同一个玩具句子:「the cat sat」。3 个 token,Q/K/V 每个向量 4 维。每个矩阵 12 个数字, 小到能写在信封背面,又大到能把 attention 从头到尾每一步都演完。

02

分数矩阵:Q · Kᵀ

每个 Query 都遇到每个 Key。点积衡量它们配得有多好。

魔法步骤来了。对每个 token i,我们想要一个数,说出 「i 想看每个 j 的力度有多大」。线性代数里衡量两个向量相似度 最朴素的办法就拿来用 —— 点积

S[i, j] = Q[i] · K[j]。把 token i 的 Query 行、 token j 的 Key 行,逐元素相乘、相加。这一个数就是注意力分数(attention score)。又正又大 = 匹配得好;接近 0 = 不匹配; 是负的 = 不合。

分数矩阵 —— S = Q · Kᵀ,在 "the cat sat" 上Q(3×4)与 Kᵀ(4×3)—— 两个操作数Q (3×4)thecatsat110111110111·Kᵀ (4×3)thecatsat110111011011左:Q,每个 token 一行。右:Kᵀ,每个 token 一列。
1 / 4
每个 token 的 Query 和每个 token 的 Key 做点积。3×3 的 S[i,j] = token i 想看 token j 的强度。

对每一对 (i, j) 都算一遍,得到一个 n × n 矩阵 ——分数矩阵 S。写成矩阵形式就是漂亮紧凑的S = Q · Kᵀ:Qn × d_k(query 一行一行堆),Kᵀd_k × n(key 一列一列堆),乘积是 n × n

在小例子上,3 个 token,分数矩阵 3 × 3。第 i 行就是 「token i 在看什么」。demo 里的 「cat」行是[the=2, cat=4, sat=6]:cat 对 「the」一般、对自己还行, 但感兴趣的是「sat」。一个健康的语言模型本就该这样 —— 懂动词的名词应该强烈关注它的动词。

为啥用点积?几个理由。一,它是最便宜的有意义的相似度度量: 每个维度一次乘加。二,可微,梯度流得顺。三,Q 和 K 各自经过学过的矩阵变换, 模型实际上可以用任意它想要的相似度函数 —— 在投影后空间里的点积灵活得很。 四,点积归到矩阵乘法,GPU 把矩阵乘当早饭吃。

代价。这就是「attention 是二次方」里那个有名的 O(n²)。 长度 n = 1000 时,分数矩阵一百万个元素。n = 32,000(一个长上下文)时就是十亿。显存和算力都按 走。 高效 attention 里所有的招数 —— flash attention、稀疏 attention、sliding window —— 都是为了把这个矩阵实例化出来。

03

缩放与 Softmax:从分数到权重

原始分数可以任意大小。两个简单操作把它变成概率分布。

分数矩阵告诉我们,每个 token 想看每个 token 的力度有多大, 但这些数没有上下界 —— 可能是 3300、也可能是 −5。 要把它当加权和的权重用,我们希望它正、希望它和为 1。两步就够。

  • 除以 √d_k 缩放。把每个分数都除以 key 维度的开方。d_k = 4 时除以 2,d_k = 64(真 Transformer 头的尺寸)时除以 8。 为什么?不缩放的话,高维下点积的方差会随 d_k 增长,把分数推到 softmax 饱和的区域,梯度就消失了。除以 √d_k 让方差稳住。小补丁,大稳定性。
  • Softmax。就是概率统计 primer 里的那个 softmax。每一行独立做:exp(x_i) / Σ exp(x_j)。一行的输出就是一个概率分布 —— 非负、和为 1。
缩放点积注意力 —— "cat" 这一行原始分数 —— [2, 4, 6]2the4cat6sat分数矩阵的 "cat" 行:[2, 4, 6]。还不是权重。
1 / 4
先除以 √d_k 让原始分数别太大(免得 softmax 饱和),然后 softmax 把它变成 token 上的概率分布。

合起来:A = softmax(Q · Kᵀ / √d_k)。这一行就是缩放点积注意力的 全部公式,只差最后一步(加权和,下一节讲)。A注意力矩阵—— 每个 token 一行权重,告诉它怎么把 value 混在一起。

在小例子上,「cat」这一行大约是[the=0.09, cat=0.24, sat=0.67]。当百分比读:cat 把67% 的注意力放在 sat 上、24% 放在自己身上、9% 放在 the 上。原本 2:4:6 的分数比,经 softmax 锐化成更果决的 9:24:67,因为 exp() 对最大那一项放大得格外猛。

再加一道:mask。做语言建模的时候,一个 token 不应该能看到未来的 token —— 训练时那是作弊。技巧:在 softmax 之前, 把禁止的位置加上一个很大的负数(常用 −∞)。exp(−∞) = 0, 那些位置的权重就严格为 0。这就是因果 mask(causal mask), 也是 encoder(没 mask,每个 token 能看到整句 —— BERT 用) 和 decoder(有因果 mask,每个 token 只看过去 —— GPT 用)的区别。

数值稳定。实际实现不会直接算 exp(x) / Σ exp(x) —— 分数大就溢出了。而是先减掉这一行的最大值:softmax(x) = softmax(x − max(x)), 数学上一模一样、但留在浮点能表示的范围内。每个库都这么干,平时几乎不用想它。

04

V 的加权和

拿注意力权重,拿 Value 行,混。输出是一个带上下文的 token。

终于齐了。对每个 token i,注意力机制给出了一行和为 1 的权重, 告诉 i 该把其他每个 token 各混多少进来。最后这一步, 就是真正去做这件「混」—— 混的是 Value 向量。

output[i] = Σ A[i, j] · V[j]。「cat」这一行就是0.09 · V[the] + 0.24 · V[cat] + 0.67 · V[sat]。结果是一个新的 4 维向量 —— 形状跟原来的 V 行一样,但按注意力把所有 value 混在了一起。

V 的加权和 —— "cat" 的输出权重(来自 softmax)与 V 矩阵the0.100.400.200.50× 0.09cat0.500.600.300.10× 0.24sat0.700.200.400.80× 0.67§3 的 softmax 权重:[0.09, 0.24, 0.67]。V 是 3 行 × 4 维。
1 / 4
每个 token 的 V 行先乘上自己的 softmax 权重再相加。结果是一个带上下文的向量 —— 「cat」这下携带了「sat」的信息。

写成矩阵形式,整个自注意力操作就一行:

Attention(Q, K, V) = softmax( Q · Kᵀ / √d_k ) · V

就这。全部就这一行。Q、K、V 各自是 n × d_k。 分数矩阵 Q · Kᵀ 是 n × n。缩放 + softmax 之后得到 n × n的注意力矩阵。再乘 V(n × d_v),输出又是 n × d_v: 每个 token 一个新表示。三次矩阵乘 + 一次 softmax,就这些。 这一个公式,加上多头、残差连接、layer norm、外面包一层 MLP,就是 Transformer。

「cat」刚才发生了什么。注意力之前,「cat」的表示 只来自「cat」这个词的 embedding —— 一个静态身份。注意力之后, 它的表示约等于 0.67 · V[sat],再加上自己和「the」的小贡献。 换句话说:「cat」现在知道动词是「sat」,并把这条信息往后传。 人们说注意力给你带上下文的表示,就是这个意思。 词嵌入 primer 末尾那个静态 embedding 的限制?这下解决了。

每个 token 同时进行。例子走的是「cat」,但分数矩阵的每一行、 注意力矩阵的每一行,都是各自独立的。n 个 token 的输出是并行算出来的 —— 把这件事写成矩阵乘的全部意义就在这。RNN 天生是顺序的(t 时刻的隐藏状态要等 t−1)。 注意力天生是并行的(所有位置一把算)。Transformer 在 GPU 上能训那么快, 全部理由就在这。

我们走到哪了。一层自注意力,接收 n 个输入向量, 吐出 n 个输出向量,长度不变,每个都被序列里其他部分上下文化。 把 12 层、96 层、175 层(中间夹小前馈网络)叠起来,你就得到了 GPT-2、GPT-3、GPT-4。 Transformer 一句话讲完。下一篇主菜会补上「小前馈网络」、残差脚手架、多头结构、 layer norm —— 但所有这一切的核心操作,就是我们刚刚走完的这个。