Transformer Block 入门

attention 是主角,但它不是一个人干活的。每个 Transformer 块都把 attention 包在另外三样东西里 —— 归一化(normalization)残差连接(residual)位置无关的前馈 MLP(position-wise FFN) —— 再把 N 个这样的块堆成一个模型。4 个短主题:LayerNorm 与 RMSNorm;残差连接以及为啥它让深堆栈可训;承担了大半参数的 FFN; 以及这些怎么装成完整的 block、最后凑成完整的模型

01

LayerNorm 与 RMSNorm

把每个 token 的特征向量保持在一个稳定的尺度上,免得训练炸。

深网络里,激活向量的量级可以一层一层飘得很离谱。反向的梯度也一样。 如果不在哪里把尺度按住,深堆栈根本练不起来:要么值爆掉、要么趋零把梯度搞没。 归一化(normalization)就是答案。

LayerNorm。2016 年提出,原始 Transformer 用的就是它。 对每个 token 的特征向量 x ∈ R^d_model,在 d_model这个维度上算它的均值和标准差,然后标准化:

μ = mean(x)            // 每个 token 一个标量
σ = std(x)             // 每个 token 一个标量
LN(x) = γ · (x − μ) / σ  +  β

γ 和 β 是学过的「每个维度一个」的缩放与偏置参数(长度 d_model 的向量)。 关键性质:这件事对每个 token 独立做。token 之间不共享统计量 —— 这跟 BatchNorm 的关键区别就在这,也正是它适合可变长度序列的原因。

LayerNorm 与 RMSNorm 在一个特征向量上的对比原始 x —— 均值 ≈ 0.13、std ≈ 0.69x[d_model]0.60-0.401.20-0.900.30-0.200.80-0.408 维。均值不为 0、各维幅度不齐 —— 还没归一化。
1 / 4
LayerNorm 先去均值,再缩放到单位方差。RMSNorm 省掉去均值这一步 —— 经验上效果差不多,大约快 7%。Llama、Qwen、DeepSeek 都用 RMSNorm。

RMSNorm。2019 年。现代开源 LLM 的默认。观察是: LayerNorm 里去均值这一步,贡献并不大 —— 「重新居中」对模型容量 并没带来有用的提升。于是 RMSNorm 直接把这一步丢掉:

rms = sqrt( mean(x²) )         // 不再去均值
RMSNorm(x) = γ · x / rms       // 也不要 β

就这。不去均值、不要 β。每个向量少几步操作,实际中大约快 7–10%。 经验上模型训练和泛化效果跟 LayerNorm 差不多。这点收益还会叠加: 因为 norm 在每个块都用一次,每个块省一点, 前向加反向整轮下来省得相当可观。

谁用哪个。

  • LayerNorm:原始 Transformer、BERT、GPT-2、GPT-3、T5。
  • RMSNorm:Llama 1、2、3。Qwen。DeepSeek。Mistral。PaLM (T5 变种)。Gemma。基本上每个现代开源 LLM 都用。

pre-norm 还是 post-norm。一个相关但独立的问题: 归一化是放在 attention / FFN 之前(pre-norm)、还是放在残差之后(post-norm)?2017 年原论文用的是 post-norm。 到 2020 年左右,几篇论文证明 pre-norm 稳定得多: 梯度可以顺着残差直接流回去,不会在 backprop 一开始就被 norm 层压扁。 现代每一个 decoder LLM 都用 pre-norm。我们在 §4 看最终的结构图。

02

残差连接

把输入加回到输出上。就这一招,让深堆栈训得起来。

残差连接(也叫「skip connection」)是深度学习里看起来最朴素的一个想法, 也是后果最大的一个。拿一个把向量 x 映成某个 f(x) 的块, 把它的输出改成:

y = f(x) + x

改动就这一行。块做完它的变换,然后再把原来的输入加回去。 功能上,块现在学的是相对于输入的增量,而不是完整的输出。 紧接着的两件事都重要。

残差连接:x → f(x) + x一个块:y = f(x)xfy = f(x)从一个朴素层开始:输入 x、变换成 f(x)、输出 y。
1 / 4
把输入加回到输出,等于给激活和梯度都开了一条直通堆栈的「高速路」。深的 Transformer 没这个根本训不起来。

1. 恒等函数是个好默认。如果 f 输出全零, 块就退化成恒等函数:y = x。一摞零初始化的块堆起来, 就是从输入到输出的恒等映射。优化器然后学的是每一层一小点有用的增量, 不用每一层从头学整个输出变换。这是一个友好得多的起点。

2. 梯度高速路。对残差 y = f(x) + x 用链式法则, 得到 dy/dx = (df/dx) + 1。这「+ 1」就是残差路径的贡献。 即使 df/dx 很小 —— 深堆栈里趋零 —— 残差路径上的梯度永远是 1, 上游梯度可以原封不动传下去。光骨干是导数的乘法链;残差路径是 1 的加法链。

在残差连接(ResNet,2015)之前,视觉网络训到 ~10 层就不太靠谱、 训到 ~20 层基本无望。残差出来之后,ResNet-152 顺顺利利训出来了。 Transformer 继承了同一个性质:GPT-3 有 96 层、Llama 70B 有 80 层, 它们能训出来,全靠每个子层都是残差块。

每个 Transformer 块里两个残差。一个 Transformer 块有个子层 —— attention 和 FFN —— 每个各自有自己的残差:

x  ←  x + Attention( Norm(x) )    // 残差 1
x  ←  x + FFN( Norm(x) )          // 残差 2

x 这个变量是被复用的 —— 块在它上面就地更新两次。 从残差的角度看,attention 和 FFN 各自产出一个小修正。 模型的骨架就是恒等映射;每个块往上面叠一个小的、专门的调整。

为什么 norm 和残差都要?它们解决不同的问题。 norm 控制每一步激活的尺度。残差保住梯度流恒等默认。两个一起搭出可训练性的骨架 —— 缺一个,深堆栈都顶不住。

03

位置无关前馈(FFN)

一个两层 MLP,对每个 token 各自独立跑。模型大部分参数其实都在这。

attention 在 token 之间搬信息。前馈网络 —— FFN(或叫 MLP)子层 —— 是每个 token 拿到该有的信息之后,自己一个人接着做的事。 它是 Transformer 块的后半,position-wise 应用: 同一个 MLP 在每个 token 上跑,位置之间没有信息流动。

FFN(x) = W_down · σ( W_up · x )
  • W_up:形状 (d_model, 4 · d_model)。 把 token 向量投影到一个更宽的「hidden」空间 —— 通常是模型维的 4 倍。 GPT-3 的 d_model = 12288,hidden 就是 49152
  • σ:逐元素的非线性。原始 Transformer 和 GPT-2 用 GELU。 Llama、Qwen、Mistral 用 SiLU(也叫 Swish)。具体选哪个, 没有「必须有一个」来得重要。
  • W_down:形状 (4 · d_model, d_model)。 再投影回去,让输出跟输入同形 —— 这是堆叠块的必要条件。
position-wise FFN —— 两层线性 + 一个非线性输入 x —— d_model = 4xd=4W_upW_up · xd=16σGELUW_downW_down · hd=4每个 token 一个 d_model 维向量。每个位置都跑同一个 FFN。
1 / 4
每个 token 各自独立过一遍,所有位置用同一份权重。hidden 是 d_model 的 4 倍 —— Transformer 大部分参数都在这(每个块大约 2/3)。

为什么 position-wise?因为跨 token 的活,attention 已经干完了。 FFN 只需要把每个 token 的带上下文向量变换成有用的东西。 各位置共享同一个 MLP 意味着 FFN 学一个固定函数、到处用 —— 便宜、泛化好、对任意长度的序列都行。

为什么 4 倍?2017 年的经验最佳值。直觉是: attention 的表达力受 d_model 限制,FFN 需要足够的空间 来做 attention 干不了的 per-token 计算。4 倍对大多数任务够用, 大模型再宽一点更好,再窄就掉点。一些现代变种比如 Llama 用稍微不同的形状 (SwiGLU 变种大约是 2/3 × 4 = 8/3 倍模型维、三个矩阵代替两个), 但结构是一样的。

参数都在哪儿。一个块在 d_model = 7684 · d_model = 3072 时快速一算:

  • Attention:W_Q、W_K、W_V、W_O —— 每个是 768 × 768, 一共 4 · 768² ≈ 240 万
  • FFN:W_up 是 768 × 3072、W_down 是 3072 × 768。 一共 2 · 768 · 3072 ≈ 470 万

FFN 差不多是 attention 的两倍大。在一个你直觉上会以为「attention 就是模型」的 Transformer 里,实际上是 FFN 占据了大部分参数、也干了大部分 token 级的计算。 研究者现在认为大部分事实性知识 —— 模型的「记忆」 —— 都存在 FFN 里, attention 则是决定每个 token 该激活哪些记忆的路由层。

非线性很重要。没有 σ,FFN 就塌了:W_down · (W_up · x) = (W_down · W_up) · x, 变成一个低秩的单层线性。非线性是让这个上–下结构真能算出比一个线性层 更有表达力的东西的关键。SiLU 和 GELU 都把小正值留下、把大负值压扁 —— 一个平滑、单调的 ReLU 变种。

04

完整的块,以及 N 个块构成模型

norm、attention、残差、norm、FFN、残差。N 份复制叠起来,外面包一个 embedding 和一个 head。

我们看过的四块东西 —— norm、attention、残差、FFN —— 按一个固定模式装在一起。 这就是pre-norm Transformer 块,每个现代 decoder-only LLM 都用:

# 一个块
x  ←  x + Attention( Norm( x ) )
x  ←  x + FFN( Norm( x ) )

两个子层、两个残差、两次 norm。一行一行读:先归一化 x、在它上面做 attention、 把结果加回原 x;再归一化新的 x、做 FFN、把结果再加回去。输出跟输入同形, 可以直接喂给下一个块。

一个 Transformer 块,以及 N 个块堆成模型一个块 —— pre-norm + 残差xRMSNormMulti-Head Attention+RMSNormFFN+pre-norm 流程:norm → attn → +x → norm → FFN → +x。
1 / 4
pre-norm 风格(每个现代解码器 LLM 都用):先归一化,再 attention 或 FFN,再加残差。一个块重复 N 次;最后一层 norm + LM head 收尾。

N 份复制。一个模型有许多份这样的块,各自独立的权重 —— 每个块学自己的 attention 模式、自己的 FFN 函数、自己的 norm 参数。 信息从输入到输出顺序穿过每个块,没有跳过、没有分支。

  • GPT-2 small:N = 12 个块、d_model = 768
  • GPT-2 medium / large / XL:N = 24 / 36 / 48
  • GPT-3:N = 96d_model = 12288
  • Llama 2 / 3 7B:N = 32。Llama 70B:N = 80
  • 现代前沿模型:一般 40–120 个块。

外面那一圈。一个完整的语言模型就是:

x = Embedding( token_ids )          // (seq_len, d_model)
for block in blocks:
    x = block(x)                    # 重复 N 次
x = FinalNorm(x)
logits = LM_Head(x)                 // (seq_len, vocab_size)

Embedding 把 token ID 变成 d_model 维向量(RoPE 那种相对位置信息 是在 attention 内部加的,不在这一层)。N 个块把向量变换一遍。 最后一次 norm 把尺度收拾干净。LM head —— 一个形状为 (d_model, vocab_size) 的线性层 —— 把每个 token 的向量 投影到整个词表上的分数;在这些分数上 softmax,就得到下一 token 的概率分布。 这一个线性层往往是整个模型里最大的一个权重张量 (GPT-2 small:50,000 × 768 = 3800 万;GPT-3:50,000 × 12288 = 6.14 亿)。

因果 mask 与 KV 缓存。对语言建模,每个块里的 attention 都带因果 mask(自注意力 primer 讲过)—— 每个 token 只能注意自己和之前的 token。 推理时,前面 token 的 K 和 V 矩阵会被缓存,避免每生成一个新 token 都 重新算一遍;这个 KV 缓存是生成文本时的主要显存代价, MQA / GQA / sliding window 之类的技巧都是为了把它缩小。

走到这里。20 篇 primer 下来,图终于齐了。 token → embedding → N 个 pre-norm 块(每个 = norm + 带 RoPE 的多头 attention + 残差 + norm + FFN + 残差)→ 最后的 norm → 线性 head → 下一 token logits。 这一句话 —— 加上一路上我们拆解过的所有零件 —— 就是一个现代 decoder-only LLM。 训练再加几件东西(交叉熵损失、AdamW、学习率调度、混合精度 —— 都在前面的 primer 里), 但架构本身,就是你用几百行 PyTorch 能写出来的那个。