Theory NotebookMath for LLMs

Chain Rule and Backpropagation

Multivariate Calculus / Chain Rule and Backpropagation

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Chain Rule and Backpropagation — Theory

"Backpropagation is an algorithm for computing gradients efficiently in a computational graph. At its heart, it is nothing more than the chain rule of calculus applied repeatedly." — Goodfellow, Bengio & Courville

Interactive theory notebook: chain rule verification, backprop from scratch, gradient derivations for all standard layers, vanishing/exploding gradient analysis, and memory-efficient backprop.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3

import numpy as np
import numpy.linalg as la
from scipy import optimize, special, stats
from scipy.optimize import minimize, fsolve, linprog
from math import factorial
import math
import matplotlib.patches as patches

COLORS = {
    "primary": "#0077BB",
    "secondary": "#EE7733",
    "tertiary": "#009988",
    "error": "#CC3311",
    "neutral": "#555555",
    "highlight": "#EE3377",
}
HAS_MPL = True
np.set_printoptions(precision=8, suppress=True)
np.random.seed(42)
spmin = minimize

try:
    import torch
    HAS_TORCH = True
except ImportError:
    torch = None
    HAS_TORCH = False


def header(title):
    print("\n" + "=" * len(title))
    print(title)
    print("=" * len(title))

def check_true(name, cond):
    ok = bool(cond)
    print(f"{'PASS' if ok else 'FAIL'} - {name}")
    return ok

def check_close(name, got, expected, tol=1e-8):
    ok = np.allclose(got, expected, atol=tol, rtol=tol)
    print(f"{'PASS' if ok else 'FAIL'} - {name}: got {got}, expected {expected}")
    return ok

def check(name, got, expected, tol=1e-8):
    return check_close(name, got, expected, tol=tol)

def sigmoid(x):
    x = np.asarray(x, dtype=float)
    return np.where(x >= 0, 1/(1+np.exp(-x)), np.exp(x)/(1+np.exp(x)))

def softmax(z, axis=-1):
    z = np.asarray(z, dtype=float)
    z = z - np.max(z, axis=axis, keepdims=True)
    e = np.exp(z)
    return e / np.sum(e, axis=axis, keepdims=True)

def relu(x):
    return np.maximum(0, x)

def relu_prime(x):
    return np.where(np.asarray(x) > 0, 1.0, 0.0)

def centered_diff(f, x, h=1e-6):
    return (f(x + h) - f(x - h)) / (2 * h)

def numerical_gradient(f, x, h=1e-6):
    x = np.asarray(x, dtype=float)
    grad = np.zeros_like(x, dtype=float)
    for idx in np.ndindex(x.shape):
        xp = x.copy(); xm = x.copy()
        xp[idx] += h; xm[idx] -= h
        grad[idx] = (f(xp) - f(xm)) / (2*h)
    return grad

def numerical_jacobian(f, x, h=1e-6):
    x = np.asarray(x, dtype=float)
    y0 = np.asarray(f(x), dtype=float)
    J = np.zeros((y0.size, x.size), dtype=float)
    flat = x.reshape(-1)
    for j in range(x.size):
        xp = flat.copy(); xm = flat.copy()
        xp[j] += h; xm[j] -= h
        J[:, j] = ((np.asarray(f(xp.reshape(x.shape))) - np.asarray(f(xm.reshape(x.shape)))) / (2*h)).reshape(-1)
    return J.reshape(y0.shape + x.shape)

def grad_check(f, x, analytic_grad, h=1e-6):
    numeric_grad = numerical_gradient(f, x, h=h)
    denom = la.norm(analytic_grad) + la.norm(numeric_grad) + 1e-12
    return la.norm(analytic_grad - numeric_grad) / denom



def jacobian_fd(f, x, h=1e-6):
    x = np.asarray(x, dtype=float)
    y0 = np.asarray(f(x), dtype=float)
    J = np.zeros((y0.size, x.size), dtype=float)
    flat = x.reshape(-1)
    for j in range(x.size):
        xp = flat.copy(); xm = flat.copy()
        xp[j] += h; xm[j] -= h
        yp = np.asarray(f(xp.reshape(x.shape)), dtype=float).reshape(-1)
        ym = np.asarray(f(xm.reshape(x.shape)), dtype=float).reshape(-1)
        J[:, j] = (yp - ym) / (2*h)
    return J.reshape(y0.shape + x.shape)

def hessian_fd(f, x, h=1e-5):
    x = np.asarray(x, dtype=float)
    H = np.zeros((x.size, x.size), dtype=float)
    flat = x.reshape(-1)
    for j in range(x.size):
        xp = flat.copy(); xm = flat.copy()
        xp[j] += h; xm[j] -= h
        gp = numerical_gradient(lambda z: f(z.reshape(x.shape)), xp.reshape(x.shape), h=h).reshape(-1)
        gm = numerical_gradient(lambda z: f(z.reshape(x.shape)), xm.reshape(x.shape), h=h).reshape(-1)
        H[:, j] = (gp - gm) / (2*h)
    return H.reshape(x.shape + x.shape)



def grad_fd(f, x, h=1e-6):
    return numerical_gradient(f, x, h=h)



def fd_grad(f, x, h=1e-6):
    return numerical_gradient(f, np.asarray(x, dtype=float), h=h)

print("Chapter helper setup complete.")

1. Chain Rule Verification

We verify the multivariate chain rule Jfg(x)=Jf(g(x))Jg(x)J_{f\circ g}(\mathbf{x}) = J_f(g(\mathbf{x})) \cdot J_g(\mathbf{x}) numerically for concrete functions.

Code cell 5

# === 1.1 Scalar Chain Rule ===
# f(t) = sin(t^2), g(t) = exp(3t)
# (f∘g)'(t) = cos(g(t)^2) * 2*g(t) * 3*exp(3t)

f_scalar = lambda t: np.sin(t**2)
g_scalar = lambda t: np.exp(3*t)
h_scalar = lambda t: f_scalar(g_scalar(t))  # composition

t0 = 0.5
# Analytical derivative of composition
gt0 = g_scalar(t0)
chain_rule_val = np.cos(gt0**2) * 2*gt0 * 3*np.exp(3*t0)

# Finite difference
h = 1e-5
fd_val = (h_scalar(t0 + h) - h_scalar(t0 - h)) / (2*h)

print('Scalar Chain Rule: (f∘g)(t) = sin(exp(6t))')
print(f'  Analytical: {chain_rule_val:.8f}')
print(f'  FD approx:  {fd_val:.8f}')
check_close('scalar chain rule', chain_rule_val, fd_val)

Code cell 6

# === 1.2 Multivariate Chain Rule — Jacobian Composition ===
# g: R^2 -> R^3,  f: R^3 -> R^2

def g_mv(x):
    return np.array([x[0]**2, x[0]*x[1], np.exp(x[1])])

def f_mv(u):
    return np.array([u[0]*u[1], u[1] + u[2]**2])

def h_mv(x):
    return f_mv(g_mv(x))

def Jg(x):
    """Analytical Jacobian of g."""
    return np.array([
        [2*x[0],    0.0      ],
        [x[1],      x[0]     ],
        [0.0,       np.exp(x[1])]
    ])

def Jf(u):
    """Analytical Jacobian of f."""
    return np.array([
        [u[1],      u[0],   0.0    ],
        [0.0,       1.0,    2*u[2] ]
    ])

x0 = np.array([1.0, 0.0])
gx0 = g_mv(x0)

# Chain rule: J_h = J_f(g(x)) @ J_g(x)
Jh_chain = Jf(gx0) @ Jg(x0)
Jh_fd    = jacobian_fd(h_mv, x0)

print('Jacobian via chain rule J_f(g(x)) @ J_g(x):')
print(Jh_chain.round(6))
print('Jacobian via finite differences:')
print(Jh_fd.round(6))
check_close('Jacobian chain rule = FD', Jh_chain, Jh_fd)

Code cell 7

# === 1.3 VJP and JVP Duality ===
x0 = np.array([1.0, 0.0])
gx0 = g_mv(x0)
Jh = Jf(gx0) @ Jg(x0)  # 2x2 Jacobian

v = np.array([1.0, -0.5])   # tangent vector (input space)
u = np.array([2.0, 1.0])    # cotangent vector (output space)

# JVP: J @ v  (forward mode — one input direction)
jvp = Jh @ v

# VJP: J.T @ u  (reverse mode — one output direction)
vjp = Jh.T @ u

# Duality: u^T (Jv) = (J^T u)^T v
lhs = u @ jvp
rhs = vjp @ v

print(f'JVP (J @ v):  {jvp}')
print(f'VJP (J^T @ u): {vjp}')
print(f'Duality LHS u^T(Jv): {lhs:.8f}')
print(f'Duality RHS (J^Tu)^Tv: {rhs:.8f}')
check_close('Duality u^T(Jv) = (J^Tu)^Tv', lhs, rhs)

print('\nTakeaway: JVP=forward mode, VJP=backprop. Both compute the same info differently.')
print('For scalar loss (m=1), VJP costs 1 pass vs n passes for JVP.')

2. Computation Graphs

We implement a minimal computation graph with automatic differentiation — the micro version of PyTorch's autograd.

Code cell 9

# === 2.1 Minimal Autograd Engine ===

class Value:
    """Scalar value with automatic gradient tracking."""
    def __init__(self, data, _children=(), _op=''):
        self.data = float(data)
        self.grad = 0.0
        self._backward = lambda: None
        self._prev = set(_children)
        self._op = _op

    def __repr__(self):
        return f'Value(data={self.data:.4f}, grad={self.grad:.4f})'

    def __add__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        out = Value(self.data + other.data, (self, other), '+')
        def _backward():
            self.grad += out.grad
            other.grad += out.grad
        out._backward = _backward
        return out

    def __mul__(self, other):
        other = other if isinstance(other, Value) else Value(other)
        out = Value(self.data * other.data, (self, other), '*')
        def _backward():
            self.grad += other.data * out.grad
            other.grad += self.data * out.grad
        out._backward = _backward
        return out

    def __pow__(self, exponent):
        out = Value(self.data**exponent, (self,), f'**{exponent}')
        def _backward():
            self.grad += exponent * self.data**(exponent-1) * out.grad
        out._backward = _backward
        return out

    def exp(self):
        out = Value(np.exp(self.data), (self,), 'exp')
        def _backward():
            self.grad += out.data * out.grad
        out._backward = _backward
        return out

    def relu(self):
        out = Value(max(0, self.data), (self,), 'relu')
        def _backward():
            self.grad += (out.data > 0) * out.grad
        out._backward = _backward
        return out

    def log(self):
        out = Value(np.log(self.data + 1e-10), (self,), 'log')
        def _backward():
            self.grad += (1.0 / (self.data + 1e-10)) * out.grad
        out._backward = _backward
        return out

    def backward(self):
        # Build topological order
        topo = []
        visited = set()
        def build(v):
            if v not in visited:
                visited.add(v)
                for child in v._prev:
                    build(child)
                topo.append(v)
        build(self)
        self.grad = 1.0
        for v in reversed(topo):
            v._backward()

    def __radd__(self, other): return self + other
    def __rmul__(self, other): return self * other
    def __neg__(self): return self * -1
    def __sub__(self, other): return self + (-other)
    def __truediv__(self, other): return self * other**(-1)

print('Minimal autograd engine defined.')

Code cell 10

# === 2.2 Verify Autograd on Simple Expression ===
# f(x, y) = exp(x) * (x + y^2)
# ∂f/∂x = exp(x) * (x + y^2) + exp(x) * 1 = exp(x)*(x + y^2 + 1)
# ∂f/∂y = exp(x) * 2y

x = Value(1.0)
y = Value(2.0)
f = x.exp() * (x + y**2)
f.backward()

# Analytical
xv, yv = 1.0, 2.0
df_dx = np.exp(xv) * (xv + yv**2 + 1)
df_dy = np.exp(xv) * 2*yv

print(f'f(1, 2) = exp(1)*(1+4) = {f.data:.6f}  (expected {np.exp(1)*5:.6f})')
print(f'df/dx autograd: {x.grad:.6f}  analytical: {df_dx:.6f}')
print(f'df/dy autograd: {y.grad:.6f}  analytical: {df_dy:.6f}')
check_close('df/dx', x.grad, df_dx)
check_close('df/dy', y.grad, df_dy)

Code cell 11

# === 2.3 Fan-Out Node: Gradient Accumulation ===
# z = x*x + 3*x  →  dz/dx = 2x + 3
# x appears twice: once in x*x, once in 3*x

x = Value(4.0)
z = x*x + Value(3)*x   # x used twice (fan-out)
z.backward()

# Analytical: dz/dx = 2x + 3 = 11
print(f'z = x² + 3x at x=4: {z.data}  (expected {4**2 + 3*4})')
print(f'dz/dx autograd: {x.grad:.4f}  analytical: {2*4+3}')
check_close('fan-out accumulation', x.grad, 2*4+3)
print('\nGradients from both paths summed correctly.')

3. Backpropagation from Scratch

We implement the full backprop algorithm for a feedforward MLP using numpy, then verify against finite differences.

Code cell 13

# === 3.1 Activation Functions and Derivatives ===

def relu(z): return np.maximum(0, z)
def relu_prime(z): return (z > 0).astype(float)

def sigmoid(z):
    return np.where(z >= 0, 1/(1+np.exp(-z)),
                    np.exp(z)/(1+np.exp(z)))
def sigmoid_prime(z):
    s = sigmoid(z)
    return s * (1 - s)

def softmax(z):
    z = z - np.max(z, axis=-1, keepdims=True)
    e = np.exp(z)
    return e / e.sum(axis=-1, keepdims=True)

# Verify sigmoid derivative numerically
z_test = np.array([-2.0, -0.5, 0.0, 0.5, 2.0])
sig_prime_analytical = sigmoid_prime(z_test)
sig_prime_fd = np.array([
    (sigmoid(np.array([zi+1e-5])) - sigmoid(np.array([zi-1e-5])))[0] / (2e-5)
    for zi in z_test
])
check_close('sigmoid derivative', sig_prime_analytical, sig_prime_fd)

Code cell 14

# === 3.2 MLP Forward Pass ===

class MLP:
    def __init__(self, layer_sizes, seed=42):
        np.random.seed(seed)
        self.L = len(layer_sizes) - 1
        self.W = []
        self.b = []
        for i in range(self.L):
            n_in, n_out = layer_sizes[i], layer_sizes[i+1]
            # He initialisation for ReLU
            self.W.append(np.random.randn(n_out, n_in) * np.sqrt(2.0/n_in))
            self.b.append(np.zeros(n_out))

    def forward(self, x):
        """Forward pass. Returns output and cache."""
        cache = {'a': [x]}
        cache['z'] = []
        a = x
        for l in range(self.L):
            z = self.W[l] @ a + self.b[l]
            cache['z'].append(z)
            if l < self.L - 1:
                a = relu(z)
            else:
                a = z   # linear output layer
            cache['a'].append(a)
        return cache['a'][-1], cache

    def mse_loss(self, y_pred, y_true):
        return 0.5 * np.mean((y_pred - y_true)**2)

net = MLP([2, 4, 3, 1])
x = np.array([1.0, 2.0])
y_pred, cache = net.forward(x)
print(f'Forward pass output: {y_pred}')
print(f'Cache keys: {list(cache.keys())}')
print(f'Layer activations shapes: {[a.shape for a in cache["a"]]}')

Code cell 15

# === 3.3 MLP Backward Pass ===

def backward_pass(net, cache, y_pred, y_true):
    """Manual backpropagation. Returns gradient dicts."""
    dW = [None] * net.L
    db = [None] * net.L

    # Output layer delta: MSE loss gradient
    delta = (y_pred - y_true) / len(np.atleast_1d(y_true))  # (n_out,)

    for l in range(net.L - 1, -1, -1):
        a_prev = cache['a'][l]
        dW[l] = np.outer(delta, a_prev)
        db[l] = delta.copy()

        if l > 0:
            # Propagate through W and activation
            delta = net.W[l].T @ delta
            delta = delta * relu_prime(cache['z'][l-1])

    return dW, db

# Test backward pass
x = np.array([1.0, 2.0])
y_true = np.array([1.0])
y_pred, cache = net.forward(x)
dW, db = backward_pass(net, cache, y_pred, y_true)

print(f'Loss: {net.mse_loss(y_pred, y_true):.6f}')
print(f'Gradient shapes: {[g.shape for g in dW]}')
print(f'dW[0] norm: {la.norm(dW[0]):.6f}')

Code cell 16

# === 3.4 Gradient Verification via Finite Differences ===

def compute_loss(net, x, y_true):
    y_pred, _ = net.forward(x)
    return net.mse_loss(y_pred, y_true)

def grad_check(net, x, y_true, layer_idx=0, h=1e-5):
    """Check gradients for W[layer_idx] via FD."""
    W = net.W[layer_idx]
    grad_fd = np.zeros_like(W)
    for i in range(W.shape[0]):
        for j in range(W.shape[1]):
            W[i, j] += h
            loss_plus = compute_loss(net, x, y_true)
            W[i, j] -= 2*h
            loss_minus = compute_loss(net, x, y_true)
            W[i, j] += h  # restore
            grad_fd[i, j] = (loss_plus - loss_minus) / (2*h)
    return grad_fd

y_pred, cache = net.forward(x)
dW, db = backward_pass(net, cache, y_pred, y_true)

for l in range(net.L):
    grad_fd = grad_check(net, x, y_true, layer_idx=l)
    rel_err = la.norm(dW[l] - grad_fd) / (la.norm(dW[l]) + la.norm(grad_fd) + 1e-10)
    ok = rel_err < 1e-5
    print(f'{'PASS' if ok else 'FAIL'} — Layer {l} gradient (rel err: {rel_err:.2e})')

# Restore shared helper name after the local finite-difference arrays above.
grad_fd = numerical_gradient
print("Gradient helper restored.")

4. Gradient Derivations for Standard Layers

We derive and verify gradients for every layer type appearing in modern transformers.

Code cell 18

# === 4.1 Linear Layer Gradients ===
# Forward: z = W x + b
# Backward: given z_bar
#   x_bar = W^T z_bar
#   W_bar = outer(z_bar, x)
#   b_bar = z_bar

np.random.seed(42)
m, n = 4, 3
W = np.random.randn(m, n)
b = np.random.randn(m)
x = np.random.randn(n)
z_bar = np.random.randn(m)  # upstream gradient

# Analytical gradients
x_bar = W.T @ z_bar
W_bar = np.outer(z_bar, x)
b_bar = z_bar

# Verify via finite differences on a scalar function
def loss_fn(x_):
    z = W @ x_ + b
    return z_bar @ z  # scalar projection for FD check

x_bar_fd = grad_fd(loss_fn, x)

print('Linear Layer Gradients:')
print(f'  x_bar analytical: {x_bar}')
print(f'  x_bar FD:         {x_bar_fd}')
check_close('x_bar = W^T z_bar', x_bar, x_bar_fd)

# W_bar via FD
def loss_fn_W(W_flat):
    W_ = W_flat.reshape(m, n)
    z = W_ @ x + b
    return z_bar @ z

W_bar_fd = grad_fd(loss_fn_W, W.flatten()).reshape(m, n)
check_close('W_bar = outer(z_bar, x)', W_bar, W_bar_fd)

Code cell 19

# === 4.2 Softmax + Cross-Entropy Fused Gradient ===
# L = -log(p_y),  p = softmax(z)
# Claim: dL/dz = p - e_y

K = 5
z = np.array([2.0, 1.0, 0.5, -1.0, 0.3])
y = 1  # true class index

p = softmax(z)
e_y = np.zeros(K); e_y[y] = 1.0

# Analytical gradient
grad_analytical = p - e_y

# FD gradient
def ce_loss(z_):
    p_ = softmax(z_)
    return -np.log(p_[y] + 1e-10)

grad_fd_ce = grad_fd(ce_loss, z)

print('Softmax + Cross-Entropy Gradient:')
print(f'  p = {p.round(4)}')
print(f'  Analytical (p - e_y): {grad_analytical.round(4)}')
print(f'  Finite diff:           {grad_fd_ce.round(4)}')
check_close('CE gradient = p - e_y', grad_analytical, grad_fd_ce)

print('\nTakeaway: Fusing softmax+CE avoids storing softmax Jacobian.')
print('  Cost: O(K) vs O(K^2) for separate computation.')

Code cell 20

# === 4.3 LayerNorm Gradient ===

d = 6
x = np.random.randn(d)
gamma = np.ones(d)
beta = np.zeros(d)
eps = 1e-5

def layernorm_forward(x, gamma, beta, eps=1e-5):
    mu = x.mean()
    var = x.var()
    x_hat = (x - mu) / np.sqrt(var + eps)
    y = gamma * x_hat + beta
    return y, x_hat, mu, var

def layernorm_backward(dy, x, x_hat, var, gamma, eps=1e-5):
    d = len(x)
    # Gradient through scale/shift
    dgamma = dy * x_hat
    dbeta = dy
    # Gradient through normalisation
    dx_hat = dy * gamma
    dvar = np.sum(dx_hat * (x - x.mean()) * (-0.5) * (var + eps)**(-1.5))
    dmu = np.sum(dx_hat * (-1/np.sqrt(var + eps))) + dvar * (-2/d) * np.sum(x - x.mean())
    dx = dx_hat / np.sqrt(var + eps) + dvar * 2*(x - x.mean())/d + dmu/d
    return dx, dgamma, dbeta

y, x_hat, mu, var = layernorm_forward(x, gamma, beta)
dy = np.random.randn(d)  # upstream gradient
dx, dgamma, dbeta = layernorm_backward(dy, x, x_hat, var, gamma)

# Verify dx via FD
def ln_proj(x_):
    y_, _, _, _ = layernorm_forward(x_, gamma, beta)
    return dy @ y_  # scalar projection

dx_fd = grad_fd(ln_proj, x)
check_close('LayerNorm dx', dx, dx_fd)

print('LayerNorm gradient verified.')
print(f'dx norm: {la.norm(dx):.6f},  dx_fd norm: {la.norm(dx_fd):.6f}')

Code cell 21

# === 4.4 Attention Gradient ===

T, d_k, d_v = 4, 3, 3
np.random.seed(1)
Q = np.random.randn(T, d_k)
K = np.random.randn(T, d_k)
V = np.random.randn(T, d_v)
scale = 1.0 / np.sqrt(d_k)

def attention_forward(Q, K, V, scale):
    S = Q @ K.T * scale              # (T, T) raw scores
    P = softmax(S)                   # (T, T) attention weights
    O = P @ V                        # (T, d_v) output
    return O, P, S

def softmax_backward(P, dP):
    """Backward through row-wise softmax."""
    # dS[i,j] = P[i,j] * (dP[i,j] - sum_k(P[i,k]*dP[i,k]))
    dS = P * (dP - (P * dP).sum(axis=1, keepdims=True))
    return dS

def attention_backward(dO, P, S, Q, K, V, scale):
    # dV = P^T dO
    dV = P.T @ dO
    # dP = dO V^T
    dP = dO @ V.T
    # dS = softmax_backward
    dS = softmax_backward(P, dP)
    # dQ = dS K * scale,  dK = dS^T Q * scale
    dQ = dS @ K * scale
    dK = dS.T @ Q * scale
    return dQ, dK, dV

O, P, S = attention_forward(Q, K, V, scale)
dO = np.random.randn(T, d_v)
dQ, dK, dV = attention_backward(dO, P, S, Q, K, V, scale)

# Verify via FD
def attn_proj(Q_flat):
    Q_ = Q_flat.reshape(T, d_k)
    O_, _, _ = attention_forward(Q_, K, V, scale)
    return (dO * O_).sum()

dQ_fd = grad_fd(attn_proj, Q.flatten()).reshape(T, d_k)
check_close('Attention dQ', dQ, dQ_fd)

def attn_proj_V(V_flat):
    V_ = V_flat.reshape(T, d_v)
    O_, _, _ = attention_forward(Q, K, V_, scale)
    return (dO * O_).sum()

dV_fd = grad_fd(attn_proj_V, V.flatten()).reshape(T, d_v)
check_close('Attention dV', dV, dV_fd)
print('Attention backward verified.')

5. Vanishing and Exploding Gradients

Empirical analysis of gradient magnitudes across depth, comparing activations and initialisation strategies.

Code cell 23

# === 5.1 Gradient Norm vs Depth ===

np.random.seed(42)
L = 20  # number of layers
n = 50  # layer width
x = np.random.randn(n)
x /= la.norm(x)  # unit input

def run_depth_experiment(weight_scale, activation='sigmoid', n_layers=L, width=n):
    """Forward + backward, return gradient norms at each layer."""
    # Build weights
    Ws = [np.random.randn(width, width) * weight_scale for _ in range(n_layers)]

    # Forward
    a = x.copy()
    acts = [a]
    preacts = []
    for W in Ws:
        z = W @ a
        preacts.append(z)
        if activation == 'sigmoid':
            a = sigmoid(z)
        elif activation == 'relu':
            a = relu(z)
        elif activation == 'tanh':
            a = np.tanh(z)
        else:
            a = z  # linear
        acts.append(a)

    # Backward: compute ||delta^l||_2 for each l
    delta = np.random.randn(width)  # loss gradient at output
    delta /= la.norm(delta)
    grad_norms = [la.norm(delta)]

    for l in range(n_layers - 1, 0, -1):
        if activation == 'sigmoid':
            act_prime = sigmoid_prime(preacts[l])
        elif activation == 'relu':
            act_prime = relu_prime(preacts[l])
        elif activation == 'tanh':
            act_prime = 1 - np.tanh(preacts[l])**2
        else:
            act_prime = np.ones_like(preacts[l])
        delta = Ws[l].T @ delta
        delta = delta * act_prime
        grad_norms.append(la.norm(delta))

    return list(reversed(grad_norms))

# Run experiments
scale_sigmoid = 0.1 / np.sqrt(n)   # small weights -> vanishing
scale_relu = np.sqrt(2.0 / n)       # He init
scale_xavier = np.sqrt(1.0 / n)     # Xavier

np.random.seed(0)
norms_sigmoid = run_depth_experiment(scale_sigmoid, 'sigmoid')
norms_relu    = run_depth_experiment(scale_relu, 'relu')
norms_sigmoid_xe = run_depth_experiment(scale_xavier, 'sigmoid')
norms_tanh    = run_depth_experiment(scale_xavier, 'tanh')

print('Gradient norm at layer 1 vs layer 20:')
print(f'  Sigmoid (small init): {norms_sigmoid[0]:.2e} vs {norms_sigmoid[-1]:.2e}')
print(f'  ReLU (He init):       {norms_relu[0]:.2e} vs {norms_relu[-1]:.2e}')
print(f'  Sigmoid (Xavier):     {norms_sigmoid_xe[0]:.2e} vs {norms_sigmoid_xe[-1]:.2e}')
print(f'  Tanh (Xavier):        {norms_tanh[0]:.2e} vs {norms_tanh[-1]:.2e}')

Code cell 24

# === 5.2 Vanishing Gradient Plot ===
if HAS_MPL:
    layers = list(range(1, L+1))
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    ax = axes[0]
    ax.semilogy(layers, norms_sigmoid, 'r-o', markersize=4, label='Sigmoid (small init)')
    ax.semilogy(layers, norms_sigmoid_xe, 'm-s', markersize=4, label='Sigmoid (Xavier)')
    ax.semilogy(layers, norms_tanh, 'b-^', markersize=4, label='Tanh (Xavier)')
    ax.semilogy(layers, norms_relu, 'g-D', markersize=4, label='ReLU (He init)')
    ax.axhline(1.0, color='k', linestyle='--', alpha=0.4, label='norm=1 (no change)')
    ax.set_xlabel('Layer')
    ax.set_ylabel('Gradient norm ||δ^l||₂')
    ax.set_title('Gradient Norm vs Depth (20 layers)')
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.4)

    # Show log ratio (how many orders of magnitude change)
    ax = axes[1]
    for norms, label, color in [
        (norms_sigmoid, 'Sigmoid (small)', 'red'),
        (norms_sigmoid_xe, 'Sigmoid (Xavier)', 'purple'),
        (norms_tanh, 'Tanh', 'blue'),
        (norms_relu, 'ReLU (He)', 'green'),
    ]:
        ratio = np.log10(np.array(norms) / norms[-1] + 1e-30)
        ax.plot(layers, ratio, label=label, linewidth=2)
    ax.set_xlabel('Layer')
    ax.set_ylabel('log₁₀(||δ^l|| / ||δ^L||)')
    ax.set_title('Relative Gradient Decay')
    ax.legend(fontsize=9)
    ax.axhline(0, color='k', linestyle='--', alpha=0.4)
    ax.grid(True, alpha=0.4)

    plt.tight_layout()
    plt.show()
    print('Key insight: ReLU with He init maintains near-constant gradient norm.')
    print('Sigmoid with small weights loses ~10 orders of magnitude by layer 1.')

Code cell 25

# === 5.3 Residual Connection Gradient Highway ===

def run_residual_experiment(weight_scale, activation='sigmoid', n_layers=L, width=n):
    """Same as above but with residual connections."""
    Ws = [np.random.randn(width, width) * weight_scale for _ in range(n_layers)]

    a = x.copy()
    preacts = []
    residuals = [a]  # residual stream values
    for W in Ws:
        z = W @ a
        preacts.append(z)
        if activation == 'sigmoid':
            a = a + sigmoid(z)  # residual addition
        else:
            a = a + relu(z)
        residuals.append(a)

    # Backward with residuals
    delta = np.random.randn(width)
    delta /= la.norm(delta)
    grad_norms_res = [la.norm(delta)]

    for l in range(n_layers - 1, 0, -1):
        if activation == 'sigmoid':
            act_prime = sigmoid_prime(preacts[l])
        else:
            act_prime = relu_prime(preacts[l])
        # Residual: delta passes through both paths
        delta_through_sublayer = Ws[l].T @ delta * act_prime
        delta = delta + delta_through_sublayer  # identity skip
        grad_norms_res.append(la.norm(delta))

    return list(reversed(grad_norms_res))

np.random.seed(0)
norms_res_sigmoid = run_residual_experiment(scale_sigmoid, 'sigmoid')
norms_res_relu    = run_residual_experiment(scale_relu, 'relu')

print('Gradient norm at layer 1 WITH residual connections:')
print(f'  Residual sigmoid: {norms_res_sigmoid[0]:.4f} (vs {norms_sigmoid[0]:.2e} without)')
print(f'  Residual ReLU:    {norms_res_relu[0]:.4f}')
print('\nResidual connections prevent vanishing: identity skip ensures gradient flow.')

6. Xavier and He Initialisation — Derivation

We verify the variance propagation argument behind Xavier and He initialisations.

Code cell 27

# === 6.1 Variance Propagation Through Linear Layers ===

np.random.seed(42)
n_trials = 10000
L_depth = 10
widths = [100] * (L_depth + 1)

def measure_activation_variance(weight_init_fn, activation='relu', n_trials=1000):
    """Measure variance of activations at each layer across trials."""
    n = widths[0]
    variances = []
    x_samples = np.random.randn(n_trials, n)

    a = x_samples  # (n_trials, n)
    variances.append(np.var(a))

    for l in range(L_depth):
        n_in = widths[l]
        n_out = widths[l+1]
        W = weight_init_fn(n_in, n_out)
        z = a @ W.T  # (n_trials, n_out)
        if activation == 'relu':
            a = np.maximum(0, z)
        elif activation == 'tanh':
            a = np.tanh(z)
        else:
            a = z
        variances.append(np.var(a))

    return variances

# Xavier: sigma^2 = 2/(n_in + n_out)  (for linear/tanh)
xavier_init = lambda n_in, n_out: np.random.randn(n_out, n_in) * np.sqrt(2/(n_in+n_out))
# He: sigma^2 = 2/n_in  (for ReLU)
he_init = lambda n_in, n_out: np.random.randn(n_out, n_in) * np.sqrt(2/n_in)
# Too small
small_init = lambda n_in, n_out: np.random.randn(n_out, n_in) * 0.01
# Too large
large_init = lambda n_in, n_out: np.random.randn(n_out, n_in) * 1.0

vars_he_relu   = measure_activation_variance(he_init, 'relu')
vars_small     = measure_activation_variance(small_init, 'relu')
vars_large     = measure_activation_variance(large_init, 'relu')
vars_xavier_ln = measure_activation_variance(xavier_init, 'linear')

print('Activation variance at final layer (layer 10):')
print(f'  He init (ReLU):     {vars_he_relu[-1]:.4f}  (want ~{vars_he_relu[0]:.4f})')
print(f'  Small init (ReLU):  {vars_small[-1]:.2e}   (vanishing)')
print(f'  Large init (ReLU):  {vars_large[-1]:.2e}   (exploding)')
print(f'  Xavier (linear):    {vars_xavier_ln[-1]:.4f}  (want ~{vars_xavier_ln[0]:.4f})')

Code cell 28

# === 6.2 Initialisation Variance Plot ===
if HAS_MPL:
    layers = list(range(L_depth + 1))
    fig, ax = plt.subplots(figsize=(10, 5))

    ax.semilogy(layers, vars_he_relu,   'g-o', linewidth=2, label='He init + ReLU (stable)')
    ax.semilogy(layers, vars_small,     'r-s', linewidth=2, label='Small init + ReLU (vanishing)')
    ax.semilogy(layers, vars_large,     'b-^', linewidth=2, label='Large init + ReLU (exploding)')
    ax.semilogy(layers, vars_xavier_ln, 'k-D', linewidth=2, label='Xavier + Linear (stable)')

    ax.set_xlabel('Layer')
    ax.set_ylabel('Activation Variance (log scale)')
    ax.set_title('Variance Propagation Through 10 Layers')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.4)
    plt.tight_layout()
    plt.show()

    print('He initialisation keeps variance near 1.0 at all layers.')
    print('This ensures gradient norms are also well-scaled.')

7. Memory-Efficient Backpropagation

Implementing gradient checkpointing and measuring the memory-compute tradeoff.

Code cell 30

# === 7.1 Memory Measurement Utilities ===
import sys

def memory_of(arrays):
    """Total memory of a list of numpy arrays in MB."""
    return sum(a.nbytes for a in arrays) / (1024**2)

# Simulate a deep network activation cache
L_big = 32
B, T, d = 4, 512, 256  # batch, seq_len, hidden (small for demo)

# Standard caching: all activations
all_acts = [np.zeros((B, T, d), dtype=np.float32) for _ in range(L_big)]
mem_full = memory_of(all_acts)

# Checkpointing: every sqrt(L) layers
k = int(np.sqrt(L_big))
checkpoints = [np.zeros((B, T, d), dtype=np.float32) for _ in range(k)]
mem_checkpoint = memory_of(checkpoints)

print(f'Network: {L_big} layers, B={B}, T={T}, d={d}')
print(f'Full cache memory:     {mem_full:.2f} MB')
print(f'Checkpoint memory:     {mem_checkpoint:.2f} MB  (k={k} checkpoints)')
print(f'Memory reduction:      {mem_full/mem_checkpoint:.1f}x')
print(f'Compute overhead:      ~{k/(k-1):.2f}x  (one extra forward per segment)')

Code cell 31

# === 7.2 Gradient Checkpointing Implementation ===

def forward_segment(x, Ws, start, end, activation='relu'):
    """Forward pass from layer start to end. Returns output and cache."""
    a = x.copy()
    cache = {'z': [], 'a': [a]}
    for l in range(start, end):
        z = Ws[l] @ a
        cache['z'].append(z)
        a = relu(z) if activation == 'relu' else sigmoid(z)
        cache['a'].append(a)
    return a, cache

def backward_segment(delta, cache, Ws, start, end):
    """Backward through layers start to end. Returns delta at start."""
    dW_seg = []
    for l in range(end-1, start-1, -1):
        idx = l - start
        a_prev = cache['a'][idx]
        z = cache['z'][idx]
        dW_seg.append(np.outer(delta, a_prev))
        if l > start:
            delta = Ws[l].T @ delta * relu_prime(cache['z'][idx-1])
        else:
            delta = Ws[l].T @ delta
    return delta, list(reversed(dW_seg))

# Setup: 8 layers, checkpoint every 4
np.random.seed(42)
n_layers_ck = 8
width_ck = 10
Ws_ck = [np.random.randn(width_ck, width_ck) * np.sqrt(2/width_ck) for _ in range(n_layers_ck)]
x_ck = np.random.randn(width_ck)
y_true_ck = np.zeros(width_ck); y_true_ck[0] = 1.0

# Standard backprop
a = x_ck
cache_std = {'z': [], 'a': [a]}
for l in range(n_layers_ck):
    z = Ws_ck[l] @ a
    cache_std['z'].append(z)
    a = relu(z)
    cache_std['a'].append(a)

delta = a - y_true_ck
dW_std = []
for l in range(n_layers_ck-1, -1, -1):
    dW_std.append(np.outer(delta, cache_std['a'][l]))
    if l > 0:
        delta = Ws_ck[l].T @ delta * relu_prime(cache_std['z'][l-1])
dW_std = list(reversed(dW_std))

print('Gradient norms (standard backprop):')
for l, g in enumerate(dW_std):
    print(f'  Layer {l}: ||dW|| = {la.norm(g):.6f}')

Code cell 32

# === 7.3 Mixed Precision Gradient Analysis ===

# Simulate FP16 vs FP32 gradient precision
np.random.seed(42)
grad_true = np.random.randn(100) * 1e-4  # small gradients typical in late training

# FP32 representation
grad_fp32 = grad_true.astype(np.float32)

# FP16 representation (simulate)
grad_fp16 = grad_true.astype(np.float16)

# FP16 with loss scaling (scale by 2^13 = 8192)
scale = 2**13
grad_scaled_fp16 = (grad_true * scale).astype(np.float16) / scale

err_fp32 = la.norm(grad_fp32 - grad_true)
err_fp16 = la.norm(grad_fp16 - grad_true)
err_scaled = la.norm(grad_scaled_fp16 - grad_true)

print('Mixed Precision Gradient Errors:')
print(f'  FP32 error:              {err_fp32:.2e}')
print(f'  FP16 error (no scale):   {err_fp16:.2e}  ({err_fp16/err_fp32:.0f}x worse)')
print(f'  FP16 with loss scaling:  {err_scaled:.2e}  ({err_scaled/err_fp32:.1f}x worse)')

# Count underflowed gradients (become exactly 0 in FP16)
n_underflow_fp16 = np.sum(grad_fp16 == 0) 
n_underflow_scaled = np.sum(grad_scaled_fp16 == 0)
print(f'  Underflows (FP16):       {n_underflow_fp16}/100')
print(f'  Underflows (scaled FP16): {n_underflow_scaled}/100')
print('\nLoss scaling prevents FP16 underflow for small gradients.')

8. LoRA Backward Pass

Low-rank adaptation gradient computation and comparison with full fine-tuning.

Code cell 34

# === 8.1 LoRA Forward and Backward ===

np.random.seed(42)
m_lora, n_lora, r_lora = 8, 6, 2

# LoRA weights
W0 = np.random.randn(m_lora, n_lora)  # frozen
B = np.random.randn(m_lora, r_lora) * 0.01   # trainable
A = np.random.randn(r_lora, n_lora) * 0.01   # trainable
x_lora = np.random.randn(n_lora)

# Forward: y = (W0 + B@A) x
y_lora = (W0 + B @ A) @ x_lora

# Upstream gradient
y_bar = np.random.randn(m_lora)

# LoRA backward (W0 frozen -> no gradient for W0)
# dA = B^T y_bar x^T
dA = (B.T @ y_bar)[:, None] * x_lora[None, :]   # (r, n)
# dB = y_bar x^T A^T  (outer product then A^T)
dB = np.outer(y_bar, x_lora @ A.T)   # (m, r)
# dx (for input gradient)
dx_lora = (W0 + B @ A).T @ y_bar

# Verify dA via FD
def lora_proj(A_flat):
    A_ = A_flat.reshape(r_lora, n_lora)
    y_ = (W0 + B @ A_) @ x_lora
    return y_bar @ y_

dA_fd = grad_fd(lora_proj, A.flatten()).reshape(r_lora, n_lora)
check_close('LoRA dA', dA, dA_fd)

def lora_proj_B(B_flat):
    B_ = B_flat.reshape(m_lora, r_lora)
    y_ = (W0 + B_ @ A) @ x_lora
    return y_bar @ y_

dB_fd = grad_fd(lora_proj_B, B.flatten()).reshape(m_lora, r_lora)
check_close('LoRA dB', dB, dB_fd)

print(f'\nParameter counts:')
print(f'  Full W gradient:   {m_lora * n_lora} = {m_lora}x{n_lora}')
print(f'  LoRA A+B gradient: {r_lora*(m_lora+n_lora)} = {r_lora}x({m_lora}+{n_lora})')
print(f'  Compression:       {m_lora*n_lora / (r_lora*(m_lora+n_lora)):.1f}x fewer gradient params')

Code cell 35

# === 8.2 LoRA vs Full Finetuning — Training Comparison ===

np.random.seed(42)
m_ft, n_ft, r_ft = 20, 15, 3

# Data: simple linear regression problem
W_true = np.random.randn(m_ft, n_ft) * 0.5
X_data = np.random.randn(100, n_ft)
Y_data = X_data @ W_true.T + 0.01*np.random.randn(100, m_ft)

def train_lora(n_steps=300, lr=0.01, r=r_ft):
    W0_ = np.random.randn(m_ft, n_ft) * 0.1
    B_ = np.zeros((m_ft, r))
    A_ = np.random.randn(r, n_ft) * 0.01
    losses = []
    for step in range(n_steps):
        idx = np.random.choice(100, 16)
        x_b, y_b = X_data[idx], Y_data[idx]
        y_hat = x_b @ (W0_ + B_ @ A_).T
        loss = 0.5 * np.mean((y_hat - y_b)**2)
        losses.append(loss)
        dy = (y_hat - y_b) / 16
        dA_ = (B_.T @ dy.T) @ x_b / m_ft
        dB_ = (dy.T @ x_b) @ A_.T / m_ft
        A_ -= lr * dA_; B_ -= lr * dB_
    return losses

def train_full(n_steps=300, lr=0.01):
    W_ = np.random.randn(m_ft, n_ft) * 0.1
    losses = []
    for step in range(n_steps):
        idx = np.random.choice(100, 16)
        x_b, y_b = X_data[idx], Y_data[idx]
        y_hat = x_b @ W_.T
        loss = 0.5 * np.mean((y_hat - y_b)**2)
        losses.append(loss)
        dy = (y_hat - y_b) / 16
        dW = (dy.T @ x_b) / m_ft
        W_ -= lr * dW
    return losses

losses_lora = train_lora()
losses_full = train_full()

print(f'Final loss — LoRA: {losses_lora[-1]:.4f},  Full: {losses_full[-1]:.4f}')

if HAS_MPL:
    fig, ax = plt.subplots(figsize=(9, 4))
    ax.semilogy(losses_lora, 'b-', linewidth=2, label=f'LoRA (r={r_ft})')
    ax.semilogy(losses_full, 'r-', linewidth=2, label='Full finetuning')
    ax.set_xlabel('Step')
    ax.set_ylabel('MSE Loss')
    ax.set_title('LoRA vs Full Finetuning Convergence')
    ax.legend()
    ax.grid(True, alpha=0.4)
    plt.tight_layout()
    plt.show()

9. Backpropagation Through Time (BPTT)

Implementing vanilla RNN training and observing vanishing gradients over long sequences.

Code cell 37

# === 9.1 Vanilla RNN — Forward Pass ===

np.random.seed(42)
n_x, n_h = 5, 8   # input dim, hidden dim
T_rnn = 20         # sequence length

# RNN weights (small, so vanishing is visible)
Wh = np.random.randn(n_h, n_h) * 0.3
Wx = np.random.randn(n_h, n_x) * 0.3
bh = np.zeros(n_h)

def rnn_forward(xs, Wh, Wx, bh, h0=None):
    """Forward pass through T timesteps."""
    T = len(xs)
    h = h0 if h0 is not None else np.zeros(n_h)
    hs = [h.copy()]
    zs = []
    for t in range(T):
        z = Wh @ h + Wx @ xs[t] + bh
        h = np.tanh(z)
        zs.append(z)
        hs.append(h.copy())
    return hs, zs

# Generate random sequence
xs = [np.random.randn(n_x) for _ in range(T_rnn)]
hs, zs = rnn_forward(xs, Wh, Wx, bh)

print(f'RNN: input_dim={n_x}, hidden_dim={n_h}, seq_len={T_rnn}')
print(f'Hidden state norm at t=1:  {np.linalg.norm(hs[1]):.4f}')
print(f'Hidden state norm at t=20: {np.linalg.norm(hs[20]):.4f}')

Code cell 38

# === 9.2 BPTT — Gradient Norm Over Time Steps ===

def rnn_backward(hs, zs, xs, Wh, Wx):
    """BPTT. Returns gradient norms at each step."""
    T = len(xs)
    delta = np.random.randn(n_h)  # gradient from loss at final step
    delta /= np.linalg.norm(delta)
    grad_norms = [np.linalg.norm(delta)]

    for t in range(T-1, 0, -1):
        tanh_prime = 1 - np.tanh(zs[t])**2
        delta = Wh.T @ (delta * tanh_prime)
        grad_norms.append(np.linalg.norm(delta))

    return list(reversed(grad_norms))

grad_norms_rnn = rnn_backward(hs, zs, xs, Wh, Wx)

print('BPTT gradient norm from final step backwards:')
for t in [0, 4, 9, 14, 19]:
    print(f'  Step {t+1:2d}: {grad_norms_rnn[t]:.2e}')

if HAS_MPL:
    fig, ax = plt.subplots(figsize=(9, 4))
    ax.semilogy(range(1, T_rnn+1), grad_norms_rnn, 'b-o', markersize=5)
    ax.set_xlabel('Time step (from loss)')
    ax.set_ylabel('Gradient norm')
    ax.set_title(f'BPTT: Gradient Norm vs Time Step (T={T_rnn}, ||Wh||_2 = {np.linalg.norm(Wh,2):.2f})')
    ax.grid(True, alpha=0.4)
    plt.tight_layout()
    plt.show()
    print(f'Gradient decays {grad_norms_rnn[0]/grad_norms_rnn[-1]:.0f}x from step 20 to step 1')

10. Straight-Through Estimator and Discrete Operations

Gradient estimation for non-differentiable operations in neural networks.

Code cell 40

# === 10.1 STE for Quantisation-Aware Training ===

def quantise(x, bits=8):
    """Round to nearest integer, simulating INT8 quantisation."""
    scale = (2**(bits-1) - 1)  # 127 for INT8
    return np.round(np.clip(x * scale, -scale, scale)) / scale

def ste_backward(grad_out, x):
    """Straight-Through Estimator: pass gradient through as-is."""
    # STE: treat quantise as identity in backward
    return grad_out  # no modification

# Compare: true gradient (0 almost everywhere) vs STE
x_vals = np.linspace(-2, 2, 1000)
q_vals = quantise(x_vals)

# True gradient of quantise: 0 except at jumps (measure 0)
true_grad = np.zeros_like(x_vals)  # zero a.e.
# STE gradient: identity
ste_grad = np.ones_like(x_vals)   # always 1

print('STE vs True Gradient for quantise operation:')
print(f'  True gradient (mean): {true_grad.mean():.4f}  (zero everywhere)')
print(f'  STE gradient (mean):  {ste_grad.mean():.4f}  (passes gradient through)')

if HAS_MPL:
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    axes[0].plot(x_vals, x_vals, 'b--', alpha=0.5, label='Identity')
    axes[0].plot(x_vals, q_vals, 'r-', linewidth=1.5, label='Quantised')
    axes[0].set_title('Quantisation: Forward')
    axes[0].set_xlabel('x')
    axes[0].legend()
    axes[0].grid(True, alpha=0.4)

    axes[1].plot(x_vals, true_grad, 'r-', linewidth=2, label='True gradient (=0 a.e.)')
    axes[1].plot(x_vals, ste_grad, 'g-', linewidth=2, label='STE gradient (=1)')
    axes[1].set_title('Gradients: STE vs True')
    axes[1].set_xlabel('x')
    axes[1].legend()
    axes[1].set_ylim(-0.5, 1.5)
    axes[1].grid(True, alpha=0.4)

    plt.tight_layout()
    plt.show()

Code cell 41

# === 10.2 REINFORCE Gradient Estimator ===

# Setup: binary stochastic node z ~ Bernoulli(sigma(theta))
# Loss: L = z^2,  want: d/dtheta E[L]

def sigmoid_fn(x):
    return 1 / (1 + np.exp(-x))

# Analytical gradient:
# E[L] = E[z^2] = P(z=1) * 1^2 + P(z=0)*0^2 = sigma(theta)
# d/dtheta E[L] = sigma'(theta) = sigma(theta)*(1-sigma(theta))

theta_vals = np.linspace(-3, 3, 50)
analytical_grads = [sigmoid_fn(t) * (1 - sigmoid_fn(t)) for t in theta_vals]

# REINFORCE estimate at theta=0
theta = 0.0
p = sigmoid_fn(theta)
n_samples = 50000
np.random.seed(42)
z_samples = np.random.binomial(1, p, n_samples).astype(float)

# REINFORCE: E[L * grad_theta log p(z|theta)]
# log p(z=1|theta) = log sigma(theta) = -log(1+exp(-theta))
# log p(z=0|theta) = log(1-sigma(theta))
log_prob_grad = np.where(
    z_samples == 1,
    (1 - p),     # d/dtheta log sigma(theta) = 1 - sigma = 1-p
    (-p)         # d/dtheta log(1-sigma) = -sigma = -p
)
L_samples = z_samples**2
reinforce_est = np.mean(L_samples * log_prob_grad)
analytical = p * (1 - p)

print(f'REINFORCE gradient estimate (theta=0, n={n_samples}):')
print(f'  REINFORCE: {reinforce_est:.6f}')
print(f'  Analytical: {analytical:.6f}')
print(f'  Relative error: {abs(reinforce_est-analytical)/analytical:.4f}')

11. Gradient Checking Toolkit

A practical toolkit for verifying any backpropagation implementation.

Code cell 43

# === 11.1 Gradient Check Protocol ===

def gradient_check(loss_fn, params_flat, grad_flat, h=1e-5, tol=1e-5):
    """Full gradient check against finite differences."""
    n = len(params_flat)
    grad_fd_all = np.zeros(n)

    for i in range(n):
        p_plus = params_flat.copy(); p_plus[i] += h
        p_minus = params_flat.copy(); p_minus[i] -= h
        grad_fd_all[i] = (loss_fn(p_plus) - loss_fn(p_minus)) / (2*h)

    abs_err = np.abs(grad_flat - grad_fd_all)
    rel_err = abs_err / (np.abs(grad_flat) + np.abs(grad_fd_all) + 1e-10)

    worst_idx = np.argmax(rel_err)
    result = {
        'max_rel_err': rel_err.max(),
        'mean_rel_err': rel_err.mean(),
        'pass': rel_err.max() < tol,
        'worst_idx': worst_idx,
        'worst_backprop': grad_flat[worst_idx],
        'worst_fd': grad_fd_all[worst_idx],
    }
    return result

# Test on a 2-layer network
np.random.seed(0)
net_test = MLP([3, 5, 2])
x_test = np.random.randn(3)
y_test = np.array([1.0, 0.0])

y_out, cache_test = net_test.forward(x_test)
dW_t, db_t = backward_pass(net_test, cache_test, y_out, y_test)

# Flatten all gradients
grad_flat = np.concatenate([g.flatten() for g in dW_t] + [g.flatten() for g in db_t])
params_flat = np.concatenate([W.flatten() for W in net_test.W] + [b.flatten() for b in net_test.b])

def make_loss(params_flat):
    # Reconstruct network from flat params
    net2 = MLP([3, 5, 2], seed=0)
    idx = 0
    for l in range(net2.L):
        sz = net2.W[l].size
        net2.W[l] = params_flat[idx:idx+sz].reshape(net2.W[l].shape)
        idx += sz
    for l in range(net2.L):
        sz = net2.b[l].size
        net2.b[l] = params_flat[idx:idx+sz]
        idx += sz
    y_p, _ = net2.forward(x_test)
    return net2.mse_loss(y_p, y_test)

result = gradient_check(make_loss, params_flat, grad_flat)
status = 'PASS' if result['pass'] else 'FAIL'
print(f'{status} — Gradient check')
print(f'  Max relative error:  {result["max_rel_err"]:.2e}')
print(f'  Mean relative error: {result["mean_rel_err"]:.2e}')

Code cell 44

# === 11.2 Error Pattern Diagnostics ===

# Inject a bug: wrong sign in weight gradient
dW_buggy = [g.copy() for g in dW_t]
dW_buggy[0] = -dW_t[0]  # wrong sign

grad_buggy = np.concatenate([g.flatten() for g in dW_buggy] + [g.flatten() for g in db_t])
result_buggy = gradient_check(make_loss, params_flat, grad_buggy)

print('Buggy gradient (wrong sign on W[0]):')
print(f'  Max relative error: {result_buggy["max_rel_err"]:.2e}  (large -> bug detected)')

# Correct gradient
result_ok = gradient_check(make_loss, params_flat, grad_flat)
print(f'Correct gradient:')
print(f'  Max relative error: {result_ok["max_rel_err"]:.2e}  (small -> correct)')

print('\nGradient checking is essential before deploying any backprop implementation.')

12. Cost Analysis — Forward vs Reverse Mode

Empirical verification of the O(1)O(1) vs O(n)O(n) cost difference.

Code cell 46

# === 12.1 JVP Cost (Forward Mode) ===
# Computing the full Jacobian row by row via JVP

import time

np.random.seed(42)
net_cost = MLP([20, 50, 50, 1])
x_cost = np.random.randn(20)
y_true_cost = np.array([1.0])

# Method 1: Finite differences (approximates JVP column by column)
# Costs O(n_params) forward passes
def params_as_flat(net):
    return np.concatenate([W.flatten() for W in net.W] + [b.flatten() for b in net.b])

n_params = sum(W.size for W in net_cost.W) + sum(b.size for b in net_cost.b)

start = time.time()
# Simulate FD gradient (just measure time for 1 pass * n_params)
# (We don't actually run n_params passes to keep demo fast)
_, cache_cost = net_cost.forward(x_cost)
t_forward = time.time() - start

start = time.time()
y_cost, cache_cost = net_cost.forward(x_cost)
dW_cost, db_cost = backward_pass(net_cost, cache_cost, y_cost, y_true_cost)
t_backward = time.time() - start

print(f'Network: {n_params} parameters')
print(f'Forward pass time:  {t_forward*1000:.3f} ms')
print(f'Backward pass time: {t_backward*1000:.3f} ms')
print(f'Backward/Forward ratio: {t_backward/t_forward:.2f}x  (theoretical: ~2-3x)')
print(f'FD gradient would cost: {n_params} forward passes')
print(f'Estimated FD time: {n_params * t_forward * 1000:.1f} ms  ({n_params:.0f}x slower)')

Code cell 47

# === 12.2 Summary: Why Backprop is O(1) passes ===

print('='*60)
print('FUNDAMENTAL THEOREM OF BACKPROPAGATION')
print('='*60)
print()
print('Computing ALL gradients d(L)/d(theta_i) for i=1,...,n')
print('costs the same as ~2-3 forward passes, regardless of n.')
print()
print('This is the "free lunch" that makes gradient descent')
print('tractable for billion-parameter neural networks.')
print()
print('Comparison:')
print(f'  Model with n={n_params} params')
print(f'  Backprop:  2-3 forward passes   ({2*t_forward*1000:.3f} ms)')
print(f'  Finite diff: {n_params} forward passes ({n_params*t_forward*1000:.0f} ms)')
print(f'  Speedup: {n_params/3:.0f}x')
print()
print('For GPT-3 (175B params), FD would require ~350 billion')
print('forward passes. Backprop requires ~3. This is why backprop')
print('is the foundation of all modern AI training.')

13. Transformer Layer — End-to-End Gradient Flow

We simulate one transformer block and verify gradient flow through the residual stream.

Code cell 49

# === 13.1 Simplified Transformer Block Forward ===

np.random.seed(42)
T_tf, d_tf = 4, 8   # seq_len, hidden_dim
d_ff = 16            # MLP expansion

# Pre-norm transformer block: x' = x + Attn(LN(x)),  x'' = x' + MLP(LN(x'))
# Weights for attention
Wq = np.random.randn(d_tf, d_tf) * 0.1
Wk = np.random.randn(d_tf, d_tf) * 0.1
Wv = np.random.randn(d_tf, d_tf) * 0.1
Wo = np.random.randn(d_tf, d_tf) * 0.1
# Weights for MLP
W1 = np.random.randn(d_ff, d_tf) * 0.1
W2 = np.random.randn(d_tf, d_ff) * 0.1

gamma_ln1 = np.ones(d_tf); beta_ln1 = np.zeros(d_tf)
gamma_ln2 = np.ones(d_tf); beta_ln2 = np.zeros(d_tf)

def ln_fwd(x, g, b, eps=1e-5):
    mu = x.mean(-1, keepdims=True)
    var = x.var(-1, keepdims=True)
    xhat = (x - mu) / np.sqrt(var + eps)
    return g * xhat + b, xhat, var

def attn_fwd(x, Wq, Wk, Wv, Wo):
    Q = x @ Wq.T; K = x @ Wk.T; V = x @ Wv.T
    S = Q @ K.T / np.sqrt(d_tf)
    S_max = S.max(-1, keepdims=True)
    P = np.exp(S - S_max)
    P = P / P.sum(-1, keepdims=True)
    O = P @ V @ Wo.T
    return O, P

def mlp_fwd(x, W1, W2):
    h = np.maximum(0, x @ W1.T)  # ReLU
    return h @ W2.T, h

# Forward through one transformer block
X = np.random.randn(T_tf, d_tf)  # residual stream input

# Attention sublayer
ln1_out, xhat1, var1 = ln_fwd(X, gamma_ln1, beta_ln1)
attn_out, P_attn = attn_fwd(ln1_out, Wq, Wk, Wv, Wo)
X2 = X + attn_out   # residual

# MLP sublayer
ln2_out, xhat2, var2 = ln_fwd(X2, gamma_ln2, beta_ln2)
mlp_out, h_mlp = mlp_fwd(ln2_out, W1, W2)
X3 = X2 + mlp_out   # residual

print(f'Transformer block: T={T_tf}, d={d_tf}, d_ff={d_ff}')
print(f'Input norm:   {np.linalg.norm(X):.4f}')
print(f'Output norm:  {np.linalg.norm(X3):.4f}')
print(f'Attn output norm:  {np.linalg.norm(attn_out):.4f}')
print(f'MLP output norm:   {np.linalg.norm(mlp_out):.4f}')

Code cell 50

# === 13.2 Gradient Flow Through Residual Stream ===

# Upstream gradient at X3 (from loss)
dX3 = np.random.randn(T_tf, d_tf)
dX3 /= np.linalg.norm(dX3)

# Backward through MLP sublayer
# X3 = X2 + mlp_out  -> dX2_from_mlp = dX3 (identity skip)
dX2_skip = dX3.copy()  # identity residual path

# Backward through MLP itself (simplified: just track norm)
# dmlp_out = dX3, so dX2 from MLP = W2^T relu'(W1 x) W1 dX3 ...
dmlp_approx = dX3 @ W2 * (h_mlp > 0) @ W1  # approximate backward
dX2_total = dX2_skip + dmlp_approx

# Backward through attention sublayer
dX_skip = dX2_total.copy()  # identity residual path
# Approximate attention backward
dattn_approx = dX2_total @ Wo * P_attn.mean() @ Wv  # rough
dX_total = dX_skip + dattn_approx

print('Gradient norms through transformer block:')
print(f'  dX3 (upstream):          {np.linalg.norm(dX3):.4f}')
print(f'  dX2 via skip (no MLP):   {np.linalg.norm(dX2_skip):.4f}  (= dX3, identity)')
print(f'  dX2 total (skip+MLP):    {np.linalg.norm(dX2_total):.4f}')
print(f'  dX via skip (no Attn):   {np.linalg.norm(dX_skip):.4f}   (= dX2, identity)')
print(f'  dX total (skip+Attn):    {np.linalg.norm(dX_total):.4f}')
print()
print('Key: Skip path preserves gradient norm exactly.')
print('     Sublayer adds/subtracts from that baseline.')

Code cell 51

# === 13.3 Xavier Initialisation Verification ===

# Show that Xavier keeps gradient variance stable for tanh networks

np.random.seed(42)
n_exp = 1000  # number of experiments
L_exp = 10
n_exp_w = 50

# Xavier: sigma^2 = 2/(n_in + n_out) = 1/n for square layers
sigma_xavier = np.sqrt(1.0 / n_exp_w)
sigma_too_small = 0.01

grad_var_xavier = []
grad_var_small = []

for trial in range(n_exp):
    x0 = np.random.randn(n_exp_w)
    delta = np.random.randn(n_exp_w)

    for W_scale, store in [(sigma_xavier, grad_var_xavier), (sigma_too_small, grad_var_small)]:
        d = delta.copy()
        for l in range(L_exp):
            W = np.random.randn(n_exp_w, n_exp_w) * W_scale
            z = W @ x0
            tanh_p = 1 - np.tanh(z)**2
            d = W.T @ d * tanh_p
        if trial == 0:
            store.append(np.var(d))
        else:
            store[-1] = np.var(d)  # just keep last
        x0 = np.tanh(z)

# Measure over multiple trials
vars_xavier_exp = []
vars_small_exp = []
for trial in range(100):
    x0 = np.random.randn(n_exp_w)
    d_x = np.random.randn(n_exp_w)
    d_s = d_x.copy()
    for l in range(L_exp):
        Wx = np.random.randn(n_exp_w, n_exp_w) * sigma_xavier
        Ws2 = np.random.randn(n_exp_w, n_exp_w) * sigma_too_small
        zx = Wx @ x0; zs = Ws2 @ x0
        d_x = Wx.T @ d_x * (1 - np.tanh(zx)**2)
        d_s = Ws2.T @ d_s * (1 - np.tanh(zs)**2)
        x0 = np.tanh(zx)
    vars_xavier_exp.append(np.var(d_x))
    vars_small_exp.append(np.var(d_s))

print('Gradient variance after 10 tanh layers (100 trials):')
print(f'  Xavier init:    {np.mean(vars_xavier_exp):.4f} ± {np.std(vars_xavier_exp):.4f}')
print(f'  Small init:     {np.mean(vars_small_exp):.2e} ± {np.std(vars_small_exp):.2e}')
print(f'  Ratio: {np.mean(vars_xavier_exp)/max(np.mean(vars_small_exp),1e-30):.0f}x better gradient scale')

Summary

Key Results Verified

SectionResultStatus
§1 Chain RuleJfg=JfJgJ_{f\circ g} = J_f \cdot J_g verified numericallyPASS
§2 AutogradVJP duality u(Jv)=(Ju)v\mathbf{u}^\top(J\mathbf{v}) = (J^\top\mathbf{u})^\top\mathbf{v}PASS
§3 Backprop3-layer MLP gradients match FD for all layersPASS
§4 LayersLinear, Softmax+CE, LayerNorm, Attention gradientsPASS
§5 VanishingHe init keeps variance near 1 at depth 20; residuals helpShown
§6 InitialisationXavier/He variance propagation verified empiricallyShown
§7 MemoryCheckpointing achieves O(L)O(\sqrt{L}) memory tradeoffMeasured
§8 LoRABackward gradients verified; rr-fold compressionPASS

The Fundamental Equation of Deep Learning

δ[l]=(W[l+1])δ[l+1]σ[l](z[l])\boldsymbol{\delta}^{[l]} = \left(W^{[l+1]}\right)^\top \boldsymbol{\delta}^{[l+1]} \odot \sigma'^{[l]}(\mathbf{z}^{[l]})

This recurrence, applied in reverse topological order through the computation graph, is backpropagation. Every modern deep learning system is built on this formula.

Next: §04 Optimality Conditions — how the gradient is used to find minima.