Lesson overview | Previous part | Lesson overview
Number Systems: Part 8: Numerical Precision in Neural Networks to 17. Conceptual Bridge
8. Numerical Precision in Neural Networks
This section connects number system theory to the practical precision requirements of each component in a transformer-based LLM.
8.1 Forward Pass Precision Requirements
Each operation in the forward pass has different sensitivity to numerical precision:
Input embeddings (BF16 sufficient):
- Embedding lookup is a simple table read - no arithmetic precision concern for the lookup itself
- Small errors in embedding vectors () don't compound catastrophically because they pass through layer normalisation early
- BF16 with 7-bit mantissa provides relative precision - adequate
Attention scores (BF16 matmul, FP32 softmax):
- The matmul is computed in BF16 with FP32 accumulation (7.6) - no problem
- The division by is a simple scale; BF16 sufficient
- However: the subsequent softmax requires FP32 for the exponentiation (see 8.4)
Softmax (must use FP32 or numerical tricks):
- overflows BF16 for (because BF16 max)
- Attention logits can reach 50-100 in late training - well below 88, but early layers or unstable training can exceed this
- Solution: compute softmax intermediate values in FP32; or use the numerically stable version (8.4)
Layer normalisation (FP32 for mean/variance):
- Mean and variance : sums over values must be accurate
- With BF16 inputs: accumulate and in FP32 before applying normalisation
- The normalised output can be stored in BF16
FFN activations (BF16 sufficient):
- SwiGLU gate and up projections: standard matmul; BF16 with FP32 accumulation
- Swish activation function: smooth, bounded gradient; no overflow/underflow issues in BF16
- Down projection: same as any matmul; BF16 sufficient
8.2 Backward Pass Precision Requirements
The backward pass is more precision-sensitive than the forward pass because errors in gradients compound through the optimizer over millions of steps:
Gradient magnitudes:
- Typical gradient magnitudes: to
- These are well within BF16 range () but near FP16 minimum ()
- This is exactly why BF16 works and FP16 fails for training without loss scaling
Gradient accumulation (MUST use FP32):
This sum aggregates many small gradient contributions. In BF16:
- Machine epsilon
- Any gradient contribution smaller than of the running sum is lost
- Over a batch of 1024 samples, many individual contributions are small -> silently discarded
- Over millions of training steps, this precision loss causes weights to stop updating -> training stalls
In FP32:
- Machine epsilon
- Contributions as small as of the running sum are preserved
- 65,000\times more precise than BF16 - sufficient for stable training
Gradient clipping (FP32 norm, BF16 application):
- Global gradient norm : compute in FP32 (sum of squares)
- Apply clipping scale factor: can be done in BF16 (simple multiplication)
Loss computation (FP32):
- Cross-entropy loss involves - both log and exp are precision-sensitive
- Must use numerically stable log-sum-exp (8.5) in FP32
8.3 Mixed Precision Training - Complete Picture
The standard mixed-precision training recipe that has been used for virtually all large-scale LLM training since 2020:
MIXED PRECISION TRAINING PIPELINE
=======================================================================
INITIALISATION:
\theta_master = initialise_weights() # FP32 (4 bytes/param)
m = zeros_like(\theta_master) # FP32 Adam momentum
v = zeros_like(\theta_master) # FP32 Adam variance
FOR EACH TRAINING STEP:
+- FORWARD PASS ----------------------------------------------+
| \theta_bf16 = cast(\theta_master, BF16) # FP32 -> BF16 |
| activations = forward(x, \theta_bf16) # BF16 matmul |
| loss = cross_entropy(activations, y) # FP32 (stable) |
+-------------------------------------------------------------+
|
v
+- BACKWARD PASS ---------------------------------------------+
| grads_bf16 = backward(loss, activations) # BF16 gradients |
| grads_fp32 = accumulate(grads_bf16) # FP32 accumulate |
| clip_grad_norm(grads_fp32) # FP32 norm |
+-------------------------------------------------------------+
|
v
+- OPTIMIZER STEP --------------------------------------------+
| m = \beta_1*m + (1-\beta_1)*grads_fp32 # FP32 |
| v = \beta_2*v + (1-\beta_2)*grads_fp32^2 # FP32 |
| m = m / (1 - \beta_1^t) # FP32 |
| v = v / (1 - \beta_2^t) # FP32 |
| \theta_master -= \eta * m / (\sqrtv + \epsilon) # FP32 update |
+-------------------------------------------------------------+
MEMORY PER PARAMETER:
\theta_master (FP32): 4 bytes
Adam m (FP32): 4 bytes
Adam v (FP32): 4 bytes
\theta_bf16 (working): 2 bytes
--------------------------
Total: 14 bytes per parameter
For 70B model: 70B \times 14 = 980 GB (explaining why large model
training requires many GPUs even with mixed precision)
8.4 Numerical Stability of Softmax
The softmax function is one of the most numerically sensitive operations in a transformer:
Naive softmax - the overflow problem:
For in BF16: - exceeds BF16 max () -> overflow to Inf. Even in FP32: overflows. Any attention logit above ~88 causes overflow.
Numerically stable softmax - the max-subtraction trick:
Proof that this is mathematically equivalent:
The factor cancels in numerator and denominator. But numerically, the subtraction ensures that the largest exponent is - no overflow possible. All other exponents are - also safe.
FlashAttention's online softmax:
- Standard softmax requires two passes: (1) compute ; (2) compute and sum
- FlashAttention maintains a running max and running sum, rescaling as new blocks of keys arrive
- This enables computing attention in a single pass through the KV cache, keeping all data in SRAM
8.5 Log-Sum-Exp Trick
The log-sum-exp (LSE) operation appears in cross-entropy loss, softmax computation, and log-normaliser calculation:
The overflow problem: for large , overflows before the log can "undo" it.
The solution - factor out the maximum:
Proof:
The maximum is factored out. All remaining exponents -> no overflow. The subtracted is re-added as a simple sum - no precision loss.
Why this matters for LLM training:
Cross-entropy loss for next-token prediction:
This is computed millions of times per training step (once per token in the batch). If overflows, the loss becomes Inf -> gradient is NaN -> training dies.
PyTorch: torch.logsumexp(z, dim) implements this correctly. Never compute torch.log(torch.sum(torch.exp(z))) directly.
8.6 Numerical Stability of Layer Normalisation
LayerNorm:
where and .
Numerical concerns:
-
Catastrophic cancellation in : when (which is common - normalisation centers data near zero), loses significant bits (3.6)
-
Division by small : if all values are nearly identical, ; division amplifies noise
- The ensures we divide by at least ; this prevents division by zero and limits amplification
- In BF16: is representable (BF16 min normal ), but the precision of when is limited by BF16's 7-bit mantissa
-
FP32 accumulation for mean and variance: even if input is BF16, compute and by accumulating in FP32 before applying the normalisation
RMSNorm - a numerically cleaner alternative:
- No mean subtraction -> no catastrophic cancellation from
- Only requires computing - a sum of positive values (no cancellation)
- Used in LLaMA, Mistral, and most 2024+ transformer architectures
- Slightly cheaper to compute (one fewer reduction) and more numerically stable
8.7 Gradient Vanishing and Exploding - Numerical View
The chain rule through layers:
where is the Jacobian of layer . If each Jacobian has a dominant eigenvalue :
Exploding gradients ():
- , layers: - moderate growth
- , layers: - rapid growth
- In FP32: gradients exceed -> Inf -> training crashes
- In BF16: same range as FP32, but precision errors amplified by large gradients -> unstable updates
Vanishing gradients ():
- , : - still meaningful
- , : - tiny
- In FP16: gradients below underflow to zero -> silent death (no error, no NaN, just zero gradients)
- In BF16: gradients below underflow - this almost never happens in practice
Residual connections - the numerical fix:
The Jacobian is , whose eigenvalues are . Even if is small or negative, the eigenvalue stays near 1 - preventing both explosion and vanishing.
Gradient clipping - the safety net:
- Clip the global gradient norm:
- Prevents Inf from exceeding FP32 range
- Standard practice in LLM training: typical
9. Quantization Mathematics
Quantization is the mathematical process of mapping values from a high-precision format (e.g., FP32, BF16) to a low-precision format (e.g., INT8, INT4, NF4). This section develops the mathematical foundations rigorously.
9.1 Uniform Quantization Formulas - Derivation and Intuition
The fundamental problem: map a continuous (or high-precision) value to an integer (unsigned) or (signed), where is the bit width.
Step 1 - Define the scale factor :
This is the "width" of one quantization bin. For symmetric quantization around zero:
Step 2 - Define the zero-point (asymmetric only):
The zero-point ensures that the floating-point value 0.0 maps exactly to an integer. This is critical because:
- Padding zeros in convolutions must remain exactly zero after quantization
- Skip connections that add 0 must not introduce bias
For symmetric quantization: (or for unsigned).
Step 3 - Quantize (FP -> INT):
Step 4 - Dequantize (INT -> FP approximation):
Quantization error:
For uniform quantization with round-to-nearest, the error is bounded:
Worked example - INT8 symmetric quantization of a weight tensor:
Given: weights = [-0.45, 0.12, -0.03, 0.67, -0.89, 0.34]
Bit width: b = 8 (signed: range [-128, 127])
Step 1: \alpha = max(|-0.89|, |0.67|) = 0.89
s = 2\alpha / (2^8 - 1) = 1.78 / 255 = 0.006980
Step 2: z = 0 (symmetric)
Step 3: Quantize each weight:
-0.45 -> round(-0.45 / 0.006980) = round(-64.47) = -64
0.12 -> round(0.12 / 0.006980) = round(17.19) = 17
-0.03 -> round(-0.03 / 0.006980) = round(-4.30) = -4
0.67 -> round(0.67 / 0.006980) = round(95.99) = 96
-0.89 -> round(-0.89 / 0.006980) = round(-127.51)= -128 <- clips to -128
0.34 -> round(0.34 / 0.006980) = round(48.71) = 49
Step 4: Dequantize:
-64 -> 0.006980 \times (-64) = -0.44672 (error: 0.00328)
17 -> 0.006980 \times 17 = 0.11866 (error: 0.00134)
-4 -> 0.006980 \times (-4) = -0.02792 (error: 0.00208)
96 -> 0.006980 \times 96 = 0.67008 (error: 0.00008)
-128 -> 0.006980 \times (-128)= -0.89344 (error: 0.00344) <- clipping error
49 -> 0.006980 \times 49 = 0.34202 (error: 0.00202)
Maximum error: 0.00344 \approx s/2 = 0.00349 OK
9.2 Signal-to-Quantization-Noise Ratio (SQNR)
SQNR measures the quality of quantization - how much signal is preserved relative to the quantization noise introduced.
Definition:
In dB:
For uniform quantization of a uniformly distributed signal:
The quantisation error is approximately uniformly distributed on with variance:
For a signal uniformly distributed on quantized to bits:
The 6 dB per bit rule: each additional bit of precision adds approximately 6 dB of SQNR.
| Bit Width | SQNR (dB) | Noise Ratio | Typical Use |
|---|---|---|---|
| 2 | 12.0 | 6.25% | Binary/ternary weights |
| 4 | 24.1 | 0.39% | Inference (GPTQ, AWQ) |
| 8 | 48.2 | 0.0015% | PTQ inference standard |
| 16 | 96.3 | Training weights | |
| 32 | 192.7 | Master weights, optimizer |
For non-uniform distributions (which neural network weights follow):
Neural network weights typically follow a bell-shaped distribution (roughly Gaussian with heavy tails after training). This means:
- Most values cluster near zero
- A few outlier values stretch the range
- Uniform quantization "wastes" most levels on the sparse tails
This is why calibration (choosing , to clip outliers) dramatically improves SQNR, and why non-uniform formats like NF4 outperform uniform INT4.
9.3 Block Floating-Point and Group Quantization
Instead of a single pair for an entire tensor, group quantization assigns one pair per block of elements:
BLOCK/GROUP QUANTIZATION
=========================
Full tensor T with N elements, block size G:
Block 0 Block 1 Block 2
+----------+ +----------+ +----------+
| w_0...w_{G-1}| | w_G...w_{2G-1}| | ... |
| s_0, z_0 | | s_1, z_1 | | s_2, z_2 |
+----------+ +----------+ +----------+
Each block has its own scale s_k and zero-point z_k:
s_k = (max(block_k) - min(block_k)) / (2^b - 1)
z_k = round(-min(block_k) / s_k)
Memory overhead:
Without grouping: N\timesb bits + 1\times(32+32) bits for (s,z)
With grouping: N\timesb bits + (N/G)\times(32+32) bits for (s_k,z_k)
Overhead ratio: 64/G bits per element \approx 64/G extra bits per parameter
Common group sizes and their overhead (for INT4 - 4 bits per weight):
| Group Size | Overhead per param | Effective bits/param | Used In |
|---|---|---|---|
| 32 | +1.0 bit (FP16 scale) | 4.5 | QLoRA (NF4) |
| 64 | +0.5 bit | 4.25 | Common middle ground |
| 128 | +0.25 bit | 4.125 | GPTQ default |
| Per-channel | +negligible | ~4.0 | Weight-only quant |
| Per-tensor | +negligible | ~4.0 | Least accurate |
Double quantization (QLoRA innovation):
The scales themselves are FP16 (2 bytes each). With group size , this adds bytes = 0.25 bits per parameter. QLoRA further quantizes these FP16 scales to FP8, reducing the overhead to 0.125 bits per parameter:
9.4 Optimal Quantization Levels - Lloyd-Max Algorithm
Uniform quantization is suboptimal when the input distribution is non-uniform (which it always is for neural network weights). Lloyd-Max quantization finds the reproduction levels and decision boundaries that minimise for a given distribution .
The optimality conditions:
- Nearest-neighbour condition - each value maps to its closest reproduction level:
- Centroid condition - each reproduction level is the conditional mean of values in its bin:
Iterative algorithm:
- Initialise (e.g., uniformly spaced or k-means++ style)
- Compute as midpoints of adjacent
- Recompute as centroids of bins defined by
- Repeat until convergence (MSE decreases monotonically)
Connection to NF4: the 16 NF4 levels (6.1) are the Lloyd-Max optimal quantization levels for a standard normal distribution with bits. The weight tensor is first normalised to zero mean and unit variance, making the normal assumption a good fit.
9.5 Hadamard Transform for Quantization - QuIP and QuaRot
The outlier problem: Transformer activations and certain weight columns contain outlier channels - values 10-100\times larger than the rest. These stretch the quantization range, wasting precision for all other values.
The Hadamard solution: The Hadamard matrix is an orthogonal matrix with entries :
Key property: (orthogonal) - applying and in sequence is an identity.
How it helps quantization:
For a weight matrix and input :
Let and . Then:
- and have much more uniform magnitudes (Hadamard "spreads" outliers across all dimensions)
- Quantizing and yields much lower error than quantizing and directly
This is free at inference (for weights):
- Pre-compute once during quantization
- Apply to activations at runtime (cost: via Fast Walsh-Hadamard Transform, much cheaper than the matmul it enables)
QuIP (Chee et al., 2023): uses random orthogonal matrices (including Hadamard) to "incohere" weight matrices before quantization, achieving near-optimal 2-bit quantization.
QuaRot (Ashkboos et al., 2024): applies Hadamard rotations to both weights and activations, enabling 4-bit quantization of all linear layers including KV cache with minimal accuracy loss.
9.6 Quantization Error Propagation Through Layers
A critical question: if each layer introduces quantization error , how does the total error grow through layers?
Single layer error model:
where is the per-layer weight perturbation and .
Through layers (without activations, for intuition):
Expanding the product (first-order approximation, dropping cross-terms of ):
The total error is a sum of terms, each amplified by the products of weight matrices of the layers above and below.
Practical implications:
| Layer Position | Amplification | Recommendation |
|---|---|---|
| First layer (near input) | Error amplified by all subsequent layers | Use higher precision or skip quantization |
| Middle layers | Moderate amplification | Standard quantization (INT4/INT8) is fine |
| Last layer (near output) | Error directly affects logits | Use higher precision or skip quantization |
| Attention layers | Errors in QK^T amplified by softmax | More sensitive; use INT8 minimum |
Empirical observation from GPTQ, AWQ, and similar methods:
- Quantizing all layers to INT4: moderate perplexity increase (~0.5-1.0 on WikiText)
- Keeping first and last layers in FP16: recovers most of the loss (~0.1-0.3 increase)
- This is consistent with the error amplification analysis above
10. Number Systems for Specific AI Operations
Different operations within a transformer have radically different precision requirements. This section provides a per-operation analysis.
10.1 Embedding Tables
The operation: look up a learned vector for each token in the vocabulary.
Precision analysis:
- Embedding lookup is a table read - no arithmetic; precision only matters for storage
- Embedding vectors are typically small in magnitude ( after weight decay)
- The embedding is immediately fed into layer normalisation, which re-centres and re-scales
Storage formats:
| Format | Memory for V=128K, d=4096 | Quality Impact |
|---|---|---|
| FP32 | 2.0 GB | Baseline |
| BF16 | 1.0 GB | No measurable loss |
| INT8 | 0.5 GB + scales | < 0.1% perplexity increase |
| INT4 | 0.25 GB + scales | Slight degradation; acceptable |
Autoregressive inference: embeddings are accessed one token at a time - memory-bandwidth-bound. Smaller formats directly reduce first-token latency.
Training: always FP32/BF16 (embeddings need to receive precise gradient updates since the vocabulary is large and each token's embedding updates are sparse).
10.2 Attention Score Computation
The operation:
Precision requirements by component:
matmul:
- BF16 \times BF16 with FP32 accumulation is standard and sufficient
- FP8 with FP32 accumulation is emerging (H100 tensor cores support this natively)
- INT8 \times INT8 with INT32 accumulation works for inference with calibrated quantization
scaling:
- A single multiplication by - precision irrelevant
- Often fused into the Q or K projection () to avoid a separate operation
Softmax (see 8.4):
- Must use FP32 intermediate values (or at minimum the numerically stable BF16 version)
- The exponentiation and normalisation are the precision bottleneck
Attention \times V matmul:
- Same precision profile as : BF16 with FP32 accumulation
- The attention weights after softmax are in - well-suited for any format
Multi-query / grouped-query attention (MQA/GQA):
- Fewer K, V heads (typically 1 or 8 groups vs 32-128 query heads)
- Doesn't change precision requirements, but dramatically reduces KV cache memory (see 10.3)
10.3 KV Cache Precision
The KV cache stores past key and value vectors for autoregressive generation. For long-context models, it dominates inference memory:
Memory calculation:
Example - LLaMA 3 70B with 128K context:
- layers, (GQA),
- In BF16: per request
- In INT8: per request
- In INT4: per request
KV cache quantization is uniquely challenging:
- Keys and values are computed on-the-fly during generation; you can't calibrate offline
- Outlier channels in keys cause accuracy degradation under naive quantization
- Per-channel quantization (different scale per head dimension) works well because outlier channels are consistent across tokens
Effective approaches:
- Per-channel INT8: scale per head dimension; negligible accuracy loss for most models
- Per-token INT4 with group size 128: achievable with careful calibration; small perplexity increase
- KIVI (Liu et al., 2024): Key cache in INT2 (per-channel), Value cache in INT2 (per-token), with residual FP16 for recent tokens - 8\times compression with minimal loss
- SqueezeLLM / GEAR: mixed-precision KV cache with outlier tokens kept in higher precision
10.4 Feed-Forward Network (FFN) Precision
The FFN in modern transformers (SwiGLU variant):
Precision analysis:
| Component | Operation | Format | Rationale |
|---|---|---|---|
| Gate projection | BF16/INT8 | Standard matmul | |
| Up projection | BF16/INT8 | Standard matmul | |
| Swish activation | BF16 | Smooth function; no overflow risk | |
| Element-wise multiply | BF16 | Simple multiplication | |
| Down projection | BF16/INT8 | Standard matmul, but more sensitive |
The SwiGLU precision advantage over ReLU:
- ReLU creates exact zeros, which slightly help quantization (zero is perfectly representable)
- SwiGLU produces smooth, non-zero activations - slightly harder to quantize but better model quality
- In practice, the difference in quantization difficulty is negligible
FFN weight quantization sensitivity:
- and : moderately sensitive (they determine what information passes through)
- : most sensitive (final projection directly affects residual stream and all subsequent layers)
- Practical strategy: quantize , to INT4; keep in INT8 or higher
10.5 Optimizer State Precision
The Adam optimizer maintains two state variables per parameter:
Why optimizer states are precision-critical:
-
Exponential moving average requires cumulative precision:
- With , the effective window is steps
- Each step contributes of its gradient to the running average
- In BF16: contributions smaller than relative to the running sum are lost
- Since , individual gradient contributions are frequently lost -> optimizer blindness
- In FP32: contributions as small as relative to running sum are preserved -> 65,000\times more room
-
stores squared gradients:
- Typical gradient: -> squared:
- These tiny values need precise accumulation over thousands of steps
- BF16 would round values to zero for small gradients -> divide-by-zero or extremely large updates
Memory-efficient optimizer approaches:
| Approach | Memory/param | Quality | Used In |
|---|---|---|---|
| Adam FP32 | 12 bytes (\theta+m+v) | Baseline | Standard training |
| Adam BF16 states | 6 bytes | Fails - training diverges | Don't use |
| 8-bit Adam (Dettmers) | 4 bytes | Minimal loss | bitsandbytes |
| Adafactor | 4-8 bytes | Slight loss for some tasks | T5, PaLM |
| LOMO | ~0 bytes | Limited quality | Fine-tuning only |
| GaLore | ~4 bytes | Near-baseline | Memory-constrained training |
8-bit Adam (bitsandbytes) - how it works:
- Stores and in INT8 with dynamic exponent (block-wise FP8 effectively)
- Maintains a per-block scale factor in FP32
- Dequantizes to FP32 for the update step, then re-quantizes
- Round-trip error is small enough not to affect training for models up to 175B parameters
10.6 Gradient Communication Precision
In distributed training, gradients are communicated between GPUs via all-reduce. For a 70B model, each gradient sync transmits in BF16.
Gradient compression techniques:
| Method | Format | Compression | Quality |
|---|---|---|---|
| Full BF16 | BF16 | 1\times | Baseline |
| FP8 all-reduce | FP8 | 2\times | < 0.1% loss |
| INT8 all-reduce | INT8 | 2\times | < 0.1% loss |
| 1-bit Adam/LAMB | 1-bit + error feedback | 16-32\times | Slight loss, converges |
| TopK + INT8 | Sparse INT8 | 10-100\times | Depends on sparsity |
Error feedback mechanism (for aggressive compression):
- Compress gradient:
- Compute error:
- Accumulate error in FP32 buffer:
- Next step, add accumulated error: , then compress
- The accumulated error ensures that over many steps, the true gradient is communicated on average
DeepSpeed ZeRO and precision:
- ZeRO-1/2/3 partitions optimizer states/gradients/parameters across GPUs
- Reduces per-GPU memory but increases communication volume
- Communication precision becomes even more important at scale
- ZeRO++ uses INT8 for all-to-all communication with negligible quality loss
11. Hardware Implementation of Number Systems
Understanding hardware constraints explains why certain number formats exist and dictates the achievable performance for each format.
11.1 Integer ALUs - Simple and Fast
An integer arithmetic logic unit (ALU) for -bit operations:
Gate count and area:
- -bit adder: gates (carry-lookahead) -> completes in gate delays
- -bit multiplier: gates (Wallace tree) -> completes in gate delays
- An INT8 multiplier uses gates; an INT32 multiplier uses gates - 16\times more silicon
Relative cost (approximate, normalised to INT8 multiply):
| Operation | Relative Area | Relative Energy | Relative Latency |
|---|---|---|---|
| INT8 multiply | 1\times | 1\times | 1\times |
| INT16 multiply | 4\times | 4\times | 1.3\times |
| INT32 multiply | 16\times | 16\times | 1.6\times |
| INT4 multiply | 0.25\times | 0.25\times | 0.8\times |
| INT2 multiply | 0.0625\times | 0.0625\times | 0.6\times |
| INT1 XNOR | 0.01\times | 0.01\times | 0.2\times |
This is why 1-bit networks (BitNet) are so attractive for edge deployment: the multiply-accumulate (MAC) unit - the most replicated component - becomes nearly free.
11.2 Floating-Point Units - Complex and Power-Hungry
An FP multiply involves:
- XOR the sign bits (1 gate)
- Add the exponents (integer adder)
- Multiply the mantissas (integer multiplier of mantissa width)
- Normalise and round (shift + round logic)
The mantissa multiplier dominates cost. Since mantissa width determines silicon area:
| Format | Mantissa bits | Multiplier gates | Relative to FP32 |
|---|---|---|---|
| FP32 | 23+1=24 | ~576 | 1\times |
| BF16 | 7+1=8 | ~64 | 0.11\times (9\times cheaper) |
| FP16 | 10+1=11 | ~121 | 0.21\times (5\times cheaper) |
| TF32 | 10+1=11 | ~121 | 0.21\times |
| FP8 E4M3 | 3+1=4 | ~16 | 0.028\times (36\times cheaper) |
| FP8 E5M2 | 2+1=3 | ~9 | 0.016\times (64\times cheaper) |
Key insight: BF16 is only slightly more expensive than FP8 in multiplier area, but the 9\times savings over FP32 is why BF16 replaced FP32 as the default training format. The jump from BF16 to FP8 saves another 4\times.
11.3 Tensor Cores - Systolic Arrays for AI
Tensor cores (NVIDIA terminology; Google calls them "matrix multiply units" in TPUs) are specialised hardware that compute small matrix multiplies in a single clock cycle.
Architecture of a tensor core:
TENSOR CORE (NVIDIA H100 - one core)
====================================
Computes: D = A \times B + C
A: 16\times16 matrix (FP16/BF16/FP8/INT8) input
B: 16\times16 matrix (FP16/BF16/FP8/INT8) input
C: 16\times16 matrix (FP32/FP16) accumulator (read)
D: 16\times16 matrix (FP32/FP16) accumulator (write)
+----------------------------------------+
| 16\times16 array of fused multiply-add |
| units, pipelined as a systolic array |
| |
| For FP16: each unit has an 11-bit |
| multiplier + FP32 adder |
| |
| For INT8: each unit has an 8-bit |
| multiplier + INT32 adder |
| |
| For FP8: each unit has a 4-bit |
| multiplier + FP32 adder |
+----------------------------------------+
Number of tensor cores per GPU generation:
| GPU | Tensor Cores | Formats Supported | Peak TOPS (INT8) |
|---|---|---|---|
| V100 (2017) | 640 | FP16 | 130 |
| A100 (2020) | 432 | FP16, BF16, TF32, INT8, INT4, FP64 | 624 |
| H100 (2022) | 528 | FP16, BF16, TF32, FP8, INT8 | 1,979 |
| B200 (2024) | 640+ | FP16, BF16, TF32, FP8, FP4, INT8 | 4,500+ |
Each new generation adds support for lower-precision formats, enabling higher throughput without proportionally increasing power or die area.
11.4 Memory Bandwidth - The True Bottleneck
For large language model inference, the bottleneck is almost never compute - it's memory bandwidth. Loading model weights from HBM (High-Bandwidth Memory) to the compute units limits throughput.
The arithmetic intensity argument:
For a single token in autoregressive generation:
- Compute: one matrix-vector multiply per layer: FLOPs (for a square weight matrix)
- Memory: load weight matrix:
- Arithmetic intensity
For different formats:
| Format | Bytes | Arithmetic Intensity | Bottleneck |
|---|---|---|---|
| FP32 | 4 | 0.5 FLOP/byte | Memory (severely) |
| BF16 | 2 | 1.0 FLOP/byte | Memory |
| INT8 | 1 | 2.0 FLOP/byte | Memory (still) |
| INT4 | 0.5 | 4.0 FLOP/byte | Memory or compute (depends on GPU) |
H100 SXM5 specs:
- HBM3 bandwidth: 3.35 TB/s
- FP8 tensor core compute: 1,979 TFLOPS
To be compute-bound: need arithmetic intensity FLOP/byte. Even INT8 (2 FLOP/byte) exceeds this threshold, meaning single-token inference is compute-bound at INT8 on H100 for sufficiently small batch sizes. But at batch size 1, the overhead of loading the full weight matrix dominates.
Practical implications:
- Reducing model size from BF16 -> INT4 doesn't just halve memory - it halves the time to load weights, almost doubling inference speed
- This is why quantization provides near-linear speedups for inference - the improvement comes from reduced memory traffic, not faster arithmetic
11.5 Energy Cost per Operation
Energy consumption is a critical constraint for data centre deployment and edge AI:
Energy per arithmetic operation (approximate, 7nm process):
| Operation | Energy (pJ) | Relative |
|---|---|---|
| INT8 MAC | 0.2 | 1\times |
| INT16 MAC | 0.9 | 4.5\times |
| FP16 FMA | 1.0 | 5\times |
| BF16 FMA | 0.8 | 4\times |
| FP32 FMA | 3.7 | 18.5\times |
| FP8 FMA | 0.4 | 2\times |
| INT4 MAC | 0.05 | 0.25\times |
| 1-bit XNOR+popcount | 0.02 | 0.1\times |
| HBM3 DRAM read (64 bytes) | 12.5 | 62.5\times |
Critical observation: DRAM access costs 60\times more energy than an INT8 MAC. For large models:
- Most energy is spent moving data, not computing
- Smaller number formats reduce both memory traffic and compute energy - a double benefit
- This energy analysis is the fundamental driver behind the industry's push toward lower precision
Energy for one forward pass of a 70B model (estimated):
| Format | Compute Energy | Memory Energy | Total | Relative |
|---|---|---|---|---|
| FP32 | ~80 J | ~150 J | ~230 J | 1\times |
| BF16 | ~17 J | ~75 J | ~92 J | 0.40\times |
| INT8 | ~4 J | ~38 J | ~42 J | 0.18\times |
| INT4 | ~1 J | ~19 J | ~20 J | 0.087\times |
INT4 inference uses approximately 11\times less energy than FP32 - enabling deployment at 1/11th the power budget.
11.6 Format Conversion Hardware
GPUs include dedicated hardware for converting between number formats:
Conversions that are "free" (no precision loss, handled by wiring):
- FP32 -> FP64: zero-extend mantissa, adjust exponent bias
- BF16 -> FP32: zero-extend 16 bits of mantissa, copy exponent and sign
- INT8 -> INT32: sign-extend 24 bits
Conversions that require rounding (1-cycle latency on modern GPUs):
- FP32 -> BF16: truncate 16 mantissa bits, apply rounding (3.5)
- FP32 -> FP16: truncate + check for overflow (may produce Inf/NaN)
- FP32 -> FP8: significant truncation + overflow check + rounding
- FP32 -> INT8: multiply by scale, round, clamp
Key conversion in mixed-precision training:
FP32 master weight -> BF16 working copy:
Simply drop the lower 16 mantissa bits and round
This truncation introduces up to 0.39% error per weight
But the FP32 master copy is never lost - it receives the
precise gradient update and only the BF16 working copy
is re-derived fresh each forward pass
12. Precision and the Training Stability Connection
Training instability in large language models is intimately connected to number format limitations. This section explores the mechanisms by which precision failures manifest as training failures.
12.1 Loss Landscape and Precision
The loss landscape of a large neural network is a high-dimensional surface where ( for a 70B model).
Curvature and precision requirements:
The second-order Taylor expansion around the current point gives:
where is the Hessian matrix. The curvature eigenvalues of determine precision requirements:
- High curvature direction (): the loss changes rapidly; even small perturbations from quantization cause large loss changes -> needs high precision
- Low curvature direction (): the loss is flat; quantization noise barely affects loss -> can tolerate low precision
Implications for mixed precision:
- Most directions in parameter space are low-curvature (the loss landscape is approximately flat in most dimensions)
- Only a small fraction of directions are "sharp" - these correspond to critical features
- This is why BF16 training works: the quantization noise ( per weight) is smaller than the gradient step size in most directions, and the few sharp directions are protected by FP32 master weights
12.2 Precision Cliffs - When Training Suddenly Diverges
A precision cliff occurs when training appears stable for many steps, then suddenly produces NaN or Inf losses. The mechanism:
Phase 1 - Slow drift (invisible):
- Gradient accumulation in BF16 silently drops small gradient contributions (8.2)
- Certain weight matrices slowly drift from their optimal values
- Loss appears stable because the drift is small compared to normal training noise
Phase 2 - Amplification:
- Drifted weights cause slightly larger activations in deep layers
- Attention logits grow (12.5), pushing softmax inputs toward overflow
- Gradients start hitting gradient clipping more frequently
- The model is now operating at the edge of numerical stability
Phase 3 - Collapse:
- A single unlucky batch pushes attention logits past BF16/FP32 overflow -> Inf
- Inf propagates through softmax -> NaN
- NaN gradients update all weights -> all weights become NaN -> training dies
This typically happens after 50K-200K training steps, which is why short training runs may appear fine in low precision while long runs fail.
Diagnostic signals:
- Gradient norm spikes (> 10\times baseline) increasing in frequency
- Learning rate warmup completing without issue, then instability at decay phase
- Loss spikes that don't recover to baseline
12.3 Stochastic Rounding - A Precision Amplifier
When rounding to the nearest representable value in format :
- Round-to-nearest-even (RNE): always rounds to the same value -> systematic bias when many values round in the same direction
- Stochastic rounding: rounds up with probability , otherwise rounds down
Mathematical property of stochastic rounding:
Stochastic rounding is unbiased - the expected value of the rounded result equals the true value. Over many steps, the rounding errors cancel out on average:
This means that even if individual gradient contributions are too small to represent in BF16, stochastic rounding ensures they contribute to the weight update over time.
Convergence guarantee: With RNE, gradients smaller than ULP are always rounded to zero - the weight provably never updates if all gradients are this small. With stochastic rounding, even infinitesimally small gradients have a non-zero probability of causing an update.
Hardware support:
- NVIDIA H100 supports stochastic rounding for FP8 operations
- This is one reason why FP8 training is feasible despite the very low precision - stochastic rounding compensates for the coarse representation
12.4 Adam Optimizer Numerical Error Analysis
The Adam update rule:
Numerical failure mode 1 - too small:
- Standard
- If (parameter with near-zero gradients): update
- For and : update - a massive weight update
- This is why should be or even for BF16 training
Numerical failure mode 2 - overflow in low precision:
- accumulates ; if gradients are large (), then , fine
- But if gradient clipping isn't applied and (loss spike): -> grows rapidly
- In FP32: max - fine
- In 8-bit optimizer states: overflow can occur, corrupting the optimizer state silently
Numerical failure mode 3 - bias correction precision:
- Bias correction factor for :
- :
- :
- :
- Early in training (), dividing by amplifies by 1000\times - this amplification can push values out of representable range in low precision
12.5 Attention Logit Growth
A common failure mode in large transformer training:
The mechanism:
- Attention logits grow slowly during training as the model learns sharper attention patterns
- The softmax temperature effectively decreases: with decreasing
- As logits grow, the softmax output becomes more peaked ("spiky")
- Eventually, a single attention logit dominates:
- The gradient of softmax when it's nearly one-hot is very small -> vanishing gradients for attention
- Meanwhile, the large logit values approach the overflow boundary of the number format
Numerical progression (BF16 example):
| Training Step | Max Logit | Softmax Max | Gradient Magnitude | Risk |
|---|---|---|---|---|
| 10K | 5.0 | 0.73 | Normal | Safe |
| 50K | 15.0 | 0.995 | Reduced | Low |
| 100K | 30.0 | 0.9999 | Very small | Medium |
| 200K | 60.0 | ~1.0 | Near zero | High |
| 250K | 89+ | Overflow -> NaN | N/A | Crash |
Mitigation strategies:
- QK-Norm (Dehghani et al., 2023): normalise and before computing attention: . This bounds logit magnitudes by regardless of training step.
- Logit capping: clamp attention logits to before softmax ( in PaLM). Simple but effective.
- Softmax temperature: explicitly maintain temperature: where is a learnable parameter, constrained to be positive.
13. Practical Guide - Choosing Number Formats
13.1 Decision Framework for Training
TRAINING FORMAT DECISION TREE
==============================
Start: What is your model size?
+-- < 1B parameters (fits in one GPU)
| +-- Research/prototyping -> FP32 (simplest, no mixed-precision bugs)
| +-- Production training -> BF16 mixed precision (2\times speedup)
|
+-- 1B-13B parameters
| +-- Full training -> BF16 mixed precision (mandatory for speed)
| +-- Fine-tuning -> QLoRA (NF4 base + BF16 adapters)
|
+-- 13B-70B parameters
| +-- Full training -> BF16 + ZeRO-3 across multiple GPUs
| +-- Fine-tuning -> QLoRA NF4 (fits on single 80GB GPU)
| +-- Continued PT -> BF16 + gradient checkpointing
|
+-- 70B+ parameters
+-- Full training -> BF16 + 3D parallelism (need cluster)
+-- FP8 training -> if H100+ hardware (emerging 2024+)
+-- Fine-tuning -> QLoRA NF4 (4-bit base + 16-bit LoRA)
KEY RULES:
1. Master weights ALWAYS in FP32
2. Optimizer states ALWAYS in FP32 (or 8-bit Adam)
3. Forward/backward: BF16 (or FP8 on H100+)
4. Loss computation: FP32
5. Gradient accumulation: FP32
13.2 Decision Framework for Inference
INFERENCE FORMAT DECISION TREE
===============================
Start: What is your latency/quality tradeoff?
+-- Maximum quality (no degradation acceptable)
| +-- BF16 / FP16 (same as training format)
|
+-- Balanced quality/speed (< 1% perplexity increase)
| +-- NVIDIA GPU -> INT8 (W8A8) with SmoothQuant
| +-- Apple Silicon -> INT8 (W8A8) via MLX
| +-- CPU -> INT8 with ONNX Runtime
|
+-- High compression (1-3% perplexity increase, 4\times speedup)
| +-- Batch inference -> GPTQ INT4 (W4A16)
| +-- Real-time serving -> AWQ INT4 (W4A16)
| +-- Memory-constrained -> GGML/GGUF Q4_K_M
|
+-- Maximum compression (edge/mobile deployment)
| +-- INT3 (W3A16) -> aggressive but usable with GPTQ
| +-- INT2 (W2A16) -> significant quality loss; only for small models
| +-- 1.58-bit (ternary) -> BitNet models (trained from scratch only)
|
+-- Speculative/KV cache optimisation
+-- KV cache INT8 -> per-channel quantization
+-- KV cache INT4 -> with group quantization + recent-token FP16
13.3 Per-Layer Sensitivity Table
Different layers have different tolerance for quantization. Based on empirical findings from AWQ, GPTQ, SqueezeLLM:
| Layer Type | INT8 | INT4 | INT3 | INT2 | Notes |
|---|---|---|---|---|---|
| Embedding | OK | OK | WARNING | NO | Vocabulary coverage matters |
| Attention Q, K projections | OK | OK | WARNING | NO | Attention pattern quality degrades |
| Attention V projection | OK | OK | OK | WARNING | Slightly more robust than Q, K |
| Attention output projection | OK | OK | WARNING | NO | Affects residual stream |
| FFN gate & up projection | OK | OK | OK | WARNING | Relatively robust |
| FFN down projection | OK | WARNING | NO | NO | Most sensitive - affects residual |
| LM head (final projection) | OK | WARNING | NO | NO | Directly affects token probabilities |
| Layer norm / RMSNorm | OK | NO | NO | NO | Keep in FP32 or BF16 always |
Legend: OK = safe, WARNING = measurable degradation, NO = significant quality loss
13.4 Quantization Quality vs Bit Width
Empirical perplexity results (approximate, varies by model and method):
| Bit Width | Method | 7B Model PPL | 13B Model PPL | 70B Model PPL |
|---|---|---|---|---|
| 16 (BF16) | Baseline | 5.68 | 5.09 | 3.56 |
| 8 (INT8) | SmoothQuant | 5.69 (+0.01) | 5.10 (+0.01) | 3.56 (+0.00) |
| 4 (INT4) | GPTQ | 5.85 (+0.17) | 5.20 (+0.11) | 3.60 (+0.04) |
| 4 (NF4) | QLoRA/bitsandbytes | 5.80 (+0.12) | 5.16 (+0.07) | 3.58 (+0.02) |
| 4 (INT4) | AWQ | 5.78 (+0.10) | 5.15 (+0.06) | 3.58 (+0.02) |
| 3 (INT3) | GPTQ | 6.29 (+0.61) | 5.51 (+0.42) | 3.72 (+0.16) |
| 2 (INT2) | QuIP# | 7.85 (+2.17) | 6.43 (+1.34) | 4.15 (+0.59) |
Key observation: larger models are more tolerant of quantization - a 70B INT4 model often outperforms a 13B BF16 model. This is because larger models have more redundancy in their weight matrices.
14. Common Mistakes and Misconceptions
| # | Mistake | Why It's Wrong | Correct Understanding |
|---|---|---|---|
| 1 | "FP16 and BF16 are interchangeable" | FP16 overflows at 65504; BF16 at 3.4\times10^3^8. FP16 needs loss scaling; BF16 doesn't. | BF16 matches FP32 range; FP16 does not. Choose BF16 for training. |
| 2 | "More bits always means better quality" | 70B at INT4 outperforms 13B at FP16 on most benchmarks. | Total model capacity matters more than per-parameter precision. |
| 3 | "Quantization only saves memory" | Quantization also reduces memory bandwidth (the actual bottleneck) -> direct speedup. | Memory savings -> bandwidth savings -> latency reduction. |
| 4 | "INT8 inference loses accuracy" | With proper calibration (SmoothQuant, GPTQ), INT8 is lossless for nearly all models. | INT8 is the "free lunch" of inference optimisation. |
| 5 | "I should quantize my model during training" | Quantization-aware training is expensive and unnecessary for PTQ with 4+ bits. | Use PTQ unless deploying at INT2 or lower. |
| 6 | "FP32 master weights waste memory" | Without FP32 master weights, BF16 training diverges after ~100K steps. | FP32 masters are essential, not optional. |
| 7 | "Stochastic rounding is just noisy" | SR is an unbiased estimator that enables convergence in low precision. | SR is mathematically principled, not a hack. |
| 8 | "FLOPS determines inference speed" | Memory bandwidth is the bottleneck for autoregressive LLM inference. | Optimise for memory bandwidth, not compute. |
| 9 | "Quantizing KV cache is dangerous" | Per-channel INT8 KV cache quantization is nearly lossless. | KV cache quantization is safe and critical for long-context. |
| 10 | "All layers should use the same precision" | Output projection and FFN down projection are much more sensitive. | Use mixed-precision per-layer quantization. |
15. Exercises
Exercise 1: IEEE 754 Encoding
Encode the decimal value in IEEE 754 FP32 format. Show all steps: sign bit, binary conversion, normalisation, biased exponent, and mantissa. Verify by decoding your result back to decimal.
Exercise 2: Quantization Calculation
Given a weight tensor , compute the INT8 symmetric quantization:
- (a) Calculate the scale factor
- (b) Quantize all values to INT8
- (c) Dequantize and compute the maximum absolute error
- (d) Compute the SQNR in dB
Exercise 3: BF16 Precision Limits
Two FP32 values: and .
- (a) Add in FP32. What is the result?
- (b) Cast both values to BF16, add, and cast back to FP32. What is the result?
- (c) Repeat the BF16 addition 1000 times (adding to a running sum starting at ). Compare with the FP32 result. What is the relative error?
- (d) How does this relate to gradient accumulation in training?
Exercise 4: Memory Budget
You have a single NVIDIA A100 80GB GPU. Calculate the maximum model size (in parameters) you can:
- (a) Train with full FP32 (weights + Adam optimizer states)
- (b) Train with BF16 mixed precision (FP32 master + BF16 working + FP32 Adam)
- (c) Fine-tune with QLoRA (NF4 base + BF16 LoRA with rank 16, hidden dim 4096, 32 layers \times 4 projection matrices)
- (d) Serve for inference in INT4 (weights only, no optimizer states)
Exercise 5: Softmax Stability
Given attention logits :
- (a) Compute for each value. Can these be represented in FP32? In BF16?
- (b) Apply the max-subtraction trick and recompute. Show that no overflow occurs.
- (c) Compute the final softmax probabilities.
- (d) What would happen if ? How does log-sum-exp help?
Exercise 6: Error Propagation
A 3-layer network with weight matrices (each , where ). Each weight is quantized to INT8 with scale . Assuming input and (operator norm):
- (a) Bound the quantization error for each layer.
- (b) Using the first-order error approximation from 9.6, bound the total output error.
- (c) At what bit width does the output error drop below 1% of the output magnitude?
Exercise 7: Hardware Arithmetic Intensity
For a decoder-only transformer with layers, , (LLaMA-7B architecture):
- (a) Calculate the total FLOPs for generating one token (attention + FFN, ignore KV cache read).
- (b) Calculate the total weight bytes loaded from memory in INT4, INT8, and BF16.
- (c) Compute the arithmetic intensity for each format.
- (d) Given the H100's 3.35 TB/s bandwidth and 1979 TOPS (INT8), determine which format is compute-bound vs memory-bound.
Exercise 8: Stochastic Rounding Simulation
Implement a simple stochastic rounding function in Python. Given a "true" gradient and a simulated BF16 format where values snap to multiples of 0.0078125 (= , the BF16 ULP near 1.0):
- (a) What does round-to-nearest produce for ? (Hint: is below half the ULP.)
- (b) Implement stochastic rounding. Over 10,000 steps of accumulating into a running sum, what is the expected sum? Simulate and verify.
- (c) Compare with deterministic RNE accumulation over the same 10,000 steps. What is the relative error?
16. Why This Matters for AI/ML
| Concept | Where It Appears | Why You Need It |
|---|---|---|
| IEEE 754 FP32/FP64 | Loss computation, optimizer states | Understanding overflow/underflow prevents silent training failures |
| BF16 | Default training format (2020+) | Know when and why to use it; understand its precision limits |
| FP8 | Next-gen training (H100+) | Critical for cost-efficient training at scale |
| INT8 quantization | Standard inference optimisation | 2\times speedup with no quality loss - mandatory knowledge |
| INT4 quantization | High-compression inference | Enables serving 70B models on consumer hardware |
| NF4 | QLoRA fine-tuning | Enables 70B fine-tuning on a single GPU |
| Softmax stability | Every forward pass | Prevents the most common NaN crash in transformers |
| Mixed precision | Every training pipeline | Halves memory, doubles speed - used in all modern training |
| KV cache formats | Long-context inference | Enables 128K+ context without OOM |
| Error propagation | Model debugging | Understanding why certain layers can/cannot be quantized aggressively |
| Hardware constraints | Deployment planning | Choose the right format for your target hardware |
| Stochastic rounding | FP8 training, low-precision research | Key enabler for sub-8-bit training |
17. Conceptual Bridge
This chapter covered the representation of numbers - how mathematical quantities are encoded in finite hardware. The key insight is that every number format is a tradeoff between range, precision, and cost, and AI engineering is the art of choosing the right tradeoff for each part of the pipeline.
CONCEPTUAL FLOW
===============
Number Systems (this chapter)
+-- How are values represented in hardware?
+-- What are the precision limits?
+-- How do these limits affect AI systems?
|
v
Sets and Logic (next chapter)
+-- How do we formalise collections of objects?
+-- What are the logical foundations of proofs?
+-- How does set theory underpin probability and statistics?
|
v
Functions and Mappings
+-- How do we formalise input-output relationships?
+-- What properties must a function have?
+-- Neural networks as compositions of functions
|
v
Summation and Product Notation
+-- Compact notation for sums and products
+-- Used everywhere: loss functions, gradients, attention
+-- Connection to vectorised computation
How number systems connect to every subsequent chapter:
- Linear Algebra: matrix operations are chains of multiply-accumulate; the number format determines both the speed and accuracy of every matrix operation in the model
- Calculus: derivatives are computed via finite differences or analytical rules; numerical precision determines whether gradient computation is meaningful or noise
- Probability: probability values in are represented as floating-point numbers; underflow in small probabilities () causes silent failures unless you use log-probabilities
- Optimisation: every optimiser update involves division, square roots, and accumulation - all precision-sensitive operations
- Information Theory: cross-entropy loss involves and - the most overflow/underflow-prone elementary functions
The fundamental lesson: a deep understanding of number systems is not optional for AI practitioners. It's the foundation that determines whether your model trains at all, how much it costs, how fast it runs, and whether you can deploy it on your target hardware.
<- Back to Mathematical Foundations | Next: Sets and Logic ->