框架与自动微分 入门
前面的 primer 都在讲神经网络在做什么。这一篇讲人们真正用什么软件去把它写出来。 三个短主题:PyTorch、JAX、TensorFlow —— 每个现代模型都是用其中之一写的;自动微分(autograd) —— 让你写好前向、然后所有梯度免费送你的那个魔法;以及计算图 —— 让这个魔法能成立的数据结构。读完这一篇,loss.backward()这一行终于不再是魔法了。
PyTorch、JAX、TensorFlow
三个框架,做的事一样:GPU 上的张量 + 自动微分。
2026 年没人会从零开始写神经网络代码。每一篇论文、每一次模型发布、每一份微调脚本, 都是用这三个框架之一写的。它们提供的抽象大致一样 —— 类 NumPy 的张量操作、 GPU / TPU 后端、自动微分、层库、优化器库 —— 但在风格、手感、以及「为了擅长什么而被设计」这几件事上差异很大。
- PyTorch(Meta,2017)。动态图、命令式 Python。 前向就当普通代码写;用 print 调试;原地修改东西。2020 年前后赢了研究圈, 现在生产份额也在抢。除非你有特别理由,2026 年的默认选择就是它。
- JAX(Google,2018)。函数式、可组合。整个框架围绕几个程序变换 ——
jax.grad(自动微分)、jax.jit(编译成快速静态图)、jax.vmap(自动向量化)、jax.pmap(多设备)—— 这些变换可以叠加使用。牺牲点手感,换来一份「想对训练循环做花式变换时」干净得多的故事。 在研究实验室很流行。 - TensorFlow(Google,2015)。最早的工业级部署框架。一开始是静态图 (先定义再跑),对调试极不友好,TF 2.0 加了 eager 模式想跟 PyTorch 抢市场, 但研究圈再没回来。还深嵌在大规模生产栈和 Keras 生态里。
它们都有的东西,也是「凡是个深度学习框架」必须有的东西:
- N 维张量作为统一的数据类型,加上能在 CPU、GPU(CUDA / ROCm)、 TPU(XLA)上跑同一套 op 的高性能后端。
- 自动微分(§2)—— 真正的杀手锏。你写好前向,框架把所有梯度还给你。
- 标准层、损失、优化器库 —— Linear、Conv2d、MultiHeadAttention、 LayerNorm、Adam、AdamW、CrossEntropyLoss 等等。前面 primer 里出现过的所有积木, import 一下就在。
- 混合精度和分布式训练 API —— bf16 / fp16 / int8 对应硬件与张量 primer 里那些 VRAM 招;FSDP / DeepSpeed / nn.parallel 用来做多卡分片。
一个边缘但值得知道的事:PyTorch 和 JAX 里的模型代码本身往往挺短。 贵的是数据流水线、分布式训练编排、推理服务这些东西。现代框架都带一票配套库 —— PyTorch 的 torchvision / torchaudio / 🤗 transformers; JAX 的 Flax / Equinox / Haiku —— 处理「张量操作之上」那一层。 选框架很大一部分是在选生态。
在 Transformer 里:每个主流 LLM(LLaMA、Mistral、GPT-2、Gemma、 Qwen、DeepSeek)的参考实现都是先用 PyTorch 出。JAX 再实现往往几天或几周之后跟上。 最大那几个实验室的训练流水线,Google 那家是 JAX,大部分其它家是 PyTorch。 熟一门、能上手另一门 —— 这是 2026 年的工作配置。
自动微分 —— 你写前向、它给反向
让一个 700 亿参数模型能自己训起来的那个魔法。
反向传播 primer §1 里说过「每个现代框架都帮你自动算梯度」。这里讲它是怎么算的。 这项技术名字稍微拗口:自动微分(automatic differentiation), 简称 autodiff,在 PyTorch 里叫 autograd。 它是深度学习领域里最重要的一项软件工程成就。
在 PyTorch 里,用户面对的代码就两行:
x = torch.tensor(2.0, requires_grad=True) y = x ** 2 + 3 * x + 1 y.backward() # 一行数学 = 一行代码 print(x.grad) # tensor(7.)
你写了 y = x² + 3x + 1、你叫了 .backward()、 你拿到了 x = 2 时 dy/dx = 2x + 3 = 7。 你没微分过任何东西。框架做了。
怎么做到的?autograd 由三块东西组成。
- 每个 op 的导数规则。每一个基础操作(matmul、add、exp、relu、 softmax、attention、LayerNorm —— 每一个)都有一条手写的导数规则。
d(x²)/dx = 2x;d(exp(x))/dx = exp(x);d(matmul(A, B))/dA = ...矩阵代数。 这些规则烤进了框架里 —— 你几乎不用自己写。 - 录制机制(前向)。当你算
y = x ** 2 + 3 * x + 1时, 框架不只是算结果 —— 它同时把所有操作录下来: 「y 是由x²、3x和1相加得到的」。 每个张量都偷偷知道是哪个 op 产生它、输入是什么。这就是著名的「磁带」/「计算图」(§3)。 - 图回放器(反向)。当你调
y.backward()时, 框架从y起步、种下∂y/∂y = 1,查到生成y的那个 op,套用它的导数规则,把梯度传播回y的输入。 然后对这些输入再来一遍,一路递归回到x这样的叶子。 每个叶子把累计的总梯度收下。
autodiff 有两种模式,深度学习社区只用其中一种:
- 反向模式 autodiff(= 反向传播)—— 在「很多输入、少量输出」的情形下高效。一个神经网络的 loss 有几十亿个输入 (参数)、一个输出(loss 标量)。一次反向模式给你这几十亿个梯度, 代价大约只是再多一次前向。所有现代框架的默认就是这个。
- 前向模式 autodiff —— 在「少量输入、很多输出」下高效。 常用于雅可比向量积、某些物理仿真、某些元学习方法。框架支持它(比如
jax.jvp),但在主流深度学习里很少见。
「自动」二字配得上「自动」,因为另外的选项都不行:
- 符号微分(像 Mathematica 那样)—— 算出一个解析式。 小表达式行得通;一旦深度网络就组合爆炸。一个有 n 个 op 的表达式, 导数能有 n² 量级的项。在 Transformer 规模下根本没用。
- 数值微分(有限差分,
(f(x + ε) − f(x)) / ε)—— 对任意函数都能用,但精度差(两个浮点数相减再除以一个小数)、慢 (每个参数都要多算一次前向;十亿参数完全不可行)。 - autodiff —— 机器精度上是精确的(没有舍入误差累积), 总代价是「一次前向录制 + 一次反向图回放」。深度学习的正确答案, 也是这个领域里最漂亮的招之一。
在 Transformer 里:训练循环里的每一次 loss.backward()调用,都是 autograd 在干活,在 attention、FFN、LayerNorm、embedding 之间 算几万亿个参数的梯度 —— 全部靠框架早就知道的「每个 op 的导数规则」。 用户写前向、框架写反向。这种不对称,比任何其它一件事都更深地决定了 「神经网络为什么能从 1998 年的 MNIST,做到 2023 年的 GPT-4」。
计算图
掀开盖子,autograd 就是一张 DAG + 一个回放器。
从外面看,autograd 像魔法。掀开盖子,里面是一个数据结构 —— 计算图 —— 加一套反向走它的算法。计算图是一张有向无环图(DAG):每个节点是一个操作, 每条边是数据从一个操作流向另一个操作。
以 y = (x · 2 + 3)² 为例,这张图有 4 个节点、3 条数据边:
┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐
x ──▶│ · 2 │──▶ │ + 3 │──▶ │(...)²│──▶ │ y │
└──────┘ └──────┘ └──────┘ └──────┘
(mul) (add) (pow)
=4 =7 =49框架在你计算 y 的那一刻把这张图建好。每个中间张量 (4、7、49)都带一个反向指针,指向产生它的 op 和它的输入张量。这就是 y 怎么「记得」自己一路怎么回到 x 的。
反向就是把图倒着走。从 y 出发、种下 ∂y/∂y = 1, 框架找 pow 节点的导数规则,套用 (2 · (input), 在 input = 7 处求值 → 14),把这个结果赋给上一节点的梯度。 然后到 add(导数是 1,梯度原样穿过 → 14)、再到 mul(导数是 2,所以入梯度乘 2 → 28)。x 那里的总梯度是 28, 跟解析答案 d/dx (2x + 3)² = 4(2x + 3) = 28(在 x = 2)一致。
「怎么管这张图」有两种范式,塑造了框架的设计:
- 动态图(PyTorch eager、JAX 在
jit之外)—— 每次前向都重新建一张图。前向里有 Python 的if的话, 不同迭代会走出不同的图。容易调试(你可以在模型中间pdb.set_trace()),容易写动态控制流。 代价是框架看不到全局、没法把整张图当一个单位去优化;每个 op 多多少少独立跑。 - 静态图(TensorFlow 1.x、JAX 加
jit、PyTorch 加torch.compile)—— 图先定义好(通常是「拿抽象输入跑一遍 Python 追踪」), 框架分析并编译,然后你可以高速反复调用编译版本。每次调用快得多; 框架可以融合 op、提前分配显存、挑出最佳 CUDA kernel。代价是调试更难: 编译后图里报的错跟你写的 Python 完全是两回事。
动态和静态的界限正在快速模糊。现代 PyTorch(torch.compile)会追踪你的 eager 模型,对能静态分析的部分编出优化后的静态图,对不能的部分回退到 eager。 现代 JAX 激进地追踪、并在形状变化时重新编译。所有主流框架在收敛到的状态是: 写动态、框架自己决定哪些片段去编译。
除了 autograd 之外,这张图还能给你的东西:
- Kernel 融合。一长串逐元素操作(比如
relu(linear(x) + bias))可以被编进一个 GPU kernel, 不用在 VRAM 里跑三趟来回。这是torch.compile和 JAXjit能比 eager 快 2–10 倍的一半原因。 - 内存规划。提前知道整张图,框架就能分配合适的 buffer, 并在激活不再被用时复用空间。硬件与张量 primer 里说的「激活检查点」 优化,就靠对图的理解。
- 设备分配。框架可以决定哪个 op 跑在哪个 GPU / TPU pod 上、 插入合适的集合通信 op、编排分布式训练 —— 全部从图里读出来。
在 Transformer 里:一个 70B 模型的前向, 是一张几百万节点的计算图 —— Q·K 矩阵乘、softmax、attention·V、FFN 矩阵乘、 残差求和、LayerNorm —— 每个块里重复几十次。 每次训练步开头 autograd 建好这张图,loss.backward() 把它反着走一遍; 部署高吞吐推理时,被编译的也是这张图。从一个非常实际的意义上说, 图就是模型 —— 权重不过是塞进图节点里的一堆数。 2026 年大多数生产级深度学习工作,归根结底是「写代码构造合适的图」。