Lesson overview | Lesson overview | Next part
Chain Rule and Backpropagation: Part 1: Intuition to 10. Common Mistakes
1. Intuition
1.1 From Single-Variable to Multivariate Chain Rule
The single-variable chain rule says: if and , then
The intuition is rates of change compose multiplicatively. If triples its input and doubles its input, then multiplies by six.
The multivariate generalisation replaces scalars with vectors and scalar derivatives with Jacobian matrices. If and , then
The product is now matrix multiplication. This is not a different rule - it is the same rule, stated in the correct language for vector-valued functions. The single-variable rule is the special case where Jacobians degenerate to scalars.
SCALAR CHAIN RULE vs JACOBIAN CHAIN RULE
Scalar: x -> g -> u -> f -> y
R R R R R
dy/dx = (dy/du)(du/dx) [scalar multiplication]
Vector: x -> g -> u -> f -> y
R R R R R
J_{fog} = J_f * J_g [matrix multiplication]
(mxp) = (mxn) * (nxp)
The dimensions work out exactly like matrix multiplication.
The chain rule IS matrix multiplication for Jacobians.
What makes the multivariate version non-trivial is that must be evaluated at - the output of the inner function - not at itself. This point-dependence is where the local linear approximation lives: the Jacobian is the best linear approximation to at the specific point , and is the best linear approximation to at .
1.2 Backpropagation as Iterated Chain Rule
A deep neural network is a long composition of functions:
where each is a layer (linear + activation), is the loss function, and is a scalar. Computing - the gradient of the loss with respect to layer 's parameters - requires applying the chain rule through every layer from to .
The chain rule gives:
where is the error signal at layer , and it satisfies the backpropagation recurrence:
This recurrence propagates the error signal backward from layer to layer 1 - hence "backpropagation." At each step, we multiply by the transposed Jacobian of the next layer. The entire algorithm is:
- Forward pass: compute and store , for
- Initialise:
- Backward pass: compute for using the recurrence
- Gradients: extract from and
Backpropagation is not a fundamentally different concept from the chain rule. It is the chain rule, applied efficiently by sharing intermediate computations (the error signals ) across all parameters in a layer.
1.3 Historical Context
| Year | Contributor | Development |
|---|---|---|
| 1676 | Leibniz | Differential calculus; first statement of the single-variable chain rule |
| 1755 | Euler | Extended to multiple variables |
| 1960 | Kelley | Gradient computation for optimal control (independent discovery of backprop concept) |
| 1970 | Linnainmaa | First complete description of reverse-mode automatic differentiation for computing gradients |
| 1974 | Werbos | First application to neural networks in his PhD thesis |
| 1986 | Rumelhart, Hinton, Williams | Popularised backpropagation in "Learning representations by back-propagating errors" - the paper that launched the neural network revolution |
| 1989 | LeCun | Applied backprop to convolutional networks for handwritten digit recognition |
| 2012 | Krizhevsky, Sutskever, Hinton | AlexNet demonstrated GPU-accelerated backprop at scale - kicked off the deep learning era |
| 2015 | Google Brain, Facebook AI | PyTorch/TensorFlow: automatic differentiation engines that compute backprop automatically |
| 2017 | Vaswani et al. | Transformer: backprop through multi-head attention; the architecture underlying GPT, BERT, LLaMA |
| 2021 | Hu et al. (LoRA) | Parameter-efficient fine-tuning by limiting gradient flow to low-rank subspaces |
| 2022 | Dao et al. (FlashAttention) | Recompute activations during backward to avoid materialising the attention matrix |
1.4 Why Backprop Defines Modern AI
Every large language model, image classifier, and diffusion model trained today relies on backpropagation for every gradient update. The scale is staggering: training GPT-4 reportedly required floating-point operations, the vast majority of which are forward and backward passes through the transformer network.
Backprop enables gradient-based learning at scale because its cost is proportional to the cost of the forward pass - typically where is the number of parameters. Alternative approaches (finite differences, evolution strategies, zeroth-order methods) are orders of magnitude more expensive.
Three properties make backprop indispensable:
-
Efficiency: One backward pass computes for all parameters simultaneously. Finite differences would need forward passes.
-
Exactness: Unlike finite differences, backprop computes the exact gradient (up to floating-point precision), not an approximation.
-
Composability: Any differentiable function composed of differentiable primitives has an automatically computable gradient. This is why PyTorch/JAX can differentiate arbitrary Python code that uses differentiable operations.
For AI in 2026: The gradient is the workhorse of every training algorithm: SGD, Adam, AdaGrad, Muon, SOAP - all are gradient-based. Fine-tuning (LoRA, QLoRA, DoRA), RLHF (PPO, DPO, GRPO), distillation, and continual learning all depend on backprop. Even methods that appear gradient-free (evolutionary strategies, black-box optimisation) are often used because they approximate the gradient in settings where backprop is unavailable (non-differentiable objectives, external APIs).
2. The Multivariate Chain Rule - Full Theory
2.1 The General Chain Rule - Proof
We prove the chain rule using the Frchet derivative from 02. Recall:
Definition. is Frchet differentiable at if there exists a linear map such that
The matrix of is the Jacobian .
Theorem (Chain Rule). Let be Frchet differentiable at , and be Frchet differentiable at . Then is Frchet differentiable at and
Proof. Let . We need to show that is the Frchet derivative of at . Write:
Let . Since is Frchet differentiable:
Now apply Frchet differentiability of at :
Substituting:
We show the remainder is :
- First term: .
- Second term: Since , we have .
Therefore .
When the chain rule fails. The chain rule requires both at and at to be Frchet differentiable. If either fails - for example at a ReLU kink where - the classical chain rule does not apply. In practice, these measure-zero sets are handled by choosing a subgradient (any element of the Clarke subdifferential), which is what deep learning frameworks do automatically.
2.2 Three Cases in Increasing Generality
Case 1: Scalar composition . where . Jacobians are = scalars, so .
Case 2: Scalar loss of a vector function. , where and . Jacobians: and (a row vector). So:
Taking the transpose: - the gradient of with respect to is the transposed Jacobian of times the gradient of . This is the VJP equation, the core of backprop.
Case 3: Vector composition . The most general case; Jacobians are full matrices and the chain rule is full matrix multiplication:
The dimensions verify: . The "inner dimension" (the dimension of the intermediate space ) cancels in the product, exactly as in matrix multiplication.
2.3 The VJP Form - Foundation of Backprop
Definition (VJP). For and a "cotangent" vector , the vector-Jacobian product is:
Why this is the right primitive for backprop. For a scalar loss composed with :
The gradient of the composed function with respect to the input is the VJP of the inner function, with the cotangent being the gradient of the outer function.
The backprop recursion is a chain of VJPs. For :
Starting from and applying VJPs from right to left computes all intermediate gradients.
Cost comparison. Computing for :
- JVP (forward mode): requires passes (one per input dimension). Cost: .
- VJP (reverse mode): requires pass. Cost: .
For parameters, reverse mode is times cheaper. This asymmetry is why all gradient-based deep learning uses reverse mode (backprop).
2.4 Long Chains and Telescoping Products
For a depth- network , the Jacobian of with respect to is:
This is a product of matrices. The spectral norm of the product satisfies:
If each , then . For , the gradient vanishes exponentially; for , it explodes. This is the mathematical source of the vanishing/exploding gradient problem (6).
Efficient computation: reverse order. In the forward direction, we compute left to right. In the backward direction, we compute the error signals right to left, reusing stored activations. The key observation: at step , we only need and the stored activation (or ) - we do not need to recompute from scratch.
2.5 Differentiating Through Discrete Operations
Some operations in neural networks are discontinuous or discrete: argmax (in beam search), rounding/quantisation (in QAT), sampling (in VAEs and RL). The chain rule does not directly apply.
Straight-Through Estimator (STE). For a quantisation function (round to nearest integer), the derivative is almost everywhere, giving zero gradient. The STE replaces the "true" zero gradient with 1 during the backward pass:
In code: y = round(x).detach() + x - x.detach() - adds in the forward pass (cancels) but contributes its gradient in the backward pass. STE is used in VQ-VAE, binary neural networks, and quantisation-aware training.
REINFORCE (score function estimator). For a stochastic node and loss , the gradient of with respect to is:
This allows gradient estimation without differentiating through the sampling step. Used in RLHF (PPO, GRPO) and variational inference. High variance; mitigated by baselines.
3. Computation Graphs
3.1 Formal DAG Definition
A computation graph is a directed acyclic graph encoding how scalar or tensor quantities depend on one another.
Nodes partition into three types:
| Type | Symbol | Role |
|---|---|---|
| Input nodes | Hold model inputs and parameters; no incoming edges | |
| Intermediate nodes | Hold computed activations; receive edges from their operands | |
| Output node | Holds the scalar loss ; required to be scalar for standard backprop |
Edges encode data dependency: for some primitive . Each edge carries an implicit local Jacobian .
Primitive operations are the atomic building blocks with known local gradients:
PRIMITIVE OPERATIONS AND THEIR LOCAL GRADIENTS
Operation Forward Local gradient (wrt input i)
z = x + y z = x + y partialz/partialx = 1, partialz/partialy = 1
z = x * y z = xy partialz/partialx = y, partialz/partialy = x
z = exp(x) z = e partialz/partialx = e
z = log(x) z = ln x partialz/partialx = 1/x
z = relu(x) z = max(0,x) partialz/partialx = [x>0] (a.e.)
z = W x + b Wx+b partialz/partialW = x (as outer), partialz/partialx = W
z = softmax(x) e/Sigmae diag(p) - pp (see 02)
Every deep learning framework maintains a lookup table of these
primitives together with their vjp implementations.
Topological ordering - a linear ordering of such that for every edge , appears before in . Topological order exists iff is acyclic (Kahn's algorithm, 1962). Both the forward pass and the backward pass respect topological order (the latter in reverse).
For AI: Every modern deep learning framework (PyTorch, JAX, TensorFlow) represents a neural network as a computation graph. PyTorch builds the graph dynamically during the forward pass via the autograd tape; JAX traces the graph statically via XLA compilation.
3.2 Forward Pass - Value Propagation
The forward pass evaluates all node values in topological order, caching intermediates required by the backward pass.
Algorithm (Forward Pass):
Input: graph G = (V, E), input values {x_1,...,x}
Output: loss value v_N, cache of intermediates
For v in topological_order(G):
if v is an input node:
cache[v] = x_v (given)
else:
cache[v] = phi_v(cache[u_1], ..., cache[u])
where u_1,...,u = parents(v)
return cache[v_N]
Memory cost of a naive forward pass: Caching all intermediates for backprop costs memory where is the number of nodes. For a transformer with layers and activations of size , this is approximately:
This is why gradient checkpointing (7) is essential for large models.
What gets cached? A memory-optimal forward pass only caches values that appear in at least one local gradient formula. For a linear layer , the backward needs (to compute ) but not (already accumulated into the output).
3.3 Backward Pass - Gradient Accumulation
The backward pass evaluates adjoint values for every node, in reverse topological order.
Define the adjoint of node as:
where we treat as a scalar intermediate (extending to tensors componentwise).
Initialisation: (the loss node).
Backward recurrence: For a node with children (successors) - nodes that depend on :
This is exactly the chain rule applied in reverse.
Algorithm (Backward Pass):
Input: graph G, cache from forward pass
Output: partial/partialv for all v in V
adjoint[v_N] <- 1
For v in reverse_topological_order(G):
For each parent u of v:
adjoint[u] += adjoint[v] * partialv/partialu(cache)
local_vjp(v, u, adjoint[v])
return {adjoint[u] : u is a parameter node}
The key observation: each edge requires only:
- The cached forward value at (for the local gradient formula)
- The downstream adjoint (for the VJP multiplication)
3.4 Gradient Accumulation at Branching Nodes
A fan-out node has multiple children . The correct gradient is the sum of contributions:
Proof: By the total derivative,
Example - residual connection:
RESIDUAL BRANCH: u feeds into both F(u) and the skip path
u
/ \
/ \
F(u) \ <- identity skip
\ /
\ /
z = F(u) + u
Forward: z = F(u) + u
Backward: = z * partialF(u)/partialu + z * 1
= z * J_F(u) + z
The identity skip guarantees a gradient highway:
even if J_F(u) ~= 0 (saturated layer), z flows back unchanged.
This is the deep reason residual networks (He et al., 2016) solved the vanishing gradient problem: the skip connection creates a constant-1 term in the backward accumulation, guaranteeing in gradient magnitude.
3.5 Dynamic vs Static Graphs
Two design philosophies produce different tradeoffs:
DYNAMIC GRAPHS (PyTorch eager mode) STATIC GRAPHS (JAX jit / TF graph)
Graph built anew each forward pass Graph compiled once, reused
Natural Python control flow XLA/CUDA fusion, kernel merging
Easy debugging (print anywhere) Memory-optimal buffer allocation
Variable-length sequences trivial Can export/serve without Python
Graph construction overhead Tracing must handle all branches
Less compiler optimisation Python side-effects invisible
Examples: PyTorch, early Chainer Examples: JAX jit, TF2 tf.function,
ONNX Runtime, TensorRT
For transformers: Most production LLM training uses torch.compile (PyTorch 2.0+) which bridges the two: eager-mode graph construction with TorchDynamo tracing and inductor backend compilation, recovering ~30-50% throughput from kernel fusion.
4. Backpropagation
4.1 Network Notation
Consider a feedforward neural network with layers. Define:
| Symbol | Meaning |
|---|---|
| Input vector, dimension | |
| Weight matrix for layer | |
| Bias vector for layer | |
| Pre-activation (linear combination) | |
| Post-activation (elementwise) | |
| Network output | |
| Scalar loss |
The forward pass computes and for .
4.2 Forward Equations
Cache for backward: .
4.3 Output Layer Gradient
For cross-entropy loss with softmax output, the output gradient has the celebrated clean form (derived in 5.3):
where is the one-hot label. This combines the softmax Jacobian with the cross-entropy gradient into a single elegant expression.
For MSE loss () with linear output:
(Same form, different derivation - a useful coincidence that makes implementation uniform.)
4.4 Backpropagation Recurrence - Proof
Define the error signal:
Theorem (Backpropagation Recurrence):
Proof: Apply the chain rule from to via and :
Step 1: .
Step 2: .
In matrix form: .
Step 3: Multiply by and transpose to get column vector :
The (elementwise) product arises because is applied elementwise - its Jacobian is diagonal.
4.5 Weight and Bias Gradients
Once error signals are computed, parameter gradients follow immediately:
Derivation of weight gradient:
Collecting over all : .
This is an outer product - the gradient is rank-1 for a single sample. For a batch of samples it averages to higher rank.
4.6 Batched Backpropagation
With a mini-batch , stack inputs into a matrix .
The forward pass becomes:
where .
The backward pass produces (error signals for all samples simultaneously).
Weight gradient for the batch:
This is a single matrix multiplication, making batched backprop efficient on GPUs which excel at large GEMM (general matrix multiplication) operations.
5. Gradient Derivations for Standard Layers
5.1 Linear Layer
Forward: , where , .
Upstream gradient: .
VJP (backward):
Derivation of : Each , so . By VJP: .
For AI: In a transformer with hidden dim and MLP expansion : the two linear layers in FFN pass gradients back with operations - same cost as the forward GEMM. Gradient computation for is also a GEMM.
5.2 Activation Functions
For elementwise :
Gradient formulas for common activations:
| Activation | Notes | ||
|---|---|---|---|
| ReLU | Sparse gradient; "dead neurons" if always | ||
| Sigmoid | Saturates; max gradient 0.25 at | ||
| Tanh | Saturates; max gradient 1 at | ||
| GELU | = Gaussian CDF; smooth at 0 | ||
| SiLU/Swish | Used in LLaMA, Mistral | ||
| SoftPlus | Smooth ReLU; gradient never zero |
GELU (Hendrycks & Gimpel, 2016) is the standard activation in GPT-2/3, BERT, and most modern LLMs. It gates the input by its own probability under a Gaussian, producing richer gradient structure than ReLU.
5.3 Fused Softmax + Cross-Entropy Gradient
Setup: Output logits , softmax probabilities , true label , loss .
Claim:
where is the -th standard basis vector.
Proof: Write .
This direct derivation bypasses the softmax Jacobian computation entirely, which is why modern frameworks implement cross-entropy as a fused operation. For numerical stability, the is computed with the log-sum-exp trick: where .
5.4 LayerNorm Gradient
Forward: LayerNorm normalises each token independently:
where , .
Backward: Let be the upstream gradient.
For the input gradient, define . The full gradient through the normalisation is:
This expression subtracts mean and mean-of-hadamard terms, reflecting that LayerNorm's Jacobian projects out two degrees of freedom (02 exercises).
For AI: LayerNorm appears in every transformer layer (pre-norm placement in modern architectures like GPT-NeoX, LLaMA). The gradient through LayerNorm is never zero - it always passes signal, unlike BatchNorm which can become degenerate at small batch sizes.
5.5 Dot-Product Attention Gradient
Forward (simplified single-head):
Backward: Given upstream :
Then , and similarly for , .
Critical memory issue: Storing for the backward costs - this is what FlashAttention avoids by recomputing from during the backward pass (see 7.3).
5.6 Embedding Layer Gradient
Forward: , where is the embedding table.
Backward: Given upstream for all tokens :
This is a sparse gradient - only rows corresponding to tokens in the sequence receive nonzero updates. For vocabulary size (LLaMA-3), the embedding matrix is , but only a tiny fraction of rows are updated per batch. Distributed training with embedding sharding exploits this sparsity.
6. Vanishing and Exploding Gradients
6.1 Magnitude Analysis - The Core Problem
Consider an -layer network with no activation functions (to isolate the linear case). The gradient of the loss with respect to layer parameters involves the product:
This is a product of matrices. By the submultiplicativity of the spectral norm:
If for all layers:
If :
GRADIENT MAGNITUDE ACROSS LAYERS
gradient norm
exploding (rho > 1)
ideal (rho = 1)
vanishing (rho < 1)
layer l
L 0
With activations, the product includes sigma'(z) terms (< 1 for sigmoid)
compounding the vanishing problem.
This was identified by Hochreiter (1991) as the fundamental obstacle to training deep networks with gradient descent.
6.2 Activations and Saturation
For sigmoid : for all , with equality only at . In the tails (), .
For tanh: , saturating similarly.
In a network with sigmoid layers and all activations near saturation, the gradient at layer 1 is suppressed by approximately . For : - numerically zero.
ReLU resolves saturation: , which is either 0 or 1. For active neurons, it passes gradients unchanged. However, "dying ReLU" (neurons with always) creates a different problem - those neurons receive zero gradient and never recover.
GELU and SiLU (used in LLaMA) are smooth approximations that avoid hard zeros, maintaining nonzero gradients everywhere.
6.3 Xavier and He Initialisation
Goal: Choose initial weights so that gradient (and activation) variance is preserved across layers - avoiding exponential growth or decay from the start of training.
Xavier Initialisation (Glorot & Bengio, 2010) - for symmetric activations (tanh, linear):
Assumption: Weights i.i.d., inputs with variance .
Forward variance preservation: .
Backward variance preservation: .
Compromise:
He Initialisation (He et al., 2015) - for ReLU activations:
ReLU zeroes half the distribution, so effective variance is halved: . To compensate:
For AI: GPT-2 uses a scaled version: weight initialisation with the residual projection layers further scaled by where is the number of transformer layers, to control the variance accumulation in the residual stream.
6.4 Residual Connections as Gradient Highways
Theorem: In a residual network , the gradient satisfies:
Key insight: Expanding the product, we get:
The identity term guarantees that even if all (at initialisation), the gradient receives the full upstream signal unchanged. This is the theoretical explanation for why ResNets (He et al., 2016) can be trained with hundreds of layers.
In modern transformers, the pre-norm architecture (LayerNorm before the sublayer, not after) further improves gradient flow by ensuring that the residual path carries a pure copy of the signal.
6.5 Gradient Clipping
Gradient explosion is addressed pragmatically by global gradient norm clipping:
where is the concatenated parameter gradient vector and is the clip threshold.
Typical values: for transformers (used in GPT-3, PaLM, LLaMA).
Why global (not per-layer)? Clipping each layer's gradient independently destroys the relative proportions of updates across layers, disrupting the Adam momentum states. Global clipping preserves direction, only reducing magnitude.
Relationship to RNNs: Gradient clipping was originally introduced for RNNs (Mikolov, 2012; Pascanu et al., 2013), where the vanishing/exploding problem is especially severe due to the long chain of time steps.
6.6 Batch Normalisation and Layer Normalisation
BatchNorm (Ioffe & Szegedy, 2015) normalises each feature across the batch, stabilising the distribution of pre-activations. Its gradient has a complex form involving batch statistics, but crucially it prevents activations from saturating on average.
LayerNorm (Ba et al., 2016) normalises each sample across features - preferred in transformers because:
- Behaviour is independent of batch size (critical for small-batch inference)
- Gradient analysis shows it damps large pre-activation magnitudes
- Pre-norm placement ensures the residual stream grows in a controlled manner
Empirical gradient norm tracking is standard practice in LLM training: the gradient norm is logged at every step, and sudden spikes indicate loss spikes or numerical issues. The Chinchilla and GPT-4 training runs used gradient norm monitoring as a primary signal for training health.
7. Memory-Efficient Backpropagation
7.1 Memory Cost of Standard Backprop
Standard backpropagation caches all intermediate activations for use in the backward pass. For a transformer with layers, batch size , sequence length , and hidden dimension :
| Component cached | Size | At FP16 |
|---|---|---|
| Attention QKV projections | bytes | |
| Attention scores (pre-softmax) | bytes | |
| MLP intermediate | bytes | |
| LayerNorm stats | negligible |
For GPT-3 (, , , , ): the attention scores alone require - clearly infeasible without optimisation.
7.2 Gradient Checkpointing
Idea: Trade compute for memory. Instead of caching all activations, cache only a subset of "checkpoint" activations and recompute the rest during the backward pass.
Algorithm (checkpointing at every -th layer):
GRADIENT CHECKPOINTING
Forward pass:
Compute all layers normally
Save activations only at layers 0, k, 2k, 3k, ...
Discard all other intermediate activations
Backward pass:
For each segment [lk, (l+1)k]:
Re-run the forward pass from checkpoint lk
Now have all intermediates for this segment
Compute gradients for layers lk+1 to (l+1)k-1
Discard intermediates (no longer needed)
Memory-compute tradeoff:
- Memory: checkpoints (optimal with ) instead of
- Compute: Each layer's forward pass is run twice (once in original forward, once in recomputation) -> approximately compute overhead
For AI: torch.utils.checkpoint.checkpoint() implements this in PyTorch with a single function call. LLaMA, Mistral, and most OSS LLM trainers enable activation checkpointing by default for sequences longer than ~2048 tokens.
Selective recomputation: Flash Attention (see 7.3) takes a more targeted approach - instead of checkpointing by layer, it recomputes only the attention scores (the term) during the backward pass, since those are the dominant memory consumer.
7.3 FlashAttention: Fused Backward Pass
The problem: Standard attention stores for the backward pass. For (long-context models), this is GB per layer per batch element.
FlashAttention solution (Dao et al., 2022): Compute attention in tiles that fit in SRAM (GPU on-chip cache), using the online softmax algorithm (Milakov & Gimelshein, 2018) to avoid materialising the full matrix.
Backward pass in FlashAttention: The backward pass needs but doesn't store it. Instead:
- Store only the softmax normalisation statistics (scalars per row) - memory
- During the backward pass, recompute tile by tile from and the stored statistics
- Accumulate gradients tile by tile without ever forming full
Complexity:
- Memory: instead of
- FLOPs: the forward FLOPs (small constant factor)
- Wall-clock speedup: 2-4x over standard PyTorch attention on A100
For AI: FlashAttention is the default attention implementation in modern LLM training (vLLM, HuggingFace Transformers, NanoGPT). FlashAttention-3 (2024) further optimises for H100 tensor core and async operations.
7.4 Mixed Precision Training
Observation: FP32 (32-bit float) is unnecessarily precise for gradients. FP16 (16-bit float) has higher memory bandwidth on modern GPUs, but overflow/underflow is common for small/large gradient values.
AMP (Automatic Mixed Precision) strategy:
| Component | Precision | Reason |
|---|---|---|
| Forward activations | FP16 | Fast compute, lower memory |
| Backward gradients | FP16 | Fast compute |
| Weight updates | FP32 | Avoid precision loss |
| Master weights | FP32 | Preserve small updates () |
| Loss scaling | Dynamic | Prevent FP16 underflow for small gradients |
Loss scaling: Multiply the loss by a large scale factor (typically to ) before backward, then divide gradients by before the weight update. This shifts gradient values into the representable FP16 range. The scale factor is increased or decreased based on whether overflow (inf/nan) occurred.
BF16 (Brain Float 16, used in TPUs and H100): same 16-bit width but with 8 exponent bits (same as FP32) and only 7 mantissa bits. Eliminates overflow issues while retaining dynamic range - now the preferred format for LLM training.
8. Advanced Differentiation Topics
8.1 Backpropagation Through Time (BPTT)
A recurrent neural network (RNN) with hidden state can be viewed as a feedforward network unrolled through time:
UNROLLED RNN - BPTT VIEW
x_1 -> [cell] -> h_1 -> [cell] -> h_2 -> [cell] -> h_3 -> ... -> h -> loss
W W W (shared weights)
BPTT = backprop through the unrolled graph.
Gradient of loss w.r.t. W = sum of gradients from all time steps.
The gradient with respect to at time step involves the product:
Each factor . When , the product of such factors vanishes exponentially. This is the core failure mode of vanilla RNNs on long sequences (Hochreiter, 1991; Bengio et al., 1994).
Truncated BPTT: In practice, gradients are truncated to a window of steps to reduce memory and compute costs, at the cost of ignoring long-range dependencies beyond step .
LSTM/GRU solution: Long Short-Term Memory networks (Hochreiter & Schmidhuber, 1997) use gating mechanisms to maintain a cell state with additive updates - replacing multiplicative products of weight matrices with additive accumulation, similar to residual connections.
8.2 Implicit Differentiation Preview
For optimisation problems or fixed-point iterations, we sometimes need gradients of implicit functions.
Example: Consider where the optimum satisfies .
By the implicit function theorem:
This allows differentiating through optimisation steps without unrolling them - the basis of MAML (Model-Agnostic Meta-Learning, Finn et al., 2017) and DEQs (Deep Equilibrium Models, Bai et al., 2019).
Full treatment: Implicit differentiation and differentiable optimisation are covered in depth in 05/05-Automatic-Differentiation.
8.3 Straight-Through Estimator and REINFORCE
The discrete problem: When a node in the computation graph applies a discrete operation (argmax, sampling, rounding), the gradient is zero almost everywhere. The chain rule breaks - the graph is not differentiable at these nodes.
Straight-Through Estimator (STE) (Hinton, 2012; Bengio et al., 2013):
Applications:
- Quantisation-aware training (QAT): Simulate INT8 forward, use STE backward. Used in GPTQ, AWQ, and quantised LLM training.
- VQ-VAE: Vector quantisation in the encoder uses STE so gradients flow from decoder back to encoder.
- Binary neural networks: Forward uses sign(x), backward uses STE with gradient identity.
REINFORCE (Williams, 1992): For stochastic nodes, use the log-derivative trick:
This produces an unbiased gradient estimate but with high variance (addressed by baseline subtraction: ). REINFORCE is the foundation of policy gradient methods in RL and is used in RLHF's PPO step.
8.4 Higher-Order Gradients
Second-order gradients arise in:
- Newton's method: requires Hessian (see 02-Jacobians-and-Hessians)
- Meta-learning (MAML): gradient of gradient w.r.t. outer parameters
- Gradient penalty in GAN training:
In PyTorch: Higher-order gradients are computed by running autograd through itself:
# Second derivative of loss w.r.t. input
loss = model(x).sum()
g = torch.autograd.grad(loss, x, create_graph=True)[0]
g2 = torch.autograd.grad(g.sum(), x)[0] # second derivative
create_graph=True tells autograd to build a graph for the gradient computation itself, enabling differentiation through it.
Hessian-vector products (HVPs): As shown in 02, the HVP can be computed in time without forming :
This is the primitive operation behind conjugate gradient and Lanczos methods for curvature estimation.
9. Transformer Backpropagation
9.1 Full Transformer Layer Gradient Flow
A pre-norm transformer layer processes the residual stream as:
Backward through one transformer layer (given ):
GRADIENT FLOW - ONE TRANSFORMER LAYER
FORWARD BACKWARD
x x'' flows in
LN_1(x) x' = x'' + MLP_backward
(x'')
Attn(*) x = x' + Attn_backward
(x')
x' = x + Attn(LN_1(x))
The two residual additions
split the gradient stream
LN_2(x') into parallel paths -
the identity skip carries
MLP(*) the full upstream signal
unchanged.
x'' = x' + MLP(LN_2(x'))
The critical observation: both residual additions in the transformer layer act as gradient splitters. The skip path carries a copy of directly back to without passing through the MLP Jacobian. This gives transformers well-behaved gradients even at layers (GPT-3) or layers (Grok-1).
9.2 LoRA Backward Pass
Low-Rank Adaptation (Hu et al., 2022) reparametrises a weight matrix:
Forward: .
Backward (given ):
Note: is frozen, so - no gradient is computed or stored for . The backward pass only updates and .
Memory savings: For with : gradient storage reduces from to parameters - a reduction in gradient memory for that layer.
DoRA (Liu et al., 2024) further decomposes LoRA into magnitude + direction components, improving fine-tuning quality while preserving the low-rank backward structure.
9.3 Gradient Accumulation
Problem: Large effective batch sizes ( tokens, as in GPT-4 training) don't fit in GPU memory for a single forward-backward pass.
Solution - gradient accumulation:
For each micro-batch b = 1, ..., G:
loss_b = forward(micro_batch_b) / G # scaled loss
backward(loss_b) # accumulates gradients
# gradients are NOT zeroed between micro-batches
optimizer.step() # update once after G micro-batches
optimizer.zero_grad()
The division by ensures the accumulated gradient is mathematically identical to what a single pass with the full batch would produce.
For AI: GPT-3 used gradient accumulation to achieve an effective batch of tokens with hardware that could only process tokens per step.
9.4 Distributed Gradient Synchronisation
In data parallelism, each GPU processes a different micro-batch but shares the same model weights. After the backward pass, gradients must be synchronised:
All-Reduce: Sum gradients across all GPUs and divide by . Implemented via ring-all-reduce (NCCL) for communication that is bandwidth-optimal.
Gradient sharding (ZeRO): DeepSpeed's ZeRO (Zero Redundancy Optimizer) partitions gradient storage across GPUs:
- ZeRO Stage 1: Shard optimiser states -> memory reduction
- ZeRO Stage 2: Shard gradients additionally -> reduction
- ZeRO Stage 3: Shard parameters too -> reduction (linear in GPU count)
For LLaMA-3 70B training: ZeRO Stage 3 across 1024 H100 GPUs allows storing only parameters per GPU - fitting the model in memory.
10. Common Mistakes
| # | Mistake | Why It's Wrong | Fix |
|---|---|---|---|
| 1 | Applying scalar chain rule for vector functions | Scalar chain rule multiplies; multivariate chain rule composes Jacobians. Order matters: , not | Write Jacobians explicitly and multiply left-to-right in the order of composition |
| 2 | Forgetting to sum gradients at fan-out (shared weight) nodes | Each use of a weight contributes a gradient; missing uses means undercounting | Accumulate gradients with += in the backward loop over all uses |
| 3 | Treating as shape-correct without checking | The outer product has shape matching ; but transposing either vector gives wrong shape | Always verify gradient shapes match parameter shapes before implementation |
| 4 | Using sigmoid/tanh in deep networks expecting no vanishing gradients | Their derivatives are bounded by / - products over many layers vanish exponentially | Use ReLU, GELU, or SiLU with proper initialisation; add residual connections |
| 5 | Initialising all weights to zero (or same value) | Symmetry breaking fails: every neuron in a layer computes the same gradient, so they all update identically and remain identical forever | Use Xavier or He initialisation with random values |
| 6 | Skipping the fused softmax + cross-entropy optimisation and computing them separately | Intermediate probabilities overflow/underflow for large logits | Always use the log-sum-exp trick or a library's CrossEntropyLoss (which applies it internally) |
| 7 | Confusing JVP and VJP - using JVP for all gradient computations | JVP costs passes for scalar output; VJP costs per output dimension. For scalar loss, always use VJP (backprop) | Use VJP (backward) for scalar losses; reserve JVP for computing Jacobian rows or directional derivatives |
| 8 | Clipping per-layer gradients independently instead of global norm | Destroys the relative scale of gradients across layers; disrupts Adam's per-parameter adaptive scaling | Clip the global gradient norm: compute across all parameters, scale down if above threshold |
| 9 | Using STE incorrectly in quantisation-aware training - applying STE to continuous weights | STE should only be applied at the discrete rounding step, not to subsequent continuous operations | Apply STE only at the round() or sign() node; propagate real gradients elsewhere |
| 10 | Misunderstanding gradient accumulation - forgetting to scale the loss | Accumulating micro-batch gradients without dividing by produces too large an effective gradient | Divide loss by before backward, or divide accumulated gradients by before the optimiser step |
| 11 | Not using create_graph=True when computing higher-order gradients in PyTorch | Without create_graph=True, the gradient computation is not tracked, so differentiating through it returns None or wrong values | Use create_graph=True in the first torch.autograd.grad() call when second derivatives are needed |
| 12 | Confusing BPTT truncation with sequence truncation | Truncated BPTT still runs the full forward sequence; it only truncates the backward window. Sequence truncation shortens both | These are different operations - read the framework docs to confirm which is applied |