フレームワークと自動微分 入門

前の primer はニューラルネットが何をするかを扱った。この primer は、 それを実際に書くために人々が使うソフトウェアを扱う。短い 3 トピック:PyTorch、JAX、TensorFlow ——現代モデルはどれかで書かれている;自動微分(autograd) ——順伝播を書けば全勾配を無料で受け取れる仕掛け; そして計算グラフ ——その仕掛けを成立させるデータ構造。 この primer を読めば、loss.backward() はもう魔法ではなくなる。

01

PyTorch、JAX、TensorFlow

3 つのフレームワーク、同じ仕事:GPU 上のテンソル + 自動微分。

2026 年にニューラルネットコードをゼロから書く人はいない。あらゆる論文、モデル公開、 ファインチューニング手順は、3 つのフレームワークのいずれかで表現される。 提供される抽象はほぼ同じ —— NumPy 風のテンソル演算、GPU / TPU バックエンド、 自動微分、レイヤーライブラリ、オプティマイザライブラリ —— だがスタイル、手触り、「何が得意になるよう作られたか」が異なる。

  • PyTorch(Meta、2017)。動的グラフ、命令型 Python。 順伝播は普通のコードのように書き、print でデバッグし、in-place で書き換える。 2020 年頃に研究コミュニティを取り、今は本番でもシェアを伸ばしている。 2026 年に特別な理由がなければ既定の選択肢。
  • JAX(Google、2018)。関数型で合成可能。フレームワーク全体が 数個のプログラム変換 —— jax.grad(自動微分)、jax.jit(高速な静的グラフへコンパイル)、jax.vmap(自動ベクトル化)、jax.pmap(マルチデバイス)—— を中心に組まれ、 それらを重ねて使える。手触りを多少犠牲にして、「訓練ループに変わった変換を 適用したい」ときの話を遥かにきれいにする。研究所で人気。
  • TensorFlow(Google、2015)。元祖の大規模配備フレームワーク。 当初は静的グラフ(定義してから実行)でデバッグに厳しく、TF 2.0 で eager モードを 追加したが、研究のシェアは戻ってこなかった。大規模生産スタックと Keras エコシステム には今も根強く残る。
PyTorch / JAX / TensorFlowPyTorchPyTorchJAXTensorFlowスタイル命令型関数型命令型グラフ動的静的 (jit)静的手触りPythonic関数型寄り冗長得意分野研究・生産ML 研究生産配備PyTorch(2017)。動的、命令型、Pythonic —— 2026 年の既定。
1 / 3
3 つのフレームワーク、仕事は同じ:GPU 上のテンソル + 自動微分。違いはスタイルと、何に最適化されたか。

共通点、すなわち「深層学習フレームワーク」と呼ばれる条件:

  • N 次元テンソルを統一データ型として持ち、CPU、GPU(CUDA / ROCm)、 TPU(XLA)で同じ演算を回す高性能バックエンド。
  • 自動微分(§2)—— 真の決定打。順伝播を書けば、すべての勾配が 返ってくる。
  • 標準のレイヤー / 損失 / オプティマイザ群 —— Linear、Conv2d、 MultiHeadAttention、LayerNorm、Adam、AdamW、CrossEntropyLoss など。 これまでの primer に登場した部品はすべて import 1 行で使える。
  • 混合精度・分散訓練 API —— bf16 / fp16 / int8 でハードウェアと テンソル primer の VRAM 技、FSDP / DeepSpeed / nn.parallel でマルチ GPU 分割。

周辺だが知っておくべきこと:PyTorch と JAX のモデルコード自体は意外と短い。 重いのはデータパイプライン、分散訓練のオーケストレーション、推論サーバなど。 現代のフレームワークは付属ライブラリを揃える —— PyTorch には torchvision / torchaudio / 🤗 transformers、JAX には Flax / Equinox / Haiku —— 生のテンソル 演算の上の層を担う。フレームワーク選びは、エコシステム選びでもある。

Transformer では:主要 LLM(LLaMA、Mistral、GPT-2、Gemma、Qwen、 DeepSeek)の参考実装はまず PyTorch で出る。JAX への再実装は数日〜数週間後に続く。 最大級の研究所の訓練パイプラインは Google が JAX、それ以外の多くは PyTorch。 1 つに精通し、もう 1 つに必要に応じて手を出せる —— それが 2026 年の標準装備。

02

自動微分 —— 順伝播を書けば、逆伝播は返ってくる

700 億パラメータのモデルが自分で訓練できる仕掛け。

誤差逆伝播 primer §1 で「現代のフレームワークは自動で勾配を計算してくれる」と書いた。 ここではその仕組みを扱う。やや堅い名前を持つ:自動微分(automatic differentiation)、略して autodiff、PyTorch ではautograd。深層学習における最も重要なソフトウェア工学的成果だ。

PyTorch でユーザーが触る面はたった 2 行:

x = torch.tensor(2.0, requires_grad=True)
y = x ** 2 + 3 * x + 1
y.backward()                  # 数式 1 行 = コード 1 行

print(x.grad)                 # tensor(7.)

y = x² + 3x + 1 と書き、.backward() を呼んだら、x = 2 での dy/dx = 2x + 3 = 7 が手に入った。 微分は何もしていない。フレームワークがやった。

y = x² + 3x + 1 x = 2 にて式だけ3x+ 1yx の式。順伝播を走らせると autograd が各演算を記録し始める。
1 / 4
あなたは順伝播の式を書く。フレームワークは各演算を記録し、その記録を逆向きに辿って ∂y/∂x を返す。

どうやって? autograd は 3 つの部品でできている。

  • 各 op の微分規則。基本演算(matmul、add、exp、relu、softmax、 attention、LayerNorm —— どれも)には、手で書かれた微分規則がある。d(x²)/dx = 2x;d(exp(x))/dx = exp(x);d(matmul(A, B))/dA = ... の行列代数。すべてフレームワークに 焼き込まれている —— 自分で書くことはほぼない。
  • 記録機構(順伝播)。y = x ** 2 + 3 * x + 1 を計算する とき、フレームワークは結果だけを返さない —— 演算も記録する: 「y は 3x1 を足して作られた」。 すべてのテンソルは、自分を作った op と入力をひそかに知っている。これが有名な 「テープ」/「計算グラフ」(§3)。
  • グラフ巻き戻し器(逆伝播)。y.backward() を呼ぶと、 フレームワークは y から ∂y/∂y = 1 を種にして始まり、y を作った op を引き当て、その微分規則を適用し、勾配を入力に伝える。 入力たちに対しても同じことをし、再帰的に x のような葉まで戻る。 各葉は受け取った勾配を累算する。

autodiff には 2 つのモードがあるが、深層学習コミュニティはほぼ一方しか使わない:

  • 逆向きモード autodiff(= 誤差逆伝播)—— 入力が多く出力が少ない場合に効率的。ニューラルネットの損失は入力が数十億(パラメータ)、出力が 1 個 (スカラ損失)。1 回の逆向きパスで数十億個の勾配が、追加の順伝播 1 回分のコストで 返ってくる。あらゆる現代フレームワークの既定がこれ。
  • 前向きモード autodiff —— 入力が少なく出力が多い場合に効率的。 ヤコビアン・ベクトル積、一部の物理シミュレーション、一部のメタ学習で使う。 フレームワークは対応するが(例:jax.jvp)、主流深層学習では珍しい。

「自動」と呼ぶに値する理由。代替案を見ると分かる:

  • 記号微分(Mathematica 流)—— 閉形式の導関数を計算する。小さな式 では機能するが、深層ネットでは組合せ的に爆発する。n 演算の式の導関数は n² 規模の 項を持ち得る。Transformer 規模では役立たず。
  • 数値微分(有限差分、(f(x + ε) − f(x)) / ε)—— 任意の関数で使えるが、不正確(2 浮動小数の差を小さな数で割る)、そして遅い (パラメータごとに余分な順伝播 1 回;10 億パラメータでは完全に非現実的)。
  • autodiff —— 機械精度で厳密(丸め誤差の蓄積なし)、総コストは 「順伝播 1 回の記録 + 1 回の逆向き巻き戻し」。深層学習における正解、 そしてこの分野で最も美しい仕掛けの 1 つ。

Transformer では:訓練ループ内のすべての loss.backward()が autograd の仕事だ。attention、FFN、LayerNorm、埋め込みの各所で数兆の勾配を 計算する —— すべてフレームワークが既に知っている「各 op の微分規則」から。 ユーザーは順伝播を書き、フレームワークは逆伝播を書く。この非対称こそが、 ニューラルネットを 1998 年の MNIST から 2023 年の GPT-4 へとスケールさせた 最大の要因だ。

03

計算グラフ

蓋を開ければ、autograd は DAG と巻き戻し器に過ぎない。

外側から見ると autograd は魔法だ。中身は単一のデータ構造 —— 計算グラフ—— と、それを逆向きに辿るアルゴリズム。計算グラフは有向非巡回グラフ(DAG)で、 各ノードは演算、各辺は演算間を流れるデータ。

y = (x · 2 + 3)² なら、グラフはノード 4 個、データ辺 3 本になる:

     ┌──────┐    ┌──────┐    ┌──────┐    ┌──────┐
x ──▶│  · 2 │──▶ │ + 3  │──▶ │(...)²│──▶ │  y   │
     └──────┘    └──────┘    └──────┘    └──────┘
       (mul)       (add)       (pow)
        =4          =7          =49

フレームワークは y を計算した瞬間にこのグラフを構築する。各中間テンソル (4749)は、自分を作った op と入力テンソル への逆ポインタを持つ。これが y が「自分が x から来た」 ことを覚えている仕組みだ。

y = (x · 2 + 3)² x = 2 にて葉だけx= 2· 2+ 3(...)²グラフは葉 x = 2 から始まる。演算はまだ記録されていない。
1 / 3
autograd のデータ構造は DAG:ノードは演算、辺はデータ依存。順伝播で記録、逆伝播で逆再生。

逆伝播ではグラフを逆に辿る。y から ∂y/∂y = 1 を種に始め、pow ノードに微分規則を尋ね(2 · (input)input = 7で評価して 14)、それを 1 つ前のノードの勾配とする。次に add(微分は 1、勾配はそのまま通る → 14)、mul(微分は 2、入力勾配を 2 倍 → 28) へと進む。x での総勾配は 28、これはd/dx (2x + 3)² = 4(2x + 3) = 28(x = 2)という解析解に 一致する。

グラフ管理の 2 つのパラダイムが、フレームワーク設計を方向づけてきた:

  • 動的グラフ(PyTorch eager、jit 外の JAX)—— 順伝播のたびに新しくグラフが構築される。順伝播内に Python の ifがあれば、反復ごとに違うグラフになる。デバッグが楽 (モデルの真ん中で pdb.set_trace() できる)、動的制御フローも書きやすい。 代償:フレームワークは先読みしてグラフ全体をまとめて最適化できない; 各 op がほぼ独立に動く。
  • 静的グラフ(TensorFlow 1.x、jit 付きの JAX、torch.compile 付きの PyTorch)—— グラフを一度定義し (多くは抽象入力で Python をトレース)、フレームワークが解析・コンパイルし、 以降はコンパイル版を高速に何度も呼ぶ。1 回あたりはずっと速い:op の融合、 メモリの事前確保、最適な CUDA カーネル選択ができる。デバッグは難しい —— コンパイル後のエラーは、書いた Python とは別物に見える。

動的と静的の境界は急速にぼやけてきている。現代の PyTorch(torch.compile) は eager モデルをトレースし、静的解析できる部分は最適化した静的グラフへコンパイル、 できない部分は eager にフォールバックする。現代 JAX は積極的にトレースし、形状変化で 再コンパイルする。主流フレームワークがおおむね収束しつつある状態は「動的に書く、 コンパイルする範囲はフレームワークが決める」。

autograd 以外にグラフが与えてくれるもの:

  • カーネル融合。長い要素演算の連鎖(例:relu(linear(x) + bias))が 1 個の GPU カーネルにコンパイルされ、 VRAM への 3 往復が消える。torch.compile や JAX jit が eager より 2–10 倍速い理由の半分はこれ。
  • メモリ計画。事前にグラフが分かれば、フレームワークは適切な バッファを確保し、活性が不要になり次第再利用できる。ハードウェアとテンソル primer の「活性チェックポイント」最適化はグラフの理解に依存する。
  • デバイス配置。どの op がどの GPU(または TPU pod)で走るかを 決定し、適切な集団通信 op を挿入し、分散訓練を編成する —— すべてグラフから読める。

Transformer では:70B モデルの順伝播は、数百万ノードの計算グラフ —— Q·K 行列積、softmax、attention·V、FFN 行列積、残差加算、LayerNorm —— がブロック ごとに数十回繰り返される。各訓練ステップの先頭で autograd がこのグラフを構築し、loss.backward() でそれを逆向きに辿る;高スループット推論のために コンパイルされるのも同じグラフ。実用的な意味で、グラフ自体がモデルだ —— 重みはグラフのノードに詰め込まれた数値に過ぎない。2026 年の生産深層学習の仕事の ほとんどは、結局のところ「正しいグラフを構築するコードを書くこと」だ。