Frameworks & Autograd Primer

The previous primers were about what a neural net does. This one is about the software people actually write to express it. Three short topics: PyTorch, JAX, TensorFlow — the frameworks every modern model is coded in; autograd — the trick that lets you write a forward pass and get every gradient back for free; and the computational graph — the data structure that makes the trick possible. After this primer, the line loss.backward() finally stops being magic.

01

PyTorch, JAX, TensorFlow

Three frameworks, same job: tensors on a GPU plus autograd.

Nobody writes neural net code from scratch in 2026. Every paper, every model release, every fine-tuning recipe is expressed in one of three frameworks. They all give you roughly the same abstractions — NumPy-like tensor operations, GPU/TPU backends, automatic differentiation, a layer library, an optimizer library — but they differ in style, ergonomics, and what they were built to be good at.

  • PyTorch (Meta, 2017). Dynamic graph, imperative Python. You write your forward like normal code; you debug it with print statements; you mutate things in place. Won the research community by 2020, and now eats production share too. The default choice in 2026 unless you have a specific reason for something else.
  • JAX (Google, 2018). Functional and composable. Built around a few program transformations — jax.grad (autodiff), jax.jit (compile to fast static graphs), jax.vmap (auto-vectorize), jax.pmap (multi-device) — that stack on top of each other. Trades some ergonomics for a much cleaner story when you want to do non-standard transformations on your training loop. Popular in research labs.
  • TensorFlow (Google, 2015). The original mass-deployment framework. Started with a static graph (define-then-run) that was hostile to debugging, added eager mode in TF 2.0 to compete with PyTorch, but never fully won back research mindshare. Still entrenched in production stacks at scale and in the Keras ecosystem.
PyTorch / JAX / TensorFlowPyTorchPyTorchJAXTensorFlowstyleimperativefunctionalimperativegraphdynamicstatic (jit)staticfeelvery Pythonicvery FPverboseshines atresearch, prodML researchproductionPyTorch (2017). Dynamic, imperative, very Pythonic — 2026 default.
1 / 3
Three frameworks, all with the same job: tensors on a GPU plus autograd. The differences are in style and what they were optimized for.

What they all share, and what makes any of them "a deep learning framework":

  • N-dimensional tensors as the universal data type, with a high-performance backend that runs the same ops on CPU, GPU (CUDA / ROCm), and TPU (XLA).
  • Automatic differentiation (§2) — the killer feature. You write the forward, the framework hands you every gradient.
  • A standard library of layers, losses, and optimizers — Linear, Conv2d, MultiHeadAttention, LayerNorm, Adam, AdamW, CrossEntropyLoss, etc. The same building blocks you saw across the earlier primers, all one import away.
  • Mixed-precision and distributed-training APIs — bf16 / fp16 / int8 for the hardware-tensors primer's VRAM tricks; FSDP / DeepSpeed / nn.parallel for multi-GPU sharding.

A peripheral but worth knowing: model code in PyTorch and JAX is often itself relatively short. The expensive part is data pipelines, distributed training orchestration, and inference servers. Modern frameworks ship companion libraries — torchvision / torchaudio / 🤗 transformers for PyTorch; Flax / Equinox / Haiku for JAX — that handle the layers above raw tensor ops. Picking a framework is partly picking an ecosystem.

In a Transformer: the reference implementations of every major LLM (LLaMA, Mistral, GPT-2, Gemma, Qwen, DeepSeek) ship in PyTorch first. JAX re-implementations follow days or weeks later. The training pipelines at the biggest labs (OpenAI, Anthropic, Google, Meta) are mostly JAX (Google) or PyTorch (most everyone else). Knowing one well, the other on demand — that's the working setup in 2026.

02

Autograd — You Write the Forward, You Get the Backward

The trick that lets a 70-billion-parameter model train itself.

The backprop primer (§1) said that "every modern framework computes gradients for you automatically." Here's how. The technique has a slightly clunky name: automatic differentiation, usually shortened to autodiff or, in PyTorch's case, autograd. It is the single most important software engineering achievement in deep learning.

The user-facing surface in PyTorch is two lines:

x = torch.tensor(2.0, requires_grad=True)
y = x ** 2 + 3 * x + 1
y.backward()                  # 1 line of math = 1 line of code

print(x.grad)                 # tensor(7.)

You wrote down y = x² + 3x + 1, you called .backward(), you got dy/dx = 2x + 3 = 7 at x = 2. You did not differentiate anything. The framework did.

y = x² + 3x + 1 at x = 2expression only3x+ 1yExpression in x. Autograd will record each op when you evaluate.
1 / 4
You write the forward expression. The framework remembers every operation, then walks the recording backward to give you ∂y/∂x.

How? Autograd is built on three pieces.

  • Per-op derivative rules. Every primitive op (matmul, add, exp, relu, softmax, attention, layer norm — every single one) has a hand-written rule for its derivative. d(x²)/dx = 2x; d(exp(x))/dx = exp(x);d(matmul(A, B))/dA = ... matrix algebra. These are baked into the framework — you almost never write your own.
  • A recording mechanism (forward). When you compute y = x ** 2 + 3 * x + 1, the framework doesn't just compute the result — it also records the operations: "y was produced by adding , 3x, and 1." Every tensor secretly knows which op produced it and what its inputs were. This is the famous "tape" or "computational graph" (§3).
  • A graph walker (backward). When you call y.backward(), the framework starts at y with ∂y/∂y = 1, looks up the op that produced y, applies that op's derivative rule, and propagates the gradient to y's inputs. Then it does the same for those inputs, recursing all the way back to leaves like x. Each leaf accumulates the total gradient.

Two flavors exist, and the deep-learning community uses one of them almost exclusively:

  • Reverse-mode autodiff (= backpropagation) — efficient when you have many inputs and few outputs. A neural net loss has billions of inputs (parameters) and one output (the loss scalar). One reverse-mode pass gives you all billion gradients in ~one extra forward pass of compute. This is what every modern framework defaults to.
  • Forward-mode autodiff — efficient when you have few inputs and many outputs. Used for Jacobian-vector products, some physics simulations, and certain meta-learning recipes. Frameworks support it (e.g., jax.jvp) but it's rare in mainstream deep learning.

Why "automatic" deserves the word. The alternatives are:

  • Symbolic differentiation (à la Mathematica) — compute a closed-form derivative expression. Works for small expressions; explodes combinatorially for deep networks. The derivative of an n-op expression can have on the order of n² terms. Useless at the scale of a Transformer.
  • Numerical differentiation (finite differences, (f(x + ε) − f(x)) / ε) — works for any function, but inaccurate (you're subtracting two floats and dividing by a small number) and slow (one extra forward pass per parameter; for a billion parameters, this is completely infeasible).
  • Autodiff — exact to machine precision (no rounding error stacking up), and total cost is one forward pass of recording plus one walk-back through the graph. The right answer for deep learning, and one of the field's most beautiful tricks.

In a Transformer: every loss.backward() call in a training loop is autograd at work, computing gradients for trillions of parameters across attention, FFN, layer norm, and embeddings — all from per-op rules the framework already knew. The user writes the forward; the framework writes the backward. That asymmetry, more than anything else, is what let neural networks scale from MNIST in 1998 to GPT-4 in 2023.

03

The Computational Graph

Under the hood, autograd is a DAG plus a walker.

Autograd looks like magic on the outside. Inside, it's a single data structure — a computational graph — and an algorithm that walks it in reverse. A computational graph is a directed acyclic graph (DAG) where each node is an operation and each edge is data flowing from one operation to another.

Take y = (x · 2 + 3)². The graph has four nodes and three data edges:

     ┌──────┐    ┌──────┐    ┌──────┐    ┌──────┐
x ──▶│  · 2 │──▶ │ + 3  │──▶ │(...)²│──▶ │  y   │
     └──────┘    └──────┘    └──────┘    └──────┘
       (mul)       (add)       (pow)
        =4          =7          =49

The framework builds this graph at the moment you compute y. Every intermediate tensor (the 4, the 7, the 49) keeps a back-pointer to the op that made it and the input tensor(s) it came from. That's how y "remembers" how to get back to x.

y = (x · 2 + 3)² at x = 2leaf onlyx= 2· 2+ 3(...)²Graph starts with leaf x = 2. No ops recorded yet.
1 / 3
Autograd's data structure is a DAG: nodes are operations, edges are data dependencies. Forward records; backward replays in reverse.

Backward then walks the graph in reverse. Starting at y with ∂y/∂y = 1, the framework asks the pow node what its derivative rule says, applies it (2 · (input), evaluated at input = 714), and assigns that as the gradient at the previous node. Then onto add (derivative is 1, gradient passes through unchanged → 14), then onto mul (derivative is 2, so multiply incoming gradient by 2 → 28). The total gradient at x is 28, which agrees with the closed-form answer d/dx (2x + 3)² = 4(2x + 3) = 28 at x = 2.

Two paradigms for managing the graph have shaped framework design:

  • Dynamic graphs (PyTorch eager, JAX outside jit) — the graph is built fresh every time you call the forward pass. If your forward has a Python if, different branches produce different graphs at different iterations. Easy to debug (you can pdb.set_trace() in the middle of the model), easy to express dynamic control flow. The cost is that the framework can't see ahead and optimize the whole graph as a unit; each op runs more or less independently.
  • Static graphs (TensorFlow 1.x, JAX with jit, PyTorch with torch.compile) — you define the graph once (often by tracing through Python with abstract inputs), the framework analyzes and compiles it, then you can call the compiled version many times at high speed. Much faster per call; the framework can fuse ops, allocate memory ahead of time, and pick the best CUDA kernels. Harder to debug: an error message in a compiled graph looks nothing like the Python you wrote.

The line between dynamic and static is blurring fast. Modern PyTorch (torch.compile) traces your eager-mode model and produces an optimized static graph for the parts it can statically analyze, falling back to eager mode for the parts it can't. Modern JAX traces aggressively and recompiles on shape changes. The end-state most frameworks are converging toward is: write dynamic, the framework figures out which slices to compile.

What else the graph buys you, beyond autograd:

  • Kernel fusion. A long chain of element-wise ops (e.g., relu(linear(x) + bias)) gets compiled into a single GPU kernel rather than three round-trips through VRAM. This is half of why torch.compile and JAX jit can be 2–10× faster than eager.
  • Memory planning. Knowing the graph in advance lets the framework allocate the right buffers and reuse them as activations are no longer needed. The "activation checkpointing" optimization from the hardware-tensors primer relies on graph awareness.
  • Device placement. The framework can decide which ops run on which GPU (or TPU pod), insert the right collective communication ops, and orchestrate distributed training — all from the graph.

In a Transformer: the forward pass through a 70B model is a computational graph with millions of nodes — Q·K matmul, softmax, attention·V, FFN matmul, residual add, layer norm — repeated dozens of times per block. Autograd builds this graph at the start of every training step, walks it backwards at loss.backward(), and the graph itself is also what gets compiled for high-throughput inference. The graph is, in a real sense, the model — the weights are just numbers stuffed into a graph's nodes. Most production deep-learning work in 2026 is, ultimately, writing code that constructs the right graph.