Lesson overview | Previous part | Lesson overview
Chain Rule and Backpropagation: Part 11: Exercises to Appendix K: Summary Tables
11. Exercises
Exercise 1 - Scalar Chain Rule Verification
Let and .
(a) Compute using the chain rule analytically.
(b) Evaluate the derivative at and .
(c) Verify numerically using centred finite differences.
(d) Compute - explain why the order of composition matters.
Exercise 2 - Jacobian Composition
Let and be defined by:
(a) Compute and analytically.
(b) Compute the Jacobian of using the chain rule .
(c) Verify using finite differences at .
(d) Compute directly and confirm it equals part (b).
Exercise 3 - Backprop Through a 2-Layer Network
Two-layer network: , , , .
With , , :
(a) Implement forward pass. Compute for given values.
(b) Implement backward pass manually using the backpropagation recurrence.
(c) Verify your gradients using numpy finite differences.
(d) Implement gradient descent for 100 steps with learning rate and verify loss decreases.
Exercise 4 - Vanishing Gradients Analysis
(a) Construct a 20-layer sigmoid network with all weights . Compute the gradient at layer 1 symbolically and numerically.
(b) Repeat with ReLU activation. Compare gradient magnitudes at layers 1, 5, 10, 20.
(c) Apply Xavier initialisation to the sigmoid network and compare gradient flow.
(d) Add residual connections to the 20-layer sigmoid network. Quantify the improvement.
(e) Plot gradient norm vs. layer depth for all four cases.
Exercise 5 - Gradient Checkpointing
(a) Implement a 10-layer feedforward network with explicit intermediate caching. Measure peak memory usage.
(b) Implement the same network with gradient checkpointing at every 3rd layer. Measure memory.
(c) Verify that both implementations produce identical gradients.
(d) Measure the compute overhead of recomputation. How does it compare to the theoretical ?
(e) Find the optimal checkpoint interval that minimises total memory x compute cost.
Exercise 6 - Attention Gradient
Single-head attention: with for , .
(a) Implement forward pass.
(b) Implement backward pass computing given .
(c) Verify all three gradients using finite differences.
(d) For causal masking (set for ), show that the backward pass is unchanged except at masked positions.
Exercise 7 - LoRA Gradient Analysis
(a) Implement a linear layer with LoRA adaptation. Set .
(b) Compute gradients and analytically and verify numerically.
(c) Confirm that but is not used (frozen).
(d) Compare the number of gradient parameters for full fine-tuning vs. LoRA.
(e) Implement LoRA training for 200 steps on a toy task and compare convergence with full fine-tuning.
Exercise 8 - REINFORCE and STE
(a) Implement a stochastic computational graph: , .
(b) Compute the REINFORCE gradient analytically.
(c) Estimate the REINFORCE gradient with 10000 samples. Verify against the analytical value.
(d) Implement STE for the rounding operation: , . Compute the STE gradient and update .
(e) Compare STE-based quantisation-aware training on a toy example: train for 50 steps and measure quantisation error vs. a post-training quantised model.
12. Why This Matters for AI (2026 Perspective)
| Concept | Concrete AI Impact |
|---|---|
| Multivariate chain rule | The mathematical foundation of every gradient-based learning algorithm - without it, backprop cannot be defined |
| VJP as backprop primitive | Modern autodiff systems (JAX, PyTorch) are built around VJP primitives; the cost of reverse mode is what makes training billion-parameter models tractable |
| Computation graphs | torch.compile (PyTorch 2.0), XLA (JAX/TensorFlow), TensorRT all operate by analysing the computation graph to fuse kernels and optimise memory layout |
| Fused softmax + CE gradient | The clean gradient makes language model training numerically stable; Flash Attention's backward uses the same softmax log-sum-exp statistics |
| Xavier/He initialisation | Ensures gradient scale at depth 1 or depth 96 - a critical practical enabler for deep network training |
| Residual connections | The "gradient highway" identity term in ResNets/transformers is why 100-layer networks train at all; this was the key insight enabling GPT-3's 96 layers |
| Gradient checkpointing | Enables training LLMs with 128K context lengths; without it, the activation memory would be prohibitive |
| FlashAttention backward | IO-aware backward pass reduces memory from to while maintaining numerical equivalence; standard in all production LLM training as of 2024 |
| LoRA backward | Only parameters accumulate gradients; enables fine-tuning 70B models on a single H100 via the low-rank backward structure |
| STE / REINFORCE | STE enables quantisation-aware training (GPTQ, AWQ, QLoRA); REINFORCE enables RLHF's policy gradient step in PPO-based alignment training |
| BPTT | The failure of vanilla BPTT for long sequences motivated LSTMs, GRUs, and ultimately the attention mechanism which replaces recurrence with direct pairwise interactions |
| ZeRO gradient sharding | Partitions gradient storage across GPUs linearly in GPU count; enables training models that would require more memory per GPU without it |
| Mixed precision backward | BF16 backward passes achieve memory bandwidth vs FP32 on H100, with dynamic loss scaling preventing underflow; standard in all LLM training since GPT-3 |
| Higher-order gradients | Gradient penalties in GANs, MAML's meta-gradient, and Hessian-vector products for learning rate scheduling all require differentiating through the backward pass |
Conceptual Bridge
Where we came from: 01 (Partial Derivatives) gave us tools to differentiate multivariate functions component by component. 02 (Jacobians and Hessians) assembled those into matrix objects capturing full first- and second-order sensitivity. We now know what a derivative is for a function .
What this section added: The chain rule tells us how derivatives compose - allowing us to differentiate functions built from primitives. Backpropagation is the algorithmic instantiation of this composition for computation graphs, and the VJP (reverse mode) makes the cost of differentiating a scalar loss with respect to millions of parameters equal to the cost of a single forward pass. This is not an approximation - it is exact and provably optimal.
What this enables: Every gradient-based learning algorithm - SGD, Adam, RMSprop, LARS, Shampoo - requires only the gradient , which backprop provides. The advanced sections of this chapter (04 Optimisation, 05 Automatic Differentiation) build directly on the VJP abstraction established here.
Connection to transformer training: Modern LLM training is essentially an exercise in efficient backpropagation at scale. Every engineering decision - Flash Attention's tiled backward, ZeRO's gradient sharding, gradient checkpointing, LoRA's low-rank backward, mixed precision loss scaling - is a response to the memory and compute constraints of the backward pass. Understanding backpropagation is therefore prerequisite to understanding why LLM training systems are designed the way they are.
POSITION IN THE CURRICULUM
PREREQUISITES (must know):
01 Partial Derivatives - partialf/partialx, gradient, directional derivative
02 Jacobians & Hessians - J_f, Frchet derivative, VJP/JVP
THIS SECTION (03):
Chain Rule & Backpropagation
- Multivariate chain rule (J_{fog} = J_f * J_g)
- Computation graphs (DAG, topological order)
- Backprop recurrence (delta = Wdelta sigma'(z))
- Gradient derivations (linear, softmax+CE, LN, attention)
- Vanishing/exploding gradients + solutions
- Memory-efficient backprop (checkpointing, Flash Attention)
- Advanced: BPTT, STE, REINFORCE, higher-order gradients
WHAT THIS ENABLES:
04 Optimisation - gradient descent, Adam, second-order methods
05 Automatic Differentiation - AD systems, tape, jit compilation
07 Neural Networks - full training loop built on backprop
08 Transformer Architecture - FlashAttention, LoRA, gradient flow
CROSS-CHAPTER CONNECTIONS:
03-Advanced-LA/02-SVD - gradient low-rank structure
04-Calculus/02-Derivatives - scalar chain rule (special case)
06-Probability/03-MLE - loss functions that backprop optimises
For automatic differentiation systems that implement these ideas at scale, see 05 Automatic Differentiation.
For the optimisation algorithms that consume backprop's output, see 04 Multivariate Optimisation.
Appendix A: Worked Backpropagation Example
A.1 Complete Worked Example - 3-Layer Network
To make the backpropagation formulas concrete, we trace through a minimal example end-to-end.
Network architecture:
- Input:
- Layer 1: , , activation: ReLU
- Layer 2: , , activation: none (scalar output)
- Loss: (MSE with target )
Forward pass:
Backward pass:
Output layer gradient:
Layer 2 gradients (scalar output, linear):
Error signal propagated to layer 1:
Through ReLU:
Layer 1 weight gradients:
Verification (finite difference for ): Perturb by :
Numerically: . Difference .
A.2 Computational Cost Comparison
Forward pass: - one GEMM per layer.
Backward pass: Also - same asymptotic cost, with constant factor .
Memory: Cache all and : scalars - linear in total neuron count.
The fundamental theorem of backpropagation: Computing for all parameters costs only a constant factor more than computing itself. This is the miracle that makes gradient-based learning tractable.
Formal statement: Let be the time to evaluate in the forward pass. Then the time to compute all partial derivatives via backprop is at most where in practice.
This contrasts with finite differences: computing for each of parameters via finite differences costs forward passes - for GPT-3 with , this would be billion forward passes, or approximately the heat death of the universe in compute time.
Appendix B: JVP vs VJP - Mode Selection and Complexity
B.1 Forward Mode vs Reverse Mode
Given , both modes compute the same gradient information but with different costs:
| Mode | Computes | Cost per pass | Total cost for full Jacobian |
|---|---|---|---|
| Forward (JVP) | One column of | ||
| Reverse (VJP) | One row of |
COST MATRIX: WHICH MODE WINS?
Goal: compute partial/partialtheta for : R -> R (scalar loss)
n = |theta| = 175,000,000,000 (GPT-3 parameter count)
m = 1 (scalar loss)
Forward mode: m x Tf = 1 x Tf <- ONE PASS
Reverse mode: n x Tf = 175B x Tf <- 175 BILLION PASSES
Wait - that's backwards! Reverse mode (backprop) costs 1 pass
because m=1 means ONE ROW of J_f = the gradient row vector.
Forward mode would need n=175B passes to fill all columns.
RULE: Use reverse mode (backprop) when m n
RULE: Use forward mode (JVP) when n m
Most ML: n m = 1 -> backprop is optimal
When forward mode wins: Computing the sensitivity of all outputs to one input parameter - e.g., computing how the entire model output changes as a single hyperparameter varies. Also: Jacobian-vector products in conjugate gradient (no need for the full Jacobian).
Mixed strategies: For functions with , the optimal choice is to split the Jacobian into row/column blocks and use each mode for the appropriate blocks - the basis of adjoint methods in numerical PDE solvers.
B.2 Tangent Mode for Hessian-Vector Products
As shown in 02, the Hessian-vector product can be computed by composing forward and reverse modes:
Algorithm (Pearlmutter's R{} trick, 1994):
- Forward pass (JVP with direction ): compute and simultaneously
- Cost: same as backprop () - one pass suffices
Implementation in PyTorch:
g = torch.autograd.grad(loss, params, create_graph=True)
flat_g = torch.cat([gi.view(-1) for gi in g])
hvp = torch.autograd.grad(flat_g @ v, params)
Cost: 2 backprop passes, no matrix formed. This is the primitive for:
- Conjugate gradient for Newton steps (K-FAC-style)
- Lanczos iteration for of Hessian
- Eigenvalue monitoring during training (Cohen et al., 2022 - edge of stability)
Appendix C: Automatic Differentiation Preview
C.1 The AD Abstraction
Automatic differentiation (AD) is a mechanical procedure for transforming any program that computes into a program that also computes (or JVPs/VJPs). This section previews the idea; the full treatment is in 05.
AD is neither symbolic differentiation (too slow, exponentially large expressions) nor numerical differentiation (finite differences - too imprecise, costs evaluations). AD exploits the fact that every program is a composition of primitives, and the chain rule tells us exactly how to compose their derivatives.
Two flavours:
SYMBOLIC DIFF NUMERICAL DIFF AUTO DIFF
f(x) = x^2 + sin(x) Compute f(x+h) Track ops in
and f(x-h) computation tape
-> d/dx = 2x+cos(x) -> (f(x+h)-f(x-h))/2h -> Exact as FP allows
Exact but expression Approximate; costs Exact, costs O(1)
size can explode O(n) evaluations evaluations
C.2 The Tape (Wengert List)
The Wengert list (1964) records, during the forward pass, every primitive operation applied and its operands. The backward pass replays this tape in reverse, accumulating adjoints.
FORWARD TAPE EXAMPLE: f(x) = exp(x) * (x + 1)
Tape (built during forward):
v_1 = x (input)
v_2 = exp(v_1) (op: exp, operand: v_1)
v_3 = v_1 + 1.0 (op: add, operands: v_1, 1.0)
v_4 = v_2 x v_3 (op: mul, operands: v_2, v_3)
Backward (replay in reverse):
v_4 = 1.0 (seed)
v_2 += v_4 x v_3 = 1.0 x (x+1) (mul backward)
v_3 += v_4 x v_2 = 1.0 x exp(x) (mul backward)
v_1 += v_3 x 1.0 = exp(x) (add backward)
v_1 += v_2 x exp(v_1) = (x+1)exp(x) (exp backward)
Total: v_1 = exp(x) + (x+1)exp(x) = (x+2)exp(x) (by product rule)
PyTorch's Tensor stores a grad_fn attribute at each node - this is the tape in disguise. Calling .backward() replays the tape in reverse.
For more: See 05 Automatic Differentiation for the complete treatment of forward/reverse mode AD, source transformation, operator overloading, and the design of JAX vs PyTorch autograd.
Appendix D: Numerical Gradient Verification
In practice, every backpropagation implementation should be verified against finite differences. This appendix presents the standard toolkit.
D.1 Centred Finite Differences
For a scalar loss and parameter :
Error analysis: Centred differences have error (vs for forward differences). Optimal step size balances truncation error () against floating-point cancellation error ( where is machine epsilon):
Use for float64 and for float32.
Relative error check: Accept the gradient check if:
D.2 When Gradient Checks Fail
Common failure modes:
| Symptom | Likely cause |
|---|---|
| Relative error throughout | h too large (truncation) or float32 precision |
| Relative error for specific parameters | Bug in backward for that parameter type |
| Relative error for all gradients | Loss is approximately linear in those parameters at the test point |
| Fails at kink (ReLU/max) | Gradient not defined at ; test point near kink; use away from kinks |
| Fails only for batch size 1 | BatchNorm statistics degenerate; use batch size for BN checks |
D.3 Gradient Check in PyTorch
from torch.autograd import gradcheck
def f(x):
return (x ** 2).sum()
x = torch.randn(5, requires_grad=True, dtype=torch.float64)
gradcheck(f, (x,), eps=1e-6, atol=1e-4, rtol=1e-4)
gradcheck automates centred finite differences for all inputs with requires_grad=True. Always use dtype=torch.float64 for gradient checking - float32 precision is insufficient for reliable checks.
Appendix E: Key Formulas Reference
E.1 Chain Rule Summary
| Setting | Formula |
|---|---|
| Scalar composition | |
| Vector composition | |
| VJP (backprop step) | |
| JVP (forward step) |
E.2 Backpropagation Formulas
| Layer | Forward | Backward ( given) |
|---|---|---|
| Linear | - | , , |
| Elementwise | - | |
| Softmax+CE | ||
| Residual | - | |
| LayerNorm | Complex (see 5.4); passes signal |
E.3 Activation Derivatives
| Name | ||
|---|---|---|
| ReLU | ||
| Sigmoid | ||
| Tanh | ||
| GELU | ||
| SiLU | ||
| Softplus |
E.4 Initialisation Standards
| Method | Distribution | Variance | When |
|---|---|---|---|
| Xavier uniform | Sigmoid, tanh | ||
| Xavier normal | Sigmoid, tanh | ||
| He uniform | ReLU | ||
| He normal | ReLU | ||
| GPT-2 residual | - | Transformer residuals |
Appendix F: Deep Dive - Vanishing Gradients in Transformers
F.1 Why Transformers Don't Vanish
A naive reading of the vanishing gradient analysis (6.1) suggests that 96-layer transformers should suffer catastrophic vanishing. They don't. Here is why.
The residual stream analysis: In a pre-norm transformer, the residual stream after layer is:
where is the -th sublayer (attention or MLP, wrapped in LayerNorm).
The gradient of the loss with respect to the input is:
At initialisation, the transformer weights are small, so and the product . The gradient flows back unchanged through all layers. This is categorically different from a plain deep network where the product of small Jacobians vanishes.
Gradient norm growth: As training progresses and weights grow, becomes nontrivial. The gradient norm may grow with depth, but this is controlled by:
- LayerNorm dampening (see 6.6)
- GPT-2's scaling of residual projections
- Gradient clipping ()
The "edge of stability" phenomenon (Cohen et al., 2022): In practice, the maximum Hessian eigenvalue often approaches (twice the inverse learning rate) and oscillates there. This is a gradient flow regime where the training dynamics are neither fully stable nor unstable, and gradients are large enough to cause oscillation but not divergence.
F.2 Gradient Norm as Training Signal
Modern LLM training monitors gradient norm at every step. Typical patterns:
GRADIENT NORM DURING LLM TRAINING
nablatheta_2
spike (loss spike)
1 clip threshold
normal training
steps
Patterns:
- Steady nablatheta < 1: healthy training, clipping inactive
- Sudden spike -> loss spike -> recovery: numerical event
(often a "bad" batch; LLM training has ~1-3 such events
per trillion tokens at scale)
- Slow upward drift: learning rate may be too high
Loss spike mitigation: When the gradient norm exceeds the clip threshold, the entire gradient update is scaled down. If the spike is from a corrupted batch, this prevents permanent damage to the model weights.
Gradient accumulation and norm: When using accumulation steps, each micro-batch contributes of the gradient. The global norm is computed across the accumulated gradient (after summation, before the optimiser step) - not across individual micro-batches.
F.3 Per-Layer Gradient Norm Analysis
For diagnostic purposes, logging the gradient norm per layer reveals:
- Embedding gradients: Often the largest, due to sparse updates (5.6)
- Early layers: Smallest (furthest from loss); potential vanishing
- Late layers: Largest; potential exploding
- LayerNorm parameters: Very small - and converge quickly
This per-layer analysis guided the design of:
- LARS/LAMB optimisers (You et al., 2017): layer-wise adaptive learning rates based on weight-to-gradient ratio
- Muon (2024): applies Newton step in gradient space with Nesterov momentum; designed for hidden layers while AdamW handles embedding and output
Appendix G: Historical Development
G.1 Timeline of Backpropagation
The development of backpropagation spans three centuries and multiple independent discoveries:
| Year | Event | Significance |
|---|---|---|
| 1676 | Leibniz publishes differential calculus (chain rule for single variable) | Mathematical foundation |
| 1744 | Euler uses variational methods (antecedent of reverse mode) | First "adjoint" idea |
| 1847 | Cauchy introduces gradient descent | The algorithm backprop serves |
| 1960 | Kalman filter (reverse-mode for linear dynamical systems) | AD in engineering |
| 1964 | Wengert introduces the "reverse accumulation" algorithm | First explicit AD |
| 1970 | Linnainmaa's thesis: general backpropagation | Full theoretical framework |
| 1974 | Werbos PhD thesis: backprop for neural networks | Connection to ML |
| 1982 | Hopfield networks (energy-based models with gradient) | Alternative to backprop |
| 1986 | Rumelhart, Hinton & Williams - "Learning representations by back-propagating errors" | Popularised backprop for NNs |
| 1991 | Hochreiter: vanishing gradient problem analysed | Identified depth barrier |
| 1997 | LSTM: gating to address vanishing gradient in RNNs | First scalable deep sequence model |
| 2012 | AlexNet: backprop on GPU at scale | Practical deep learning |
| 2015 | ResNets: residual connections for gradient flow | Enabled 100+ layer networks |
| 2016 | PyTorch / TensorFlow 1.0: autodiff frameworks | Democratised backprop |
| 2017 | Transformers: attention replaces BPTT | Solved long-range vanishing |
| 2018 | JAX: functional autodiff, JIT compilation | Research-grade AD |
| 2022 | FlashAttention: IO-aware backward pass | Efficient attention backward |
| 2022 | PyTorch 2.0 torch.compile | Graph-based kernel fusion |
| 2023 | FlashAttention-2: improved GPU utilisation | Standard for production |
| 2024 | FlashAttention-3: H100-optimised with async | State-of-art attention backward |
G.2 The Independent Discoveries
Backpropagation was independently discovered at least four times before becoming widely known:
-
Linnainmaa (1970): In his master's thesis, presented the general algorithm for computing exact partial derivatives of any function composed of elementary operations - precisely what we today call reverse-mode AD.
-
Werbos (1974): Applied the same idea to multi-layer neural networks in his PhD thesis, but the work was largely ignored for over a decade.
-
Parker (1985): Independently rediscovered backpropagation for neural networks.
-
Rumelhart, Hinton & Williams (1986): Published the algorithm in Nature and produced the critical experimental demonstrations that convinced the community it could work. Their paper is the one most often cited today.
This pattern of independent rediscovery is common in mathematics - the ideas are "in the air" once the prerequisites are established. The chain rule (1676) + computation graphs (1960s) + gradient descent (1847) = backpropagation (inevitable).
G.3 The Hardware-Algorithm Co-evolution
The practical impact of backpropagation depends critically on hardware:
- CPU era (1986-2011): Backprop is theoretically valid but computationally slow. Networks with more than 3-4 layers were impractical.
- GPU era (2012-present): NVIDIA's CUDA (2007) enables massively parallel GEMM operations. The bottleneck shifts from FLOPS to memory bandwidth.
- Tensor core era (2017-present): NVIDIA Volta/Ampere/Hopper GPUs have dedicated matrix multiply accelerators. FP16/BF16 tensor cores achieve 10x the throughput of FP32.
- Memory wall: As models scale, the backward pass's memory requirements dominate. FlashAttention, ZeRO, gradient checkpointing all address the memory wall.
The 2024 FLOP/memory ratio in H100 GPUs ( TFLOPS vs TB/s bandwidth) means that memory access, not computation, is the primary bottleneck for backprop at scale. This fundamental constraint is why FlashAttention's IO-aware design is so impactful.
Appendix H: Connections to Optimisation and Learning Theory
H.1 What the Gradient Tells Us
The gradient computed by backpropagation is the direction of steepest ascent in parameter space (by the first-order Taylor expansion). Gradient descent moves in the opposite direction:
What the gradient does NOT tell us:
- The curvature of the loss landscape (need Hessian for that)
- The optimal step size
- Whether we are near a local minimum, saddle point, or maximum
- Whether the gradient is statistically well-estimated (needs large enough batch)
What the gradient DOES tell us:
- The direction of maximal increase (used negated for descent)
- The sensitivity of the loss to each parameter
- Which parameters are "active" (nonzero gradient) vs. saturated (near-zero gradient)
H.2 Gradient Stochasticity
In practice, the true gradient over the full data distribution is approximated by the stochastic gradient over a mini-batch:
This is an unbiased estimator: .
Variance: . Larger batches have lower gradient variance (more accurate gradient estimate) but provide diminishing returns beyond the "critical batch size" (McCandlish et al., 2018).
For LLMs: The critical batch size for GPT-3-scale models is approximately million tokens. Training at this batch size achieves the best loss-per-FLOP tradeoff. Using larger batches wastes compute; using smaller batches wastes gradient estimation quality.
H.3 The Gradient as a Sufficient Statistic
For first-order optimisers (SGD, Adam, AdaGrad, RMSprop), the gradient is the only information extracted from the forward-backward pass. Second-order information (Hessian curvature) is either ignored or approximated.
Why not use the full Hessian? For parameters, the Hessian is a matrix - entries. Storing it is impossible ( FP32 values ~= bytes ~= GB). Inverting it is even more impossible.
Practical second-order methods use approximations:
- Diagonal: AdaGrad/Adam maintain diagonal Hessian approximations ( memory)
- Kronecker factored: K-FAC (see 02) uses per layer ( per layer)
- Low-rank: PSGD, Shampoo maintain low-rank or block-diagonal approximations
- Newton-Schulz: Muon (2024) approximates the matrix square root efficiently
H.4 Generalisation and the Implicit Gradient Bias
Gradient descent with small learning rate and large mini-batches does not merely find any minimum - it has an implicit bias toward flat minima (large regions with low loss) over sharp minima (narrow valleys).
Conjecture (Keskar et al., 2017): Flat minima generalise better because small perturbations to the parameters don't change the loss much - robust to noise in the data.
Mathematical foundation: The SGD noise effectively adds a regularisation term proportional to - the trace of the Hessian - biasing toward flat (low-trace-Hessian) minima.
This connects gradient computation (the topic of this section) to generalisation theory (a major open question in deep learning theory) - a reminder that understanding backpropagation fully requires understanding not just the mechanics, but the geometry of the loss landscape it navigates.
Appendix I: Practical Implementation Guide
I.1 Implementing Backprop from Scratch
When building a neural network framework from scratch, implement these components in order:
1. Primitive registry:
primitives = {}
def register_primitive(name, forward_fn, backward_fn):
"""Register a primitive op with its VJP."""
primitives[name] = (forward_fn, backward_fn)
# Example: multiplication primitive
def mul_forward(x, y): return x * y
def mul_backward(x, y, g_out): return g_out * y, g_out * x # (g_x, g_y)
register_primitive('mul', mul_forward, mul_backward)
2. Value class with gradient tracking:
class Value:
def __init__(self, data, parents=(), op=''):
self.data = data
self.grad = 0.0
self._backward = lambda: None # closure capturing parents
self._parents = parents
self._op = op
def __mul__(self, other):
out = Value(self.data * other.data, (self, other), 'mul')
def _backward():
self.grad += other.data * out.grad # VJP for self
other.grad += self.data * out.grad # VJP for other
out._backward = _backward
return out
def backward(self):
# Topological sort, then reverse
topo = []
visited = set()
def build_topo(v):
if v not in visited:
visited.add(v)
for p in v._parents: build_topo(p)
topo.append(v)
build_topo(self)
self.grad = 1.0
for v in reversed(topo): v._backward()
This is essentially the complete autograd engine from Karpathy's micrograd (2020) - approximately 100 lines implement a working backprop engine.
3. Building blocks: Extend Value with __add__, __pow__, exp, log, relu, softmax - each with its VJP closure.
I.2 Common Implementation Bugs
Bug 1: Overwriting instead of accumulating gradients
# Wrong:
self.grad = other.data * out.grad # erases previous contributions!
# Correct:
self.grad += other.data * out.grad # accumulates (fan-out nodes)
Bug 2: Forgetting to zero gradients between batches
# Wrong: gradient accumulates across batches
loss = model(x)
loss.backward()
optimizer.step()
# Correct:
optimizer.zero_grad() # <- must come before backward
loss = model(x)
loss.backward()
optimizer.step()
Bug 3: Not detaching from the graph for inference
# Wrong: builds graph unnecessarily during inference
with torch.no_grad(): # <- this is the correct fix
prediction = model(x)
Bug 4: Shape mismatch in weight gradient
# Wrong: grad_W and W may have different shapes
grad_W = delta @ x # (n_out, 1) @ (1, n_in) only works for batch size 1
# Correct: outer product for single sample
grad_W = np.outer(delta, x) # (n_out, n_in)
# Correct: batched
grad_W = (1/B) * Delta @ X.T # (n_out, B) @ (B, n_in) = (n_out, n_in)
I.3 Testing Checklist
Before deploying any backprop implementation:
- Gradient check passes for all primitive operations (relative error )
- Loss decreases monotonically for small enough learning rate (verify on toy problem)
- Gradients are zero for frozen parameters
- Gradient accumulation at fan-out nodes verified (shared weight receives sum)
- Shape of each gradient matches shape of corresponding parameter
- Memory usage is not
- Higher-order gradients work if needed (use
create_graph=Truein PyTorch) - Mixed precision: FP16 forward, FP32 gradient accumulation, loss scaling in place
Appendix J: Connections to Information Theory and Statistics
J.1 Fisher Information and the Natural Gradient
The ordinary gradient measures the steepest direction in parameter space with respect to the Euclidean metric. But parameter space has a natural metric induced by the probability distribution - the Fisher information metric.
Fisher information matrix:
Natural gradient (Amari, 1998):
The natural gradient is the steepest direction in the distributional geometry of the model - invariant to reparametrisation. Computing it exactly requires inverting , which costs .
K-FAC (02) approximates as a Kronecker product, making the natural gradient step tractable. It remains the most principled second-order optimiser for neural networks.
For LLMs: The approximation used in practice is Adam's diagonal (second moment of gradient as proxy for diagonal Fisher). This is crude but sufficient - Adam is a diagonal natural gradient step.
J.2 Gradient as Score Function
For a probabilistic model , the gradient of the log-likelihood is the score function:
The score function is the quantity computed by backpropagation during maximum likelihood estimation. Properties:
- (score has zero mean)
- (Fisher information = variance of score)
For language models: The negative log-likelihood has gradient - the same formula from 5.3, now understood as the negative score.
J.3 KL Divergence and the Gradient of ELBO
In variational inference and RL (RLHF), we often need gradients of KL divergences. For discrete distributions:
This is computed via backprop through the log-probability of the policy under KL regularisation - the precise form used in RLHF's PPO loss, which includes a KL penalty between the fine-tuned policy and the reference model .
References
-
Rumelhart, Hinton & Williams (1986) - "Learning representations by back-propagating errors." Nature, 323, 533-536. The canonical backpropagation paper.
-
Linnainmaa, S. (1970) - "The representation of the cumulative rounding error of an algorithm as a Taylor expansion of the local rounding errors." Master's thesis, University of Helsinki. First general reverse-mode AD.
-
Hochreiter, S. (1991) - "Untersuchungen zu dynamischen neuronalen Netzen." Diploma thesis, TU Munich. First analysis of vanishing gradients.
-
Glorot, X. & Bengio, Y. (2010) - "Understanding the difficulty of training deep feedforward neural networks." AISTATS. Xavier initialisation.
-
He, K. et al. (2015) - "Delving Deep into Rectifiers." ICCV. He initialisation for ReLU networks.
-
He, K. et al. (2016) - "Deep Residual Learning for Image Recognition." CVPR. ResNets and gradient highways.
-
Ba, J. et al. (2016) - "Layer Normalization." arXiv:1607.06450. LayerNorm for transformers.
-
Vaswani, A. et al. (2017) - "Attention Is All You Need." NeurIPS. Transformer architecture with attention backward pass.
-
Amari, S. (1998) - "Natural Gradient Works Efficiently in Learning." Neural Computation. Natural gradient and Fisher information.
-
Martens, J. & Grosse, R. (2015) - "Optimizing Neural Networks with Kronecker-factored Approximate Curvature." ICML. K-FAC.
-
Hu, E. et al. (2022) - "LoRA: Low-Rank Adaptation of Large Language Models." ICLR. LoRA backward pass.
-
Dao, T. et al. (2022) - "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." NeurIPS. IO-aware backward for attention.
-
Cohen, J. et al. (2022) - "Gradient Descent on Neural Networks Typically Occurs at the Edge of Stability." ICLR. Edge of stability phenomenon.
-
Dao, T. (2023) - "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR 2024. FlashAttention-2.
-
Liu, S. et al. (2024) - "DoRA: Weight-Decomposed Low-Rank Adaptation." DoRA backward analysis.
Appendix K: Summary Tables
K.1 Backpropagation Algorithm Summary
COMPLETE BACKPROPAGATION ALGORITHM
INPUT: network weights theta, training pair (x, y)
PHASE 1 - FORWARD PASS
a^0 = x
For l = 1, 2, ..., L:
z^l = W^l a^{l-1} + b^l (cache z^l and a^{l-1})
a^l = sigma^l(z^l) (cache a^l)
= a^L
= loss(, y)
PHASE 2 - BACKWARD PASS
delta^L = partial/partialz^L (output layer gradient, layer-specific)
For l = L-1, L-2, ..., 1:
delta^l = (W^{l+1}) delta^{l+1} sigma'^l(z^l)
PHASE 3 - GRADIENT ASSEMBLY
For l = 1, 2, ..., L:
nabla_{W^l} = delta^l (a^{l-1})
nabla_{b^l} = delta^l
PHASE 4 - PARAMETER UPDATE
theta <- theta - eta * nabla_theta (or Adam/RMSprop update)
K.2 Complexity Summary
| Operation | Time | Memory |
|---|---|---|
| Forward pass (L layers, width n) | cached activations | |
| Backward pass | error signals | |
| Full Jacobian via FD | $O( | \theta |
| Full Jacobian via backprop | ||
| Hessian-vector product | ||
| Gradient checkpointing | ||
| FlashAttention forward | ||
| FlashAttention backward |
K.3 Gradient Flow Interventions
| Problem | Diagnosis | Intervention |
|---|---|---|
| Vanishing gradients | ReLU/GELU, He init, residual connections | |
| Exploding gradients | Gradient clipping, LR warmup | |
| Dead neurons | for layer | Leaky ReLU, better init, BN |
| Slow convergence | at saddle | Momentum, Adam, noise injection |
| Oscillating loss | spikes | Reduce LR, increase batch |
| NaN gradients | Loss scaling, check log/softmax |