多頭注意 入門
1 つの注意ヘッドはトークン間の一種類の関係しか捉えない。 実際の言語には多くの関係がある —— 主述の一致、代名詞-先行詞、局所位置、意味類似。多頭注意(multi-head attention)は h 個の注意計算を並列に、 各々の部分空間で実行し、結果を混ぜる。短い 4 トピック:なぜ多頭;Q/K/V を h ヘッドに分割; 各ヘッドの並列注意計算; 最後に concat + 射影 W_O で結合。
なぜ「多頭」?
1 ヘッドは 1 種類の関係しか見ない。実文には同時に複数ある。
自己注意 primer では、1 つの注意ヘッドを端から端まで走り抜けた。動いた。 「the cat sat」で cat 行は重みの約 67% を sat に置いた —— 動詞を意識する名詞のあるべき姿。ここまでは良し。だが文の関係は 1 種類ではない。
- cat は自分の動詞(sat)を探したい。
- cat の直前には限定詞(the)もある —— 位置関係。
- sat には主語(cat)、後ろに補語もあり得る。
- 長距離の代名詞解消:「she」は 50 トークン遡って 「Mary」を見つける必要があるかもしれない。
- 意味的クラスタリング:食、感情、時間に関する語。
1 つの注意ヘッドは各トークンに 1 組の softmax 重みしか出せない。 スコア行列の同一行が全ての関係追跡を同時にやらされる。 1 ヘッドでは選ばざるを得ない:このトークンは動詞、限定詞、共参照、どれを見るか。 2 つ以上を綺麗にはできない。
解決。注意計算を h 個並列に。各々が独自の (W_Q、W_K、W_V) 射影行列を持つ —— 各々が独自の部分空間に住み、 独自のパターンに特化できる。あるヘッドは統語ヘッド、別のは局所位置ヘッド、 さらに別のは長距離照応ヘッド。誰が何に特化するかは訓練でモデルが 勝手に決める。明示的な監督は要らない。
元の Transformer は h = 8。GPT-3 は h = 96。 最大級のモデルは層あたり 64–128 ヘッド。各ヘッドは小さい(よく d_k = d_model / h = 64)、総計算量は 1 大ヘッドとほぼ同じ —— だが多種の関係を同時に符号化する容量は遥かに高い。
解釈上の注意。「ヘッド i は統語ヘッド」は便利な簡略化、 字義通りの主張ではない。実際、モデルは同一パターンを複数ヘッドに分散させ、 ヘッド同士が訓練中に競合・特化し、多くのヘッドは部分的に冗長 (剪定してもモデルはほぼ変わらない)。それでも「ヘッドは異なる関係型を担う」 というメンタルモデルは「なぜ多ヘッドが必要か」の直感として正しい。
分割:d_model → h ヘッドの d_k
特徴次元を h に切る。各ヘッドは自分の部分空間、トークン数は同じ。
多頭注意の最も素朴な実装は、最も描きやすい。前 primer の Q、K、V を取る —— 各々の形状は (n × d_model):n は系列長、d_model は入力埋め込み次元。列方向に h 等分に切る。 各部分は (n × d_k)、d_k = d_model / h。
走る例で具体的に:d_model = 4、h = 2、よって d_k = 2。3 × 4 だった Q 行列が 2 つの 3 × 2 行列、すなわちQ₁ と Q₂ になる。K と V も同じ分割 —— 全幅 3 行列ではなく、 ヘッド毎 6 行列を持つことになる。
少し非自明な点。各ヘッドは特徴次元の自分の取り分しか見ないが、全トークンは見える。分割は特徴軸方向、決して系列軸方向ではない。 つまり head 1 は「文の前半を処理するヘッド」ではない。 全ヘッドが全トークンを見る —— 違いは「どの埋め込み次元を使えるか」だけ。
実装メモ。実際は誰も Q を全幅で計算してから切ったりしない。 W_Q 自体が (d_model × d_model) 形状だが、h 個の独立な (d_model × d_k) ブロックと解釈する。前向きは通常(n × d_model) を view() や reshape()で (n × h × d_k) に、その後軸入替で (h × n × d_k)、h をバッチ次元に扱う。数学的には「切って別々」と同じだが、 GPU は 1 回のバッチ行列積で済ませる。
なぜ等分割?入力埋め込みはまだ分化していないから。 訓練前にモデルが head 1 を head 2 より大きくする理由はない。 等分割は対称な事前情報 —— 訓練で各ヘッドが自身の取り分内で特化を見出す。
並列:各ヘッドが独自に注意を行う
同じ操作を h 回。各ヘッドが自分の部分空間、自分のパターン。
Q、K、V を分割した後、各ヘッド i は自分の (Q_i, K_i, V_i)を持つ。各々が、自己注意 primer の末尾で書き下したのと全く同じ スケール付き内積注意を、自分の取り分の上で実行する:
head_i = softmax( Q_i · K_iᵀ / √d_k ) · V_i
各ヘッド毎に 3 行列積 + 1 softmax、以前と全く同じ。唯一の変化は、 これを h 回 —— ヘッド毎に 1 回 —— やること、そしてこのステップで ヘッド同士は相互作用しないこと。重み共有なし、スコア行列共有なし、 相互通信なし。同じ入力トークン上で、各々が独立に計算。
ヘッドが独立だから、GPU 上で実際に並列に動く。 実装は Q、K、V を整形してヘッド次元をバッチ次元として扱い、 同じスケール付き内積注意のカーネル 1 回の融合呼び出しで hヘッド全てを処理する。同一の総 d_model 下では、 多頭注意は単頭注意とほぼ同じコスト —— 表現力をタダで上乗せした形。
各ヘッドは何を学ぶか?これは盛んに研究されている。短く言えば: よく訓練された Transformer では、ヘッド毎にプローブをかけると、 多くが驚くほど解釈可能なパターンに特化していると分かる:
- 位置ヘッド。純粋に直前トークン、直後トークン、文頭に 注意するヘッド。位置のみ、内容無視。
- 統語ヘッド。動詞から主語へ、代名詞から先行詞へ、 節から中心名詞へ。
- 意味ヘッド。話題的に類似するトークンに強く注意 —— 「食」関連トークンが、距離があっても他の食関連トークンを見つける。
- 誘導ヘッド(induction heads)。有名なクラス: 反復パターンを認識し、類比でトークンを複写する。 (Anthropic の解釈可能性研究では文脈内学習の中核とされた。)
§1 の注意の再掲。これらの特化は創発、設計ではない。 モデルに「ヘッド 5 を位置ヘッドに」と教えた者はいない —— 訓練がそこへ 辿り着いた。多くのヘッドは冗長。著名な Voita 等の結果:訓練済み Transformer のヘッドのほぼ半数を剪定しても品質はほぼ落ちない。 多頭構造は有用な帰納バイアス —— 多様なパターンを学ぶ容量を モデルに与えるもの —— であり、ヘッド毎に固定された意味があるわけではない。
結合:concat してから W_O で射影
ヘッド出力を積み直し、最後の学習済み行列で混ぜる。
並列注意のあと、h 個の出力行列、ヘッド毎 1 つ、各々 (n × d_k) 形状を持つ。これらを Transformer の次層へ渡すために、 元の入力形状 (n × d_model) の単一テンソルに畳み戻す。2 操作で済む。
- Concat(連結)。h 個のヘッド出力を特徴軸方向に 並べる。h 個の (n × d_k) が (n × h · d_k)= (n × d_model) へ。Q/K/V/入力埋め込みと同形 —— 良し。 だが次元は積まれただけ、混ざってはいない。head 1 の出力は最初の d_k 次元、head 2 は次の d_k、というふうに。
- W_O で射影。学習済みの (d_model × d_model)重み行列。連結結果に W_O を掛けると、各出力次元が各ヘッドの寄与の 混合になり得る。ここがヘッド同士がついに「会話」する場所。
W_O がなければヘッドは永遠に対話しない —— head 1 の情報は永久に 次元 0–63、head 2 は 64–127、層を重ねても同じ。 W_O 射影によりモデルはどのヘッドを増幅すべきか、抑制すべきか、 次元を跨いでどう組み合わせるかを学習する。小さな行列 (d_model² パラメータ)が静かに重要な仕事をしている。
完全な式。4 節を合わせれば、多頭注意ブロック全体は:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_O where head_i = softmax( Q_i · K_iᵀ / √d_k ) · V_i and Q_i = Q · W_Q^i, K_i = K · W_K^i, V_i = V · W_V^i
層あたり 4 重み行列:W_Q、W_K、W_V、W_O。各々 (d_model × d_model)。 各層の注意で 4 · d_model² パラメータ —— GPT-2 small で d_model = 768 なら層あたり約 240 万。12 層積めば、注意だけで モデルの 1.17 億パラメータのうち約 2800 万。注意を囲む前向き MLP は さらに多い —— が、それは次の primer の話。
現代の変種。ここで述べた「素」の多頭注意は教科書設定。 実運用ではメモリと推論コストを節約する最適化が入る:
- Multi-Query Attention(MQA)。全ヘッドが 1 つの K と V を共有、Q だけがヘッド毎。KV キャッシュ(推論時の主要メモリコスト) を 1/h に縮める。PaLM が採用。
- Grouped-Query Attention(GQA)。妥協案:ヘッドを gグループに分け、各グループが K と V を共有。 現代のオープンソース LLM のほとんど(Llama 2/3、Mistral)が GQA を採用 —— MQA とほぼ同速度、完全な多頭と比べ品質ほぼ無損失。
- FlashAttention。数学的変更ではない —— カーネルレベルの 書き直しで、スコア行列、スケーリング、softmax、加重和を 1 つのストリーム 計算に融合、完全な n × n 注意行列をメモリに書かない。 結果は同一、メモリ帯域使用量は劇的に削減。あらゆる現代訓練スタックで標準。
たどり着いた地点。18 個目の primer で、私たちは注意ブロック 内のデータの全行程を追った:トークン → 埋め込み、埋め込み → Q/K/V、 Q/K/V → ヘッド分割、各ヘッド → 注意パターン、ヘッド → トークン毎 1 ベクトル。 次の primer に残るのは:注意を完全な Transformer ブロックに包む 残差接続、layer norm、前向きネットワーク。難所は既に背後にある。