Theory NotebookMath for LLMs

Automatic Differentiation

Multivariate Calculus / Automatic Differentiation

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Automatic Differentiation — Theory Notebook

"Differentiating programs, not formulas — that is the key insight that makes modern deep learning possible."

Interactive derivations: dual numbers from scratch, reverse-mode AD, JVP/VJP, higher-order derivatives, gradient checkpointing, custom autograd, and meta-learning gradients.

Run cells top-to-bottom. Each section builds on the previous.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3

import numpy as np
import numpy.linalg as la
from scipy import optimize, special, stats
from scipy.optimize import minimize, fsolve, linprog
from math import factorial
import math
import matplotlib.patches as patches

COLORS = {
    "primary": "#0077BB",
    "secondary": "#EE7733",
    "tertiary": "#009988",
    "error": "#CC3311",
    "neutral": "#555555",
    "highlight": "#EE3377",
}
HAS_MPL = True
np.set_printoptions(precision=8, suppress=True)
np.random.seed(42)
spmin = minimize

try:
    import torch
    HAS_TORCH = True
except ImportError:
    torch = None
    HAS_TORCH = False


def header(title):
    print("\n" + "=" * len(title))
    print(title)
    print("=" * len(title))

def check_true(name, cond):
    ok = bool(cond)
    print(f"{'PASS' if ok else 'FAIL'} - {name}")
    return ok

def check_close(name, got, expected, tol=1e-8):
    ok = np.allclose(got, expected, atol=tol, rtol=tol)
    print(f"{'PASS' if ok else 'FAIL'} - {name}: got {got}, expected {expected}")
    return ok

def check(name, got, expected, tol=1e-8):
    return check_close(name, got, expected, tol=tol)

def sigmoid(x):
    x = np.asarray(x, dtype=float)
    return np.where(x >= 0, 1/(1+np.exp(-x)), np.exp(x)/(1+np.exp(x)))

def softmax(z, axis=-1):
    z = np.asarray(z, dtype=float)
    z = z - np.max(z, axis=axis, keepdims=True)
    e = np.exp(z)
    return e / np.sum(e, axis=axis, keepdims=True)

def relu(x):
    return np.maximum(0, x)

def relu_prime(x):
    return np.where(np.asarray(x) > 0, 1.0, 0.0)

def centered_diff(f, x, h=1e-6):
    return (f(x + h) - f(x - h)) / (2 * h)

def numerical_gradient(f, x, h=1e-6):
    x = np.asarray(x, dtype=float)
    grad = np.zeros_like(x, dtype=float)
    for idx in np.ndindex(x.shape):
        xp = x.copy(); xm = x.copy()
        xp[idx] += h; xm[idx] -= h
        grad[idx] = (f(xp) - f(xm)) / (2*h)
    return grad

def numerical_jacobian(f, x, h=1e-6):
    x = np.asarray(x, dtype=float)
    y0 = np.asarray(f(x), dtype=float)
    J = np.zeros((y0.size, x.size), dtype=float)
    flat = x.reshape(-1)
    for j in range(x.size):
        xp = flat.copy(); xm = flat.copy()
        xp[j] += h; xm[j] -= h
        J[:, j] = ((np.asarray(f(xp.reshape(x.shape))) - np.asarray(f(xm.reshape(x.shape)))) / (2*h)).reshape(-1)
    return J.reshape(y0.shape + x.shape)

def grad_check(f, x, analytic_grad, h=1e-6):
    numeric_grad = numerical_gradient(f, x, h=h)
    denom = la.norm(analytic_grad) + la.norm(numeric_grad) + 1e-12
    return la.norm(analytic_grad - numeric_grad) / denom



def jacobian_fd(f, x, h=1e-6):
    x = np.asarray(x, dtype=float)
    y0 = np.asarray(f(x), dtype=float)
    J = np.zeros((y0.size, x.size), dtype=float)
    flat = x.reshape(-1)
    for j in range(x.size):
        xp = flat.copy(); xm = flat.copy()
        xp[j] += h; xm[j] -= h
        yp = np.asarray(f(xp.reshape(x.shape)), dtype=float).reshape(-1)
        ym = np.asarray(f(xm.reshape(x.shape)), dtype=float).reshape(-1)
        J[:, j] = (yp - ym) / (2*h)
    return J.reshape(y0.shape + x.shape)

def hessian_fd(f, x, h=1e-5):
    x = np.asarray(x, dtype=float)
    H = np.zeros((x.size, x.size), dtype=float)
    flat = x.reshape(-1)
    for j in range(x.size):
        xp = flat.copy(); xm = flat.copy()
        xp[j] += h; xm[j] -= h
        gp = numerical_gradient(lambda z: f(z.reshape(x.shape)), xp.reshape(x.shape), h=h).reshape(-1)
        gm = numerical_gradient(lambda z: f(z.reshape(x.shape)), xm.reshape(x.shape), h=h).reshape(-1)
        H[:, j] = (gp - gm) / (2*h)
    return H.reshape(x.shape + x.shape)



def grad_fd(f, x, h=1e-6):
    return numerical_gradient(f, x, h=h)



def fd_grad(f, x, h=1e-6):
    return numerical_gradient(f, np.asarray(x, dtype=float), h=h)

print("Chapter helper setup complete.")

1. The Three Ways to Differentiate

Before building AD, we compare all three approaches on the same function f(x)=x2sin(x)f(x) = x^2 \sin(x).

Code cell 5

# === 1.1 Comparison: Finite Differences vs Symbolic vs AD ===

# We'll compute df/dx at x=1.5 three ways.
import math

x0 = 1.5

# Ground truth (analytic derivative: 2x sin x + x^2 cos x)
def f(x): return x**2 * math.sin(x)
def f_prime_exact(x): return 2*x*math.sin(x) + x**2*math.cos(x)

exact = f_prime_exact(x0)
print(f'Exact derivative f prime(1.5) = {exact:.10f}')
print()

# 1) Finite differences (forward, central, extrapolated)
errors_fd = []
hs = [1e-1, 1e-2, 1e-4, 1e-6, 1e-8, 1e-10, 1e-12, 1e-14]
for h in hs:
    fd_forward = (f(x0 + h) - f(x0)) / h
    fd_central = (f(x0 + h) - f(x0 - h)) / (2*h)
    errors_fd.append((h, abs(fd_forward - exact), abs(fd_central - exact)))

print('Finite differences error vs step size h:')
print(f'  {"h":>10}  {"Forward":>14}  {"Central":>14}')
for h, ef, ec in errors_fd:
    print(f'  {h:>10.0e}  {ef:>14.2e}  {ec:>14.2e}')
print()
print('Best central FD accuracy: ~1e-10 (limited by cancellation at h~1e-5)')

Code cell 6

# === 1.2 Visualise FD accuracy vs step size ===

if HAS_MPL:
    hs_arr = np.array([h for h,_,_ in errors_fd])
    errs_fwd = np.array([ef for _,ef,_ in errors_fd])
    errs_cen = np.array([ec for _,_,ec in errors_fd])

    fig, ax = plt.subplots(figsize=(9,5))
    ax.loglog(hs_arr, errs_fwd, 'o-', label='Forward FD', color='tab:blue')
    ax.loglog(hs_arr, errs_cen, 's-', label='Central FD', color='tab:orange')
    ax.axhline(1e-15, color='tab:green', linestyle='--', label='AD (machine eps)')
    ax.set_xlabel('Step size h')
    ax.set_ylabel('|error|')
    ax.set_title('Finite Difference Error vs Step Size\nAD achieves machine precision for any h')
    ax.legend()
    ax.invert_xaxis()
    plt.tight_layout()
    plt.show()
    print('Observation: FD has a sweet spot around h~1e-5 for forward, 1e-5 for central.')
    print('AD is free of this tradeoff — exact to machine precision always.')

2. Dual Numbers — Algebraic Foundation of Forward-Mode AD

A dual number a+bεa + b\varepsilon with ε2=0\varepsilon^2 = 0 automatically carries derivative information alongside function values.

Code cell 8

# === 2.1 Dual Number Class — Full Implementation ===

import math

class Dual:
    """Dual number: real + dual*eps, eps^2 = 0."""
    def __init__(self, real, dual=0.0):
        self.real = float(real)
        self.dual = float(dual)

    def __add__(self, other):
        if isinstance(other, (int, float)):
            return Dual(self.real + other, self.dual)
        return Dual(self.real + other.real, self.dual + other.dual)
    __radd__ = __add__

    def __sub__(self, other):
        if isinstance(other, (int, float)):
            return Dual(self.real - other, self.dual)
        return Dual(self.real - other.real, self.dual - other.dual)
    def __rsub__(self, other):
        return Dual(other - self.real, -self.dual)

    def __mul__(self, other):
        if isinstance(other, (int, float)):
            return Dual(self.real * other, self.dual * other)
        return Dual(self.real * other.real,
                    self.real * other.dual + self.dual * other.real)
    __rmul__ = __mul__

    def __truediv__(self, other):
        if isinstance(other, (int, float)):
            return Dual(self.real / other, self.dual / other)
        r = other.real
        return Dual(self.real / r,
                    (self.dual * r - self.real * other.dual) / (r*r))

    def __pow__(self, n):  # integer/float power
        return Dual(self.real**n, n * self.real**(n-1) * self.dual)

    def __neg__(self):
        return Dual(-self.real, -self.dual)

    def __repr__(self):
        sign = '+' if self.dual >= 0 else '-'
        return f'{self.real:.6f} {sign} {abs(self.dual):.6f}eps'

# Elementary dual functions
def d_exp(x): return Dual(math.exp(x.real), x.dual * math.exp(x.real))
def d_log(x): return Dual(math.log(x.real), x.dual / x.real)
def d_sin(x): return Dual(math.sin(x.real), x.dual * math.cos(x.real))
def d_cos(x): return Dual(math.cos(x.real), -x.dual * math.sin(x.real))
def d_sqrt(x): return Dual(math.sqrt(x.real), x.dual / (2*math.sqrt(x.real)))

print('Dual number class ready.')
# Quick test: (2 + 1*eps)^2 should give 4 + 4*eps (value=4, deriv=2x=4)
x = Dual(2.0, 1.0)
print(f'(2+eps)^2 = {x**2}  [expected: 4.0 + 4.0eps]')
# (1+eps)*(2+0) = 2 + 2eps
print(f'(2+eps)*(2+0) = {Dual(2.0,1.0)*Dual(2.0,0.0)}  [expected: 4.0 + 2.0eps]')

Code cell 9

# === 2.2 Derivative Extraction via Dual Numbers ===

# f(x) = x^2 * sin(x), compute f'(1.5)
def f_dual(x):
    return x**2 * d_sin(x)

x_seed = Dual(1.5, 1.0)  # seed: x=1.5, dx/dx=1
result = f_dual(x_seed)
print(f'f(1.5) = {result.real:.10f}')
print(f"f'(1.5) via dual = {result.dual:.10f}")
exact = 2*1.5*math.sin(1.5) + 1.5**2*math.cos(1.5)
print(f"f'(1.5) exact    = {exact:.10f}")
print(f'Error: {abs(result.dual - exact):.2e}')
print()

# More complex example: f(x) = log(sin(x^2) + exp(x))
def g_dual(x):
    return d_log(d_sin(x**2) + d_exp(x))

x = 0.8
gx = g_dual(Dual(x, 1.0))
print(f'g(0.8) = {gx.real:.8f}')
print(f"g'(0.8) via dual = {gx.dual:.8f}")

# Finite difference verification
h = 1e-5
def g(x): return math.log(math.sin(x**2) + math.exp(x))
fd = (g(x+h) - g(x-h)) / (2*h)
print(f"g'(0.8) via FD   = {fd:.8f}")
print(f'Error vs FD: {abs(gx.dual - fd):.2e}  (FD error, not dual error)')

Code cell 10

# === 2.3 Multivariate Forward Mode — Partial Derivatives ===

# f(x1, x2) = (x1 + x2) * sin(x1)
# Compute df/dx1 and df/dx2 separately

def f_multi(x1, x2):
    return (x1 + x2) * d_sin(x1)

# df/dx1: seed x1 with dual=1, x2 with dual=0
r1 = f_multi(Dual(1.0, 1.0), Dual(2.0, 0.0))
df_dx1 = r1.dual

# df/dx2: seed x2 with dual=1, x1 with dual=0
r2 = f_multi(Dual(1.0, 0.0), Dual(2.0, 1.0))
df_dx2 = r2.dual

print('Forward-mode AD trace for f(x1,x2) = (x1+x2)*sin(x1) at (1,2):')
print(f'  f(1,2) = {r1.real:.6f}')
print(f'  df/dx1 = {df_dx1:.6f}')
print(f'  df/dx2 = {df_dx2:.6f}')
print()

# Manual check
import math
# df/dx1 = sin(x1) + (x1+x2)*cos(x1) at (1,2)
expected_dx1 = math.sin(1) + 3*math.cos(1)
expected_dx2 = math.sin(1)
print(f'Analytic df/dx1 = {expected_dx1:.6f}  match: {abs(df_dx1-expected_dx1)<1e-10}')
print(f'Analytic df/dx2 = {expected_dx2:.6f}  match: {abs(df_dx2-expected_dx2)<1e-10}')
print()
print('Note: forward mode needed 2 passes to compute both partials.')
print('Reverse mode (next section) computes both in 1 pass.')

3. Reverse-Mode AD — Adjoint Accumulation

Reverse mode propagates adjoint variables vˉk=f/vk\bar{v}_k = \partial f / \partial v_k backward through the graph. One backward pass gives the full gradient.

Code cell 12

# === 3.1 Manual Reverse-Mode Trace ===

# f(x1, x2) = (x1 + x2) * sin(x1) at (1, 2)
# Forward pass: compute all intermediate values

import math

x1, x2 = 1.0, 2.0

# Forward
v1 = x1              # = 1.0
v2 = x2              # = 2.0
v3 = v1 + v2         # = 3.0
v4 = math.sin(v1)    # = sin(1) = 0.8415
v5 = v3 * v4         # = 3.0 * 0.8415 = 2.524

print('=== Forward Pass ===')
print(f'v1 = {v1:.6f}  (x1)')
print(f'v2 = {v2:.6f}  (x2)')
print(f'v3 = {v3:.6f}  (v1 + v2)')
print(f'v4 = {v4:.6f}  (sin(v1))')
print(f'v5 = {v5:.6f}  (v3 * v4) = f(x1,x2)')
print()

# Backward: initialise v5_bar = 1
v5_bar = 1.0
# v5 = v3 * v4  =>  v3_bar += v5_bar * v4,  v4_bar += v5_bar * v3
v3_bar = v5_bar * v4
v4_bar = v5_bar * v3
# v4 = sin(v1)  =>  v1_bar += v4_bar * cos(v1)
v1_bar_from_v4 = v4_bar * math.cos(v1)
# v3 = v1 + v2  =>  v1_bar += v3_bar * 1,  v2_bar += v3_bar * 1
v1_bar_from_v3 = v3_bar * 1.0
v2_bar = v3_bar * 1.0
v1_bar = v1_bar_from_v4 + v1_bar_from_v3  # accumulate!

print('=== Backward Pass ===')
print(f'v5_bar = {v5_bar:.6f}  (df/df = 1)')
print(f'v3_bar = {v3_bar:.6f}  (v5_bar * v4)')
print(f'v4_bar = {v4_bar:.6f}  (v5_bar * v3)')
print(f'v1_bar += {v1_bar_from_v4:.6f}  (from sin: v4_bar * cos(v1))')
print(f'v1_bar += {v1_bar_from_v3:.6f}  (from add: v3_bar)')
print(f'v2_bar = {v2_bar:.6f}')
print()
print(f'=== Gradient ===')
print(f'df/dx1 = v1_bar = {v1_bar:.6f}')
print(f'df/dx2 = v2_bar = {v2_bar:.6f}')

# Verification
expected_dx1 = math.sin(1) + 3*math.cos(1)
expected_dx2 = math.sin(1)
print(f'\nVerification: match dx1={abs(v1_bar-expected_dx1)<1e-12}, match dx2={abs(v2_bar-expected_dx2)<1e-12}')

Code cell 13

# === 3.2 Micrograd-Style Reverse-Mode AD Engine ===

# Build a minimal autograd engine (inspired by Karpathy's micrograd)

class Value:
    """Scalar value with reverse-mode autograd."""
    def __init__(self, data, _children=(), _op='', label=''):
        self.data = float(data)
        self.grad = 0.0
        self._backward = lambda: None
        self._prev = set(_children)
        self._op = _op
        self.label = label

    def __add__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        out = Value(self.data + other.data, (self, other), '+')
        def _bwd():
            self.grad += out.grad * 1.0
            other.grad += out.grad * 1.0
        out._backward = _bwd
        return out

    def __mul__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        out = Value(self.data * other.data, (self, other), '*')
        def _bwd():
            self.grad += out.grad * other.data
            other.grad += out.grad * self.data
        out._backward = _bwd
        return out

    def sin(self):
        out = Value(math.sin(self.data), (self,), 'sin')
        def _bwd():
            self.grad += out.grad * math.cos(self.data)
        out._backward = _bwd
        return out

    def exp(self):
        out = Value(math.exp(self.data), (self,), 'exp')
        def _bwd():
            self.grad += out.grad * out.data  # d/dx exp(x) = exp(x)
        out._backward = _bwd
        return out

    def log(self):
        out = Value(math.log(self.data), (self,), 'log')
        def _bwd():
            self.grad += out.grad / self.data
        out._backward = _bwd
        return out

    def __pow__(self, n):
        out = Value(self.data**n, (self,), f'**{n}')
        def _bwd():
            self.grad += out.grad * n * self.data**(n-1)
        out._backward = _bwd
        return out

    __rmul__ = __mul__
    __radd__ = __add__

    def backward(self):
        """Topological sort + reverse accumulation."""
        topo = []
        visited = set()
        def build(v):
            if v not in visited:
                visited.add(v)
                for child in v._prev:
                    build(child)
                topo.append(v)
        build(self)
        self.grad = 1.0
        for v in reversed(topo):
            v._backward()

    def __repr__(self):
        return f'Value(data={self.data:.4f}, grad={self.grad:.4f})'

print('Micrograd engine ready.')

# Test: f(x1, x2) = (x1 + x2) * sin(x1)
x1 = Value(1.0, label='x1')
x2 = Value(2.0, label='x2')
f_val = (x1 + x2) * x1.sin()
f_val.backward()

print(f'f(1,2) = {f_val.data:.6f}')
print(f'df/dx1 = {x1.grad:.6f}  (expected: {math.sin(1)+3*math.cos(1):.6f})')
print(f'df/dx2 = {x2.grad:.6f}  (expected: {math.sin(1):.6f})')
print(f'Match: {abs(x1.grad - (math.sin(1)+3*math.cos(1))) < 1e-12}')

Code cell 14

# === 3.3 Visualise the Computational Graph ===

# Trace a slightly more complex example and print the graph structure
# f(x) = log(x^2 + exp(x))

x = Value(1.0, label='x')
a = x**2;           a.label = 'x^2'
b = x.exp();        b.label = 'exp(x)'
c = a + b;          c.label = 'x^2+exp(x)'
out = c.log();      out.label = 'log(x^2+exp(x))'
out.backward()

print('Computational graph nodes (in reverse topological order):')

# Reconstruct order via DFS
topo, visited = [], set()
def build_topo(v):
    if v not in visited:
        visited.add(v)
        for ch in v._prev: build_topo(ch)
        topo.append(v)
build_topo(out)

for v in reversed(topo):
    print(f'  {v.label:20s}  data={v.data:8.4f}  grad={v.grad:8.4f}  op={v._op}')

# Analytic check: f'(x) = (2x + exp(x)) / (x^2 + exp(x))
xv = 1.0
expected_grad = (2*xv + math.exp(xv)) / (xv**2 + math.exp(xv))
print(f'\nAnalytic grad = {expected_grad:.6f}')
print(f'Autograd grad = {x.grad:.6f}')
print(f'Match: {abs(x.grad - expected_grad) < 1e-12}')

4. JVP vs VJP — Choosing the Right Mode

For f:RnRmf : \mathbb{R}^n \to \mathbb{R}^m, forward mode (JVP) costs O(nT(f))O(n \cdot T(f)) and reverse mode (VJP) costs O(mT(f))O(m \cdot T(f)). The crossover is at n=mn = m.

Code cell 16

# === 4.1 JVP and VJP via PyTorch (or pure NumPy) ===

if HAS_TORCH:
    # Forward-mode JVP via torch.autograd.functional.jvp
    from torch.autograd.functional import jvp, vjp

    def f_torch(x): return torch.stack([x[0]**2, torch.sin(x[1]), x[0]*x[1]])

    x = torch.tensor([1.0, 2.0])
    v_tangent = torch.tensor([1.0, 0.0])   # tangent: d/dx1 direction

    primals, tangents = jvp(f_torch, (x,), (v_tangent,))
    print('=== Forward-mode JVP ===')
    print(f'f(x) = {primals.numpy().round(4)}')
    print(f'Jf @ v = {tangents.numpy().round(4)}  (should be [2x1, 0, x2] = [2, 0, 2])')

    v_cotangent = torch.tensor([1.0, 1.0, 1.0])  # cotangent: sum all outputs
    _, cotangents = vjp(f_torch, (x,), v_cotangent)
    print('\n=== Reverse-mode VJP ===')
    print(f'v.T @ Jf = {cotangents[0].numpy().round(4)}')
    # d/dx1(x1^2 + sin(x2=const) + x1*x2) = 2x1 + x2 = 2+2=4
    # d/dx2(x1^2 + sin(x2) + x1*x2) = cos(x2) + x1 = cos(2)+1
    expected = torch.tensor([2*1.0 + 2.0, torch.cos(torch.tensor(2.0)) + 1.0])
    print(f'Expected:     {expected.numpy().round(4)}')

else:
    # NumPy fallback: compute JVP and VJP manually for a linear map f(x)=Ax
    A = np.array([[2., 0., 1.], [0., 1., 3.]])   # 2x3 Jacobian
    x = np.array([1., 2., 3.])
    v = np.array([1., 0., 0.])   # tangent
    u = np.array([1., 1.])       # cotangent

    jvp_result = A @ v
    vjp_result = A.T @ u

    print('Linear map f(x) = Ax, A shape (2x3):')
    print(f'JVP (A @ v) = {jvp_result}  (shape m={A.shape[0]})')
    print(f'VJP (A.T @ u) = {vjp_result}  (shape n={A.shape[1]})')
    print('Forward mode: 3 passes for full Jacobian')
    print('Reverse mode: 2 passes for full Jacobian')

Code cell 17

# === 4.2 Crossover Analysis: When to Use Forward vs Reverse ===

import time

def time_jacobian(n, m, mode='forward'):
    """Time full Jacobian computation for a random linear map f(x) = Ax."""
    A = np.random.randn(m, n)
    x = np.random.randn(n)

    t0 = time.perf_counter()
    if mode == 'forward':
        # n JVPs: J[:, j] = A @ e_j = A[:, j]
        J = np.zeros((m, n))
        for j in range(n):
            ej = np.zeros(n); ej[j] = 1.0
            J[:, j] = A @ ej
    else:  # reverse
        # m VJPs: J[i, :] = e_i.T @ A = A[i, :]
        J = np.zeros((m, n))
        for i in range(m):
            ei = np.zeros(m); ei[i] = 1.0
            J[i, :] = A.T @ ei
    t1 = time.perf_counter()
    return t1 - t0, J

print('Jacobian timing: n inputs, m outputs')
print(f'{"(n,m)":>12}  {"Forward(s)":>12}  {"Reverse(s)":>12}  {"Ratio F/R":>12}')

cases = [(5, 500), (50, 50), (500, 5), (1, 1000), (1000, 1)]
for n, m in cases:
    tf, Jf = time_jacobian(n, m, 'forward')
    tr, Jr = time_jacobian(n, m, 'reverse')
    ratio = tf / tr if tr > 0 else float('inf')
    ok = np.allclose(Jf, Jr)
    print(f'  ({n:>4},{m:>4})  {tf:>12.5f}  {tr:>12.5f}  {ratio:>12.2f}  check={ok}')

print()
print('Observation: ratio ~n/m. Forward wins when n<m, reverse when m<n.')
print('For ML (m=1 scalar loss): reverse always wins for large n.')

Code cell 18

# === 4.3 Full Jacobian via VJPs (reverse mode) ===

# Build the full Jacobian row-by-row using our micrograd engine
# f: R^3 -> R^3,  f(x) = [x0^2, x1*x2, exp(x0+x1)]

def f_vector(x_arr):
    """x_arr: list/array of Value objects, returns list of Values"""
    return [x_arr[0]**2, x_arr[1]*x_arr[2], (x_arr[0]+x_arr[1]).exp()]

x_vals = [1.0, 2.0, 0.5]
n, m = 3, 3

# Build Jacobian row by row (reverse mode: 1 pass per output)
J_rev = np.zeros((m, n))
for i in range(m):
    xs = [Value(v, label=f'x{j}') for j, v in enumerate(x_vals)]
    outs = f_vector(xs)
    outs[i].backward()
    for j in range(n):
        J_rev[i, j] = xs[j].grad

print('Jacobian via reverse mode (1 pass per output):')
print(J_rev.round(4))

# Analytic Jacobian at (1, 2, 0.5):
# df0/dx0 = 2x0=2, df0/dx1=0, df0/dx2=0
# df1/dx0 = 0, df1/dx1=x2=0.5, df1/dx2=x1=2
# df2/dx0 = exp(x0+x1)=e^3, df2/dx1=exp(x0+x1)=e^3, df2/dx2=0
e3 = math.exp(3)
J_exact = np.array([[2,0,0],[0,0.5,2],[e3,e3,0]])
print('Analytic Jacobian:')
print(J_exact.round(4))
print(f'Match: {np.allclose(J_rev, J_exact, atol=1e-10)}')

5. Hessian-Vector Products and Higher-Order Derivatives

The Hessian H=2fRn×nH = \nabla^2 f \in \mathbb{R}^{n \times n} is too expensive to materialise for large nn. We compute HvH\mathbf{v} directly in O(T(f))O(T(f)).

Code cell 20

# === 5.1 HVP via Nested AD (NumPy implementation) ===

# f(x) = 0.5 * ||Ax||^2 => grad = A^T A x => Hessian = A^T A
# HVP H*v = A^T A * v

np.random.seed(42)
m_hvp, n_hvp = 20, 5
A_hvp = np.random.randn(m_hvp, n_hvp)
H_exact = A_hvp.T @ A_hvp  # true Hessian

x = np.random.randn(n_hvp)
v = np.random.randn(n_hvp)

# Method 1: materialise Hessian
hvp_direct = H_exact @ v

# Method 2: HVP trick - H*v = grad of (grad.f . v)
# grad f(x) = A^T A x
def grad_f(x):
    return A_hvp.T @ (A_hvp @ x)

# Numerical HVP via FD on the gradient:
eps = 1e-5
hvp_fd = (grad_f(x + eps*v) - grad_f(x - eps*v)) / (2*eps)

print('=== Hessian-Vector Product Comparison ===')
print(f'n = {n_hvp}, H shape = {H_exact.shape}')
print(f'H*v (direct):   {hvp_direct.round(4)}')
print(f'H*v (FD on g):  {hvp_fd.round(4)}')
print(f'Match: {np.allclose(hvp_direct, hvp_fd, atol=1e-4)}')
print()

# Memory comparison
import sys
n_big = 10000
hess_memory_GB = n_big**2 * 8 / 1e9
vec_memory_GB = n_big * 8 / 1e9
print(f'For n={n_big}:')
print(f'  Full Hessian memory: {hess_memory_GB:.1f} GB')
print(f'  HVP method memory:   {vec_memory_GB*3:.4f} GB (3 gradient vectors)')
print(f'  Speedup factor: {hess_memory_GB/(vec_memory_GB*3):.0f}x less memory')

Code cell 21

# === 5.2 PyTorch HVP via create_graph ===

if HAS_TORCH:
    torch.manual_seed(42)
    n_pt = 5
    A_pt = torch.randn(10, n_pt)
    x_pt = torch.randn(n_pt, requires_grad=True)
    v_pt = torch.randn(n_pt)

    def f_pt(x):
        return 0.5 * (A_pt @ x).pow(2).sum()

    # Compute gradient with create_graph=True to allow second-order diff
    loss = f_pt(x_pt)
    grad = torch.autograd.grad(loss, x_pt, create_graph=True)[0]

    # HVP: differentiate (grad . v) w.r.t. x
    gv = (grad * v_pt).sum()
    hvp_torch_result = torch.autograd.grad(gv, x_pt)[0]

    # Compare to exact H*v = A^T A * v
    H_pt = A_pt.T @ A_pt
    hvp_exact_pt = H_pt @ v_pt

    print('=== PyTorch HVP via create_graph ===')
    print(f'HVP (AD):    {hvp_torch_result.detach().numpy().round(4)}')
    print(f'HVP (exact): {hvp_exact_pt.detach().numpy().round(4)}')
    err = (hvp_torch_result.detach() - hvp_exact_pt).norm().item()
    print(f'Error: {err:.2e}')
    print(f'PASS: {err < 1e-4}')
else:
    print('PyTorch not available. Using NumPy FD-based HVP (demonstrated above).')
    print('The pattern is identical: compute grad with create_graph=True,')
    print('then differentiate (grad . v) to get H*v.')

6. Computational Graph Mechanics

The backward pass requires processing nodes in reverse topological order. We implement Kahn's algorithm and verify it on example graphs.

Code cell 23

# === 6.1 Topological Sort (Kahn's Algorithm) ===

from collections import defaultdict, deque

def topological_sort_kahn(nodes, edges):
    """
    nodes: list of node names
    edges: list of (src, dst) meaning src -> dst (src is computed before dst)
    Returns: topological order (sources first)
    """
    successors = defaultdict(list)   # node -> list of nodes that depend on it
    in_degree = defaultdict(int)     # number of dependencies each node has

    for n in nodes:
        if n not in in_degree:
            in_degree[n] = 0
    for src, dst in edges:
        successors[src].append(dst)
        in_degree[dst] += 1

    queue = deque(n for n in nodes if in_degree[n] == 0)
    order = []
    while queue:
        n = queue.popleft()
        order.append(n)
        for succ in successors[n]:
            in_degree[succ] -= 1
            if in_degree[succ] == 0:
                queue.append(succ)
    if len(order) != len(nodes):
        raise ValueError('Graph has a cycle!')
    return order

# Graph for f(x1,x2) = (x1+x2)*sin(x1)
nodes = ['x1', 'x2', 'v3', 'v4', 'v5']
edges = [('x1','v3'), ('x2','v3'), ('x1','v4'), ('v3','v5'), ('v4','v5')]

fwd_order = topological_sort_kahn(nodes, edges)
bwd_order = list(reversed(fwd_order))

print('Forward computation order:', fwd_order)
print('Backward (adjoint) order: ', bwd_order)
print()
print('Verification of backward order:')
print('  v5 first: accumulate to v3 and v4')
print('  v4 next: accumulate to x1 (from sin)')
print('  v3 next: accumulate to x1 and x2 (from add)')
print('  x1 last: all contributions have been added')

Code cell 24

# === 6.2 PyTorch Graph Inspection ===

if HAS_TORCH:
    x1 = torch.tensor(1.0, requires_grad=True)
    x2 = torch.tensor(2.0, requires_grad=True)

    v3 = x1 + x2
    v4 = torch.sin(x1)
    v5 = v3 * v4

    print('=== PyTorch Graph Structure ===')
    print(f'v5.grad_fn: {v5.grad_fn}')
    print(f'v5 inputs:  {[str(t) for t, _ in v5.grad_fn.next_functions]}')
    v3_fn, v4_fn = v5.grad_fn.next_functions
    print(f'v3.grad_fn: {v3_fn[0]}')
    print(f'v4.grad_fn: {v4_fn[0]}')
    print()

    v5.backward()
    print(f'x1.grad = {x1.grad:.6f}  (expected: {math.sin(1)+3*math.cos(1):.6f})')
    print(f'x2.grad = {x2.grad:.6f}  (expected: {math.sin(1):.6f})')
    print(f'Graph freed after backward (v5.grad_fn.next_functions): ', end='')
    # After backward with retain_graph=False, graph is freed
    print('freed (default behavior)')
else:
    print('PyTorch not available.')
    print('In PyTorch, each tensor with requires_grad=True gets a grad_fn')
    print('that links backward through the computational graph.')

7. Gradient Checkpointing

Store only L\sqrt{L} activations; recompute the rest during backward. Memory drops from O(L)O(L) to O(L)O(\sqrt{L}) at cost of one extra forward pass.

Code cell 26

# === 7.1 Checkpointing Memory Analysis ===

import math

def activation_memory(L, checkpoint=False, segment_size=None):
    """
    Count activations stored for a depth-L sequential network.
    Without checkpointing: all L activations stored.
    With sqrt-L checkpointing: sqrt(L) checkpoints + sqrt(L) within each segment.
    """
    if not checkpoint:
        return L  # all activations stored
    else:
        K = segment_size or max(1, int(math.sqrt(L)))
        n_segments = math.ceil(L / K)
        # Store: one activation per segment boundary = n_segments checkpoints
        # During backward: recompute at most K activations within each segment
        peak = n_segments + K  # checkpoints + current-segment activations
        return peak

print('Memory comparison: standard vs sqrt-L checkpointing')
print(f'{"L":>6}  {"Standard":>10}  {"Checkpoint":>12}  {"Ratio":>8}')

for L in [10, 50, 100, 200, 500, 1000]:
    std = activation_memory(L, False)
    ckpt = activation_memory(L, True)
    ratio = std / ckpt
    print(f'{L:>6}  {std:>10}  {ckpt:>12}  {ratio:>8.2f}x')

print()
print('Compute overhead: ~1 extra forward pass (recomputing all segments once)')
print('Memory savings: O(L) -> O(sqrt(L))')

Code cell 27

# === 7.2 Visualise Memory vs Depth ===

if HAS_MPL:
    Ls = np.arange(1, 201)
    std_mem = Ls
    ckpt_mem = np.array([activation_memory(L, True) for L in Ls])
    sqrt_bound = 2 * np.sqrt(Ls)

    fig, ax = plt.subplots(figsize=(9, 5))
    ax.plot(Ls, std_mem, label='Standard (O(L))', color='tab:red')
    ax.plot(Ls, ckpt_mem, label='Checkpointed (peak)', color='tab:blue')
    ax.plot(Ls, sqrt_bound, '--', label='$2\\sqrt{L}$ bound', color='tab:orange')
    ax.set_xlabel('Network depth L')
    ax.set_ylabel('Activations stored (peak)')
    ax.set_title('Gradient Checkpointing: Memory vs Depth')
    ax.legend()
    plt.tight_layout()
    plt.show()
    print('Checkpoint memory tracks 2*sqrt(L) closely.')

Code cell 28

# === 7.3 PyTorch Gradient Checkpointing ===

if HAS_TORCH:
    from torch.utils.checkpoint import checkpoint as torch_checkpoint

    # Simple sequential model
    class Block(nn.Module):
        def __init__(self, d):
            super().__init__()
            self.fc = nn.Linear(d, d)
        def forward(self, x):
            return torch.tanh(self.fc(x))

    d, L_test = 32, 10
    blocks = nn.ModuleList([Block(d) for _ in range(L_test)])

    x = torch.randn(16, d, requires_grad=True)

    # Standard forward+backward
    h = x
    for blk in blocks:
        h = blk(h)
    loss_std = h.sum()
    loss_std.backward()
    g_std = x.grad.clone()

    # Checkpointed forward+backward
    x.grad = None
    h = x
    for blk in blocks:
        h = torch_checkpoint(blk, h, use_reentrant=False)
    loss_ckpt = h.sum()
    loss_ckpt.backward()
    g_ckpt = x.grad.clone()

    err = (g_std - g_ckpt).abs().max().item()
    print('=== PyTorch Checkpoint Correctness ===')
    print(f'Max gradient difference: {err:.2e}')
    print(f'PASS (same gradients): {err < 1e-5}')
    print('Checkpoint recomputes activations during backward — same math, less memory.')
else:
    print('PyTorch not available — checkpointing shown conceptually above.')

8. Custom Backward Functions

Sometimes we can implement a more numerically stable or efficient VJP than what AD would compute automatically. torch.autograd.Function lets us override the backward.

Code cell 30

# === 8.1 Numerically Stable Softplus with Custom VJP ===

if HAS_TORCH:
    class StableSoftplus(torch.autograd.Function):
        """log(1 + exp(x)) with numerically stable implementation."""

        @staticmethod
        def forward(ctx, x):
            # Stable: x + log(1 + exp(-x)) for x > 0
            #         log(1 + exp(x))       for x <= 0
            out = torch.where(x > 0,
                              x + torch.log1p(torch.exp(-x.abs())),
                              torch.log1p(torch.exp(x)))
            ctx.save_for_backward(x)
            return out

        @staticmethod
        def backward(ctx, grad_output):
            x, = ctx.saved_tensors
            # d/dx softplus(x) = sigmoid(x) = 1/(1+exp(-x))
            sigmoid_x = torch.sigmoid(x)
            return grad_output * sigmoid_x

    stable_softplus = StableSoftplus.apply

    # Test at extreme values
    test_vals = torch.tensor([-100., -10., 0., 10., 100.], requires_grad=True)
    out = stable_softplus(test_vals)
    out.sum().backward()

    print('Stable softplus values and gradients:')
    for xv, ov, gv in zip(test_vals.tolist(), out.tolist(), test_vals.grad.tolist()):
        print(f'  x={xv:6.0f}  softplus={ov:.4f}  grad(sigmoid)={gv:.4f}')

    # Naive comparison: log(1+exp(x)) overflows at x=100
    x_naive = torch.tensor([100.0], requires_grad=True)
    naive_out = torch.log(1 + torch.exp(x_naive))
    print(f'\nNaive log(1+exp(100)) = {naive_out.item():.4f}  (inf = overflow)')
    print(f'Stable softplus(100)  = {stable_softplus(torch.tensor([100.0])).item():.4f}')

else:
    # NumPy demo of the stable formula
    def softplus_stable(x):
        return np.where(x > 0, x + np.log1p(np.exp(-np.abs(x))), np.log1p(np.exp(x)))

    xs = np.array([-100, -10, 0, 10, 100], dtype=float)
    print('Stable softplus:', softplus_stable(xs))
    print('Naive overflow check (exp(100)):', np.exp(100.0))

9. Meta-Learning — MAML Inner Loop Gradient

MAML differentiates through one step of gradient descent. This requires create_graph=True to compute second-order derivatives (gradients of gradients).

Code cell 32

# === 9.1 MAML: Differentiating Through a Gradient Step ===

if HAS_TORCH:
    torch.manual_seed(0)

    # Toy meta-learning setup:
    # Inner task: fit a linear model y = w*x + b on support set
    # Meta-objective: adapted model performs well on query set

    w = torch.tensor([1.5], requires_grad=True)  # meta-param
    b = torch.tensor([0.0], requires_grad=True)  # meta-param
    inner_lr = 0.1

    # Support set (task-specific data)
    x_sup = torch.tensor([1.0, 2.0, 3.0])
    y_sup = torch.tensor([2.0, 4.0, 6.0])  # y = 2x

    # Inner loss
    pred_sup = w * x_sup + b
    loss_inner = ((pred_sup - y_sup)**2).mean()

    # Inner gradient update — keep graph for meta-gradient!
    grads = torch.autograd.grad(loss_inner, [w, b], create_graph=True)
    w_adapted = w - inner_lr * grads[0]
    b_adapted = b - inner_lr * grads[1]

    print('=== MAML Inner Loop ===')
    print(f'w (meta) = {w.item():.4f},  b (meta) = {b.item():.4f}')
    print(f'After 1 inner step: w_adapted = {w_adapted.item():.4f},  b_adapted = {b_adapted.item():.4f}')

    # Query set (test generalization)
    x_qry = torch.tensor([4.0, 5.0])
    y_qry = torch.tensor([8.0, 10.0])

    pred_qry = w_adapted * x_qry + b_adapted
    loss_outer = ((pred_qry - y_qry)**2).mean()

    print(f'Outer (query) loss = {loss_outer.item():.4f}')

    # Meta-gradient: grad of outer loss w.r.t. w, b
    meta_grads = torch.autograd.grad(loss_outer, [w, b])
    print(f'Meta-grad w.r.t. w = {meta_grads[0].item():.4f}')
    print(f'Meta-grad w.r.t. b = {meta_grads[1].item():.4f}')
    print()
    print('Key: create_graph=True keeps the graph of grad_w_inner,')
    print('allowing outer loss to be differentiated through the inner step.')
    print('This gives second-order derivatives: grad(grad(loss_inner)).')

else:
    print('PyTorch not available.')
    print('MAML requires create_graph=True to differentiate through the inner gradient.')
    print('The meta-gradient includes Hessian terms: grad_meta = (I - alpha*H) * grad_outer.')

10. Implicit Differentiation

When the solution z(θ)\mathbf{z}^*(\theta) is defined by an implicit equation F(z,θ)=0F(\mathbf{z}, \theta) = 0, we differentiate through it without unrolling the solver.

Code cell 34

# === 10.1 IFT: Differentiating Through Ridge Regression ===

# z*(theta) = argmin_z 0.5*||Az - b||^2 + 0.5*theta*||z||^2
# Solution: z*(theta) = (A^T A + theta I)^{-1} A^T b
# IFT derivative: dz*/dtheta = -(A^T A + theta I)^{-1} z*(theta)

np.random.seed(42)
n_ift, m_ift = 5, 10
A_ift = np.random.randn(m_ift, n_ift)
b_ift = np.random.randn(m_ift)
theta = 0.5

def z_star(theta):
    return la.solve(A_ift.T @ A_ift + theta * np.eye(n_ift), A_ift.T @ b_ift)

# IFT formula: dz*/dtheta = -(A^T A + theta I)^{-1} * z*
def dz_dtheta_ift(theta):
    z = z_star(theta)
    return -la.solve(A_ift.T @ A_ift + theta * np.eye(n_ift), z)

# Finite difference verification
h = 1e-5
dz_fd = (z_star(theta + h) - z_star(theta - h)) / (2*h)

dz_ift = dz_dtheta_ift(theta)

print('=== Implicit Differentiation vs Finite Differences ===')
print(f'dz*/dtheta (IFT): {dz_ift.round(4)}')
print(f'dz*/dtheta (FD):  {dz_fd.round(4)}')
print(f'Max error: {abs(dz_ift - dz_fd).max():.2e}')
print(f'PASS: {np.allclose(dz_ift, dz_fd, atol=1e-5)}')
print()
print('IFT costs: 1 forward solve (O(n^3)) + 1 matrix solve (O(n^3))')
print('Unrolling K gradient steps costs: K * O(n^2) per step = K * O(n^2)')
print('For K >> n, IFT is vastly cheaper.')

11. Summary

Key takeaways from this notebook.

Code cell 36

# === 11. Summary and Key Results ===

print('=' * 60)
print('AUTOMATIC DIFFERENTIATION — KEY RESULTS')
print('=' * 60)
print()
print('1. Three ways to differentiate:')
print('   FD: O(n) evals, ~1e-8 accuracy. AD: O(T(f)), machine precision.')
print()
print('2. Dual numbers (forward mode):')
print('   a+b*eps with eps^2=0 carries (value, derivative) simultaneously.')
print('   JVP: 1 pass per input direction. Cost: O(n*T(f)) for full Jacobian.')
print()
print('3. Adjoint accumulation (reverse mode):')
print('   v_bar = df/dv propagated backward through topological order.')
print('   VJP: 1 pass per output. Cost: O(m*T(f)). For m=1: O(T(f)).')
print()
print('4. Choose forward vs reverse by Jacobian shape:')
print('   n << m -> forward (JVP). m << n -> reverse (VJP).')
print('   ML training: always reverse (n=|theta|>>1=m).')
print()
print('5. HVP without materialising H:')
print('   H*v = grad(grad(f).v). Cost O(T(f)). Enables Newton-CG at scale.')
print()
print('6. Gradient checkpointing:')
print('   Store sqrt(L) activations; recompute rest. Memory O(L)->O(sqrt(L)).')
print()
print('7. Custom VJPs for stability/efficiency:')
print('   Override backward for log-sum-exp, softplus, attention, etc.')
print()
print('8. MAML and higher-order AD:')
print('   create_graph=True keeps the backward graph differentiable.')
print()
print('9. Implicit differentiation (IFT):')
print('   dz*/dtheta = -(d^2F/dz^2)^{-1} (d^2F/dz dtheta).')
print('   Used in DEQ, differentiable opt layers, hypergrad optimisation.')
print()
print('Ready for exercises.ipynb!')