自己注意 入門
Transformer の各トークンは、他のすべてのトークンを眺めて、今この瞬間、 自分にとって重要なのはどれかを決める。その決定が自己注意(self-attention) —— 現代の LLM の中核となる 操作だ。短い 4 トピック:Query / Key / Value のメンタルモデル;スコア行列 Q · Kᵀ;あの有名な √d 除数と softmax を伴うスケール付き内積注意(scaled dot-product attention); そして最終的な V の加重和 —— 小例で端から端まで歩く。
Query、Key、Value
1 つの埋め込みから 3 つの射影 —— 各トークンが注意で果たす 3 つの異なる役割。
各トークンは Transformer の層に 1 つのベクトル(埋め込み + 位置情報)として入る。 注意を行うために、その 1 つのベクトルは 3 つの学習済み線形射影で3 つのベクトルに展開される:Query(Q)、Key(K)、Value(V)。名前は情報検索からの借用で、 類比は驚くほどよく当てはまる。
- Query(Q)。このトークンが探しているもの。 発する検索クエリのようなもの:「私は動詞 —— 近くに主語になりたい名詞は?」
- Key(K)。このトークンが自分について公告するもの —— 他のトークンが関連性判断に使う、タグのような要約。 掲示板のカード:「私は名詞、三人称、有情物」。
- Value(V)。マッチが成立した場合に渡す内容。 Q が K と合えば、要求側のトークンの表現に V が混ぜ込まれる。
3 つの射影は、入力埋め込み x に学習済みの 3 つの重み行列を掛けて作る:Q = x · W_Q、K = x · W_K、V = x · W_V。注意で学習されるすべてはこの 3 つの重み行列に宿る —— 勾配降下が「各トークンが探すもの」と「各トークンが提供するもの」を 言語モデリングに役立つ形に揃えていく。
なぜ 1 つの埋め込みを 3 視点に分けるのか? 3 つの役割は根本的に異なるからだ。 動詞が主語に探したいもの、動詞が下流の依存先に公告したいもの、 動詞が実際に下流に運ぶ情報、これらは別物。1 ベクトルに束ねれば、 埋め込み空間の同じ方向で 3 仕事を同時にやらせることになる。 3 行列 = 3 仕事、きれいに分離。
次元についての細部。元の Transformer では入力埋め込みは 512 次元、 Q・K・V もすべて 512 次元。実用される多頭注意では(後の Transformer primer で詳述)、 Q/K/V は h ヘッドに分割され各ヘッド d_k = d_model / h 次元、 並列に各自の注意を計算する。この primer では単頭注意のみ、d_k = 4という非常に小さい設定で、行列を読みやすく保つ。
本 primer の走る例。全編で同じ玩具文を使う:「the cat sat」。3 トークン、Q/K/V 各 4 次元。 各行列 12 数字、封筒の裏に書ける小ささ、注意の全工程を端から端まで 見せるに足る大きさ。
スコア行列:Q · Kᵀ
各 Query は各 Key と出会う。内積は両者の適合度を測る。
魔法のステップ。各トークン i に対して、「i が他のトークン j をどれだけ強く見たいか」という数値が欲しい。 線形代数で 2 ベクトルの類似度を測る最も素朴な方法を使う —— 内積。
S[i, j] = Q[i] · K[j]。トークン i の Query 行と トークン j の Key 行を要素ごとに掛けて足す。 この 1 数値が注意スコア(attention score)。 大きく正 = よい一致、ほぼ 0 = 一致なし、負 = 不一致。
全ての (i, j) ペアで計算すれば n × n 行列 ——スコア行列 S が出来上がる。行列形式では 簡潔に S = Q · Kᵀ:Q は n × d_k(query を行に積む)、Kᵀ は d_k × n(key を列に積む)、 積は n × n。
玩具例では 3 トークン、スコア行列は 3 × 3。i 行目は 「トークン i が見ているもの」。demo の 「cat」行は[the=2, cat=4, sat=6]:cat は「the」にやや、自分にもまずまず、 だが最も「sat」に関心がある。健全な言語モデルが出すべき結果 —— 動詞を意識する名詞は自分の動詞に強く注意するはず。
なぜ内積?いくつか理由がある。第一に、意味のある類似度尺度として 最も安い:次元ごとに 1 回の積和。第二に、微分可能で勾配が綺麗に流れる。 第三に、Q と K を学習行列で回転させることで、モデルは任意の類似度関数を 実質的に使える —— 射影空間での内積は驚くほど柔軟。 第四に、内積は行列積に帰着し、GPU は行列積を朝飯として食べる。
コスト。「注意は二次的」と言われる O(n²)。 長さ n = 1000 ならスコア行列は 100 万要素。n = 32,000(長文脈)なら 10 億。メモリと計算は n² で増える。 効率的注意の工夫すべて —— flash attention、疎注意、スライディング ウィンドウ —— はこの行列を全部実体化しないためのもの。
スケール化と Softmax:スコアから重みへ
生スコアは任意の大きさ。2 つの簡単な操作で確率分布に変える。
スコア行列は、各トークンが各トークンを見たい強度を教えてくれるが、 数値に上下界はない —— 3 も 300 も −5 もある。 加重和の重みとして使うには、正でかつ和が 1 であってほしい。2 操作で済む。
- √d_k でスケール化。各スコアを key 次元の平方根で割る。d_k = 4 なら 2 で、d_k = 64(実 Transformer ヘッドのサイズ) なら 8 で割る。なぜ? しないと高次元では内積の分散が d_k と共に増え、 スコアが softmax の飽和域に押し込まれ、勾配が消える。 √d_k で割れば分散を一定に保てる。小修正、大安定性。
- Softmax。確率統計 primer の softmax と同じ。各行独立に:exp(x_i) / Σ exp(x_j)。1 行の出力は確率分布 —— 非負・和 1。
合わせて:A = softmax(Q · Kᵀ / √d_k)。 この 1 行がスケール付き内積注意の全公式 —— 残るは最後の 1 ステップ (加重和、次節)。A が注意行列 —— 各トークンに 1 行の重み、value をどう混ぜるかの指示。
走る例では、「cat」行はおよそ[the=0.09, cat=0.24, sat=0.67]。パーセントとして読む: cat は注意の 67% を sat に、24% を自分に、9% を the に置く。元の 2:4:6 のスコア比が softmax で 9:24:67 に鋭化。exp() は最大値を不釣り合いに増幅するから。
もう一つの仕掛け:mask。言語モデリングでは、 トークンは未来のトークンに注意できない —— 訓練中はカンニングになる。 コツ:softmax の前に、禁止する位置に大きな負数(よく −∞)を加える。 exp(−∞) = 0、それらの位置の重みは厳密にゼロ。これが因果マスク(causal mask) —— エンコーダ(マスクなし、各トークンが全列を見る —— BERT が採用)と デコーダ(因果マスク、各トークンは過去のみ —— GPT が採用)の違い。
数値安定性。実装は exp(x) / Σ exp(x) をそのまま 計算しない —— 大スコアでオーバーフローする。代わりに行の最大値を引く:softmax(x) = softmax(x − max(x))、数学的に同じだが浮動小数点の有限範囲内に 留まる。どのライブラリもやっているので、普段ほぼ意識しなくていい。
V の加重和
注意の重み、Value 行、混ぜる。出力は文脈化されたトークン。
ついに全部揃った。各トークン i に対し、注意は合計 1 の重み行を作り出した。 他の各トークンをどれだけ混ぜるかの指示だ。最後の一歩は実際に混ぜること —— 混ぜるのは Value ベクトルだ。
output[i] = Σ A[i, j] · V[j]。「cat」行なら0.09 · V[the] + 0.24 · V[cat] + 0.67 · V[sat]。 結果は新しい 4 次元ベクトル —— 元の V 行と同じ形だが、 全 value を注意通りに混ぜたもの。
行列形式で書くと、自己注意の全操作はこの 1 行:
Attention(Q, K, V) = softmax( Q · Kᵀ / √d_k ) · V
以上。全部これ。Q、K、V それぞれ n × d_k。 スコア行列 Q · Kᵀ は n × n。スケール化 + softmax で n × nの注意行列。V(n × d_v)を掛けると出力は n × d_v: 各トークンの新しい表現。3 行列積 + 1 softmax。 この公式 1 つに多頭、残差接続、layer norm、外側の MLP を加えれば Transformer。
「cat」に何が起きたか。注意の前、「cat」の表現は 「cat」という単語の埋め込みのみに基づく —— 静的なアイデンティティ。 注意の後、その表現はおよそ 0.67 · V[sat] に自分と「the」の小貢献を加えたもの。 言い換えれば、「cat」は今や動詞が「sat」だと知っている。 この情報を次の層へ運ぶ。注意が文脈付き表現を与える、 とは正にこのこと。埋め込み primer 末尾の静的埋め込みの限界? 解決済み。
全トークン同時進行。例は「cat」を歩いたが、 スコア行列の各行・注意行列の各行は独立だ。出力は n トークン全てで 並列に計算される —— これを行列積で書く意味そのもの。RNN は本質的に逐次 (時刻 t は t − 1 の隠れ状態に依存)。 注意は本質的に並列(全位置を一度に計算)。 Transformer が GPU で速く訓練できる理由はこれ。
たどり着いた地点。1 層の自己注意は n 個の入力ベクトル を受け、n 個の出力ベクトル(長さ不変、各々が他の部分で文脈化)を返す。 12、96、175 層(間に小さい前向きネットワークを挟む)積めば GPT-2、GPT-3、GPT-4。 Transformer を 1 文で言えばこれ。次の primer、主菜は、 「小さい前向きネットワーク」、残差の足場、多頭構造、layer norm を埋める —— だが、その全ての中核となる操作は、いま歩いたこれだ。