Transformer Block 入門
注意は主役だが、それだけでは働かない。各 Transformer ブロックは注意を 他の 3 部品で包む —— 正規化(normalization)、残差接続(residual)、位置毎前向き MLP(position-wise FFN) —— そして N 個のブロックを積んでモデルになる。短い 4 トピック:LayerNorm と RMSNorm;残差接続と 深いスタックが学習可能になる理由;ほとんどのパラメータを担う FFN; そしてこれらが完全な ブロック へ組み上がり、最後に 完全なモデルへ積み上がる過程。
LayerNorm と RMSNorm
各トークンの特徴ベクトルを安定したスケールに保ち、訓練の爆発を防ぐ。
深いネットでは、活性化ベクトルの大きさが層ごとに大きく漂う。 逆方向の勾配も同じ。どこかでスケールを抑えないと、深いスタックは 学習不可能になる:値が爆発してオーバーフロー、または 0 に縮んで勾配が消える。 正規化(normalization)が答え。
LayerNorm。2016 年に提案、元の Transformer が採用。 各トークンの特徴ベクトル x ∈ R^d_model について、d_model 次元方向の平均と標準偏差を計算し、標準化:
μ = mean(x) // トークン毎 1 スカラー σ = std(x) // トークン毎 1 スカラー LN(x) = γ · (x − μ) / σ + β
γ と β は次元毎の学習可能なスケールとバイアス(長さ d_modelのベクトル)。鍵となる性質:これがトークン毎に独立に行われる。 トークン間で統計量は流れない —— BatchNorm との違いはここで、 可変長系列に適している理由でもある。
RMSNorm。2019 年。現代オープン LLM の既定。観察は: LayerNorm の平均減算の貢献は小さい —— 「再中心化」はモデル容量に 有用な変化をもたらさない。RMSNorm は単にそれを捨てる:
rms = sqrt( mean(x²) ) // 中心化なし RMSNorm(x) = γ · x / rms // β もなし
以上。平均減算なし、β バイアスもなし。ベクトル毎の演算数が減り、 実測で約 7–10% 高速。経験的にモデルの学習・汎化は LayerNorm とほぼ同じ。 norm は毎ブロックで使うので、ブロック毎の節約は積み重なる —— 前向き + 後向き全体で結構な節約に。
誰がどれを使うか。
- LayerNorm:元の Transformer、BERT、GPT-2、GPT-3、T5。
- RMSNorm:Llama 1、2、3。Qwen。DeepSeek。Mistral。PaLM (T5 系)。Gemma。現代オープン LLM のほぼ全て。
pre-norm と post-norm。関連だが別の問題:正規化は attention / FFN の前(pre-norm)か、残差の後(post-norm)か? 2017 年の元論文は post-norm。2020 年頃、複数の論文で pre-norm の方が 遥かに安定だと示された:勾配が残差を素直に流れ、backprop 開始直後に norm 層で潰されない。現代の decoder LLM はみな pre-norm。 最終的なレイアウトは §4 で見る。
残差接続
入力を出力に足し戻す。この 1 手で深いスタックが学習可能になる。
残差接続(または「skip connection」)は深層学習で最も素朴な発想の一つで、 最も影響の大きい一つ。ベクトル x を f(x) に写すブロックを取り、 その出力を以下に書き換える:
y = f(x) + x
変更はこの 1 行だけ。ブロックは変換を行い、その後元の入力が足し戻される。 機能的にはブロックは入力からの差分を学習する形になる —— 完全な出力を学ぶのではなく。続いて 2 つのことが起こる、いずれも重要。
1. 恒等関数が易しい既定値。もし f がすべて 0 を 出力するなら、ブロックは恒等関数:y = x。ゼロ初期化されたブロックを 積めば、入力から出力まで恒等写像になる。オプティマイザは各層に小さな有用な 差分を学習する —— 各層が完全な出力変換をゼロから学ぶ必要がない。 遥かに親切な出発点。
2. 勾配の高速道。残差 y = f(x) + x に連鎖律を 適用すると dy/dx = (df/dx) + 1。この「+ 1」が残差経路の貢献。df/dx が小さくても —— 深いスタックで消えても —— 残差経路の勾配は常に 1、上流の勾配がそのまま伝播される。 プレーンな幹は導関数の積の連鎖、残差経路は 1 の和の連鎖。
残差接続(ResNet、2015)以前は、視覚ネットは ~10 層を超えると不安定、 ~20 層を超えるとほぼ不可能だった。残差以後、ResNet-152 は何の問題もなく学習。 Transformer もこの性質を継承:GPT-3 は 96 層、Llama 70B は 80 層、 それでも学習できるのは各サブ層が残差ブロックだから。
Transformer ブロック 1 つあたり 2 残差。1 つの Transformer ブロックには2つのサブ層 —— attention と FFN —— があり、各々が 独自の残差を持つ:
x ← x + Attention( Norm(x) ) // 残差 1 x ← x + FFN( Norm(x) ) // 残差 2
変数 x は再利用される —— ブロックはその場で 2 度更新する。 残差の視点では attention と FFN は各々小さな修正を生むだけ。 モデルの骨格は恒等写像、各ブロックがその上に小さな専門化された調整を重ねる。
なぜ norm と残差の両方が必要?異なる問題を解決するから。 norm は各段階の活性化のスケールを制御する。残差は勾配の流れと既定としての恒等を守る。両者で学習可能性の骨格をなす —— どちらか単独では深いスタックは保てない。
位置毎前向き(FFN)
各トークンに独立に適用される 2 層 MLP。モデルのパラメータの大半が実は宿る場所。
attention はトークン間で情報を移動する。前向きネットワーク —— FFN(または MLP)サブ層 —— は各トークンが必要な情報を手にしてから 自分一人で行う処理。Transformer ブロックの後半、位置毎に適用: 同じ MLP が全トークンで走り、位置間で情報は流れない。
FFN(x) = W_down · σ( W_up · x )
- W_up:形状 (d_model, 4 · d_model)。 トークンベクトルをより広い「hidden」空間に射影 —— 通常モデル次元の 4 倍。 GPT-3 の d_model = 12288 なら hidden は 49152。
- σ:要素ごとの非線形。元の Transformer と GPT-2 は GELU。 Llama、Qwen、Mistral は SiLU(Swish とも)。具体的選択は 「何かしらある」ことより重要ではない。
- W_down:形状 (4 · d_model, d_model)。 下方射影で出力を入力と同形に戻す —— ブロック積み上げの必須条件。
なぜ position-wise?トークン間の作業は attention が済ませた。 FFN は各トークンの文脈付きベクトルを有用な形に変換すればよい。 位置間で同じ MLP を共有することで、FFN は 1 つの固定関数を学んで全位置に適用 —— 安く、汎化が良く、任意長の系列で動く。
なぜ 4 倍?2017 年からの経験的なスイートスポット。直感: attention の表現力は d_model に縛られる、FFN は attention にできない トークン毎の計算を行うのに十分な余地が必要。4 倍は大半のタスクに十分、 大モデルではもう少し広めが助かり、狭めると劣化。現代の一部の変種、 例えば Llama は微妙に異なる形状(SwiGLU 系は約 2/3 × 4 = 8/3倍モデル次元、行列を 2 つから 3 つへ)を採るが、構造は同じ。
パラメータはどこにあるか。d_model = 768、4 · d_model = 3072 での 1 ブロックの概算:
- Attention:W_Q、W_K、W_V、W_O —— 各々 768 × 768、 合計 4 · 768² ≈ 240 万。
- FFN:W_up が 768 × 3072、W_down が 3072 × 768。 合計 2 · 768 · 3072 ≈ 470 万。
FFN は attention の約 2 倍の規模。「attention こそモデル」と直感的に思いがちな Transformer において、実際にパラメータの大半とトークンレベル計算の大半を 担うのは FFN。研究者は今や、ほとんどの事実的知識 —— モデルの「記憶」 —— は FFN に蓄えられ、attention は各トークンに対しどの記憶を活性化するかを 決めるルーティング層と考えている。
非線形が重要。σ がなければ FFN は潰れる:W_down · (W_up · x) = (W_down · W_up) · x、低ランク表現の単層線形に。 非線形こそが、この上–下構造を 1 線形層より表現力のあるものにする。 SiLU と GELU は小さな正値を残し、大きな負値を圧縮 —— ReLU の滑らかで単調な変種。
完全なブロック、そして N ブロックでモデルに
norm、attention、残差、norm、FFN、残差。N 個積み、埋め込みとヘッドで包む。
ここまで見てきた 4 部品 —— norm、attention、残差、FFN —— は 1 つの固定パターンに 組み上がる。これがpre-norm Transformer ブロック、 現代の全 decoder-only LLM が採用:
# 1 ブロック x ← x + Attention( Norm( x ) ) x ← x + FFN( Norm( x ) )
2 サブ層、2 残差、2 norm。1 行ずつ読む:x を正規化、attention を実行、 結果を元の x に加算;新しい x を正規化、FFN を実行、結果を再び加算。 出力は入力と同形、次のブロックへ直接渡せる。
N 個積む。モデルはこのブロックの多数のコピーを持ち、 各々が完全に独立した重み —— 各ブロックは独自の attention パターン、 独自の FFN 関数、独自の norm パラメータを学習する。情報は入力から出力へ 順次各ブロックを通る、スキップも分岐もない。
- GPT-2 small:N = 12 ブロック、d_model = 768。
- GPT-2 medium / large / XL:N = 24 / 36 / 48。
- GPT-3:N = 96、d_model = 12288。
- Llama 2 / 3 7B:N = 32。Llama 70B:N = 80。
- 現代の最先端モデル:典型的に 40–120 ブロック。
外側の包み。完全な言語モデルは:
x = Embedding( token_ids ) // (seq_len, d_model)
for block in blocks:
x = block(x) # N 回
x = FinalNorm(x)
logits = LM_Head(x) // (seq_len, vocab_size)Embedding はトークン ID を d_model 次元ベクトルに(RoPE の相対位置情報 は attention 内部で加わる、ここではない)。N ブロックがベクトルを変換。 最後の norm でスケールを整える。LM ヘッド —— 形状 (d_model, vocab_size)の単一線形層 —— が各トークンのベクトルを語彙全体のスコアへ射影、その上に softmax で次トークンの確率分布が得られる。この単線形はモデル全体で最大の重みテンソルに なることが多い(GPT-2 small で 50,000 × 768 = 3800 万、GPT-3 で 50,000 × 12288 = 6.14 億)。
因果マスクと KV キャッシュ。言語モデリングでは、各ブロック内の attention は因果マスク(自己注意 primer で扱った)を使う —— 各トークンは自身と 過去のトークンにしか注意できない。推論時、前のトークンの K、V 行列はキャッシュ され、新生成トークン毎に再計算しなくて済む;この KV キャッシュがテキスト生成時の 主たるメモリコスト、MQA / GQA / スライディングウィンドウ等の手法はこれの縮小が目的。
たどり着いた地点。20 個目の primer で、絵が完成した。 トークン → 埋め込み → N 個の pre-norm ブロック(各々 = norm + RoPE 付き多頭 attention + 残差 + norm + FFN + 残差)→ 最終 norm → 線形ヘッド → 次トークン logits。 この一文 —— 道中で展開してきた全ての部品と合わせて —— がまさに現代の decoder-only LLM。訓練にはさらに数点(交差エントロピー損失、AdamW、学習率スケジュール、 混合精度 —— 全て前の primer で扱った)が加わるが、アーキテクチャ自体は、 数百行の PyTorch で書ける、まさにこれ。