Part 2Math for LLMs

Number Systems: Part 2 - Numerical Precision In Neural Networks To 17 Conceptual Bridge

Mathematical Foundations / Number Systems

Private notes
0/8000

Notes stay private to your browser until account sync is configured.

Part 2
26 min read18 headingsSplit lesson page

Lesson overview | Previous part | Lesson overview

Number Systems: Part 8: Numerical Precision in Neural Networks to 17. Conceptual Bridge

8. Numerical Precision in Neural Networks

This section connects number system theory to the practical precision requirements of each component in a transformer-based LLM.

8.1 Forward Pass Precision Requirements

Each operation in the forward pass has different sensitivity to numerical precision:

Input embeddings (BF16 sufficient):

  • Embedding lookup is a simple table read - no arithmetic precision concern for the lookup itself
  • Small errors in embedding vectors (<1%< 1\%) don't compound catastrophically because they pass through layer normalisation early
  • BF16 with 7-bit mantissa provides 0.78%\sim 0.78\% relative precision - adequate

Attention scores QKT/dkQK^T / \sqrt{d_k} (BF16 matmul, FP32 softmax):

  • The matmul QKTQK^T is computed in BF16 with FP32 accumulation (7.6) - no problem
  • The division by dk\sqrt{d_k} is a simple scale; BF16 sufficient
  • However: the subsequent softmax requires FP32 for the exponentiation (see 8.4)

Softmax (must use FP32 or numerical tricks):

  • softmax(z)i=ezi/jezj\text{softmax}(z)_i = e^{z_i} / \sum_j e^{z_j}
  • eze^{z} overflows BF16 for z>88z > 88 (because e881.65×1038e^{88} \approx 1.65 \times 10^{38} \approx BF16 max)
  • Attention logits can reach 50-100 in late training - well below 88, but early layers or unstable training can exceed this
  • Solution: compute softmax intermediate values in FP32; or use the numerically stable version (8.4)

Layer normalisation (FP32 for mean/variance):

  • Mean μ=1nxi\mu = \frac{1}{n}\sum x_i and variance σ2=1n(xiμ)2\sigma^2 = \frac{1}{n}\sum (x_i - \mu)^2: sums over nn values must be accurate
  • With BF16 inputs: accumulate μ\mu and σ2\sigma^2 in FP32 before applying normalisation
  • The normalised output (xμ)/σ2+ϵ(x - \mu)/\sqrt{\sigma^2 + \epsilon} can be stored in BF16

FFN activations (BF16 sufficient):

  • SwiGLU gate and up projections: standard matmul; BF16 with FP32 accumulation
  • Swish activation function: smooth, bounded gradient; no overflow/underflow issues in BF16
  • Down projection: same as any matmul; BF16 sufficient

8.2 Backward Pass Precision Requirements

The backward pass is more precision-sensitive than the forward pass because errors in gradients compound through the optimizer over millions of steps:

Gradient magnitudes:

  • Typical gradient magnitudes: 10410^{-4} to 10710^{-7}
  • These are well within BF16 range (min1.2×1038\text{min} \approx 1.2 \times 10^{-38}) but near FP16 minimum (6.1×1056.1 \times 10^{-5})
  • This is exactly why BF16 works and FP16 fails for training without loss scaling

Gradient accumulation (MUST use FP32):

LW=batchxTδ\frac{\partial L}{\partial W} = \sum_{\text{batch}} x^T \delta

This sum aggregates many small gradient contributions. In BF16:

  • Machine epsilon =7.8×103= 7.8 \times 10^{-3}
  • Any gradient contribution smaller than 0.78%0.78\% of the running sum is lost
  • Over a batch of 1024 samples, many individual contributions are small -> silently discarded
  • Over millions of training steps, this precision loss causes weights to stop updating -> training stalls

In FP32:

  • Machine epsilon =1.19×107= 1.19 \times 10^{-7}
  • Contributions as small as 0.0000119%0.0000119\% of the running sum are preserved
  • 65,000\times more precise than BF16 - sufficient for stable training

Gradient clipping (FP32 norm, BF16 application):

  • Global gradient norm g2=gi2\|\mathbf{g}\|_2 = \sqrt{\sum g_i^2}: compute in FP32 (sum of squares)
  • Apply clipping scale factor: can be done in BF16 (simple multiplication)

Loss computation (FP32):

  • Cross-entropy loss involves log(softmax(z))\log(\text{softmax}(z)) - both log and exp are precision-sensitive
  • Must use numerically stable log-sum-exp (8.5) in FP32

8.3 Mixed Precision Training - Complete Picture

The standard mixed-precision training recipe that has been used for virtually all large-scale LLM training since 2020:

MIXED PRECISION TRAINING PIPELINE
=======================================================================

INITIALISATION:
  \theta_master = initialise_weights()           # FP32 (4 bytes/param)
  m = zeros_like(\theta_master)                  # FP32 Adam momentum
  v = zeros_like(\theta_master)                  # FP32 Adam variance

FOR EACH TRAINING STEP:

  +- FORWARD PASS ----------------------------------------------+
  |  \theta_bf16 = cast(\theta_master, BF16)         # FP32 -> BF16       |
  |  activations = forward(x, \theta_bf16)       # BF16 matmul       |
  |  loss = cross_entropy(activations, y)    # FP32 (stable)     |
  +-------------------------------------------------------------+
                           |
                           v
  +- BACKWARD PASS ---------------------------------------------+
  |  grads_bf16 = backward(loss, activations) # BF16 gradients  |
  |  grads_fp32 = accumulate(grads_bf16)       # FP32 accumulate |
  |  clip_grad_norm(grads_fp32)                # FP32 norm       |
  +-------------------------------------------------------------+
                           |
                           v
  +- OPTIMIZER STEP --------------------------------------------+
  |  m = \beta_1*m + (1-\beta_1)*grads_fp32             # FP32           |
  |  v = \beta_2*v + (1-\beta_2)*grads_fp32^2            # FP32           |
  |  m = m / (1 - \beta_1^t)                        # FP32           |
  |  v = v / (1 - \beta_2^t)                        # FP32           |
  |  \theta_master -= \eta * m / (\sqrtv + \epsilon)             # FP32 update   |
  +-------------------------------------------------------------+

MEMORY PER PARAMETER:
  \theta_master (FP32):     4 bytes
  Adam m (FP32):       4 bytes
  Adam v (FP32):       4 bytes
  \theta_bf16 (working):    2 bytes
  --------------------------
  Total:              14 bytes per parameter

  For 70B model: 70B \times 14 = 980 GB (explaining why large model
  training requires many GPUs even with mixed precision)

8.4 Numerical Stability of Softmax

The softmax function is one of the most numerically sensitive operations in a transformer:

Naive softmax - the overflow problem:

softmax(z)i=ezijezj\text{softmax}(z)_i = \frac{e^{z_i}}{\sum_j e^{z_j}}

For zi=100z_i = 100 in BF16: e1002.69×1043e^{100} \approx 2.69 \times 10^{43} - exceeds BF16 max (3.39×10383.39 \times 10^{38}) -> overflow to Inf. Even in FP32: e894.5×1038e^{89} \approx 4.5 \times 10^{38} overflows. Any attention logit above ~88 causes overflow.

Numerically stable softmax - the max-subtraction trick:

softmax(z)i=ezimjezjm,m=maxj(zj)\text{softmax}(z)_i = \frac{e^{z_i - m}}{\sum_j e^{z_j - m}}, \quad m = \max_j(z_j)

Proof that this is mathematically equivalent:

ezimjezjm=eziemjezjem=eziememjezj=ezijezj\frac{e^{z_i - m}}{\sum_j e^{z_j - m}} = \frac{e^{z_i} \cdot e^{-m}}{\sum_j e^{z_j} \cdot e^{-m}} = \frac{e^{z_i} \cdot \cancel{e^{-m}}}{\cancel{e^{-m}} \cdot \sum_j e^{z_j}} = \frac{e^{z_i}}{\sum_j e^{z_j}}

The factor eme^{-m} cancels in numerator and denominator. But numerically, the subtraction ensures that the largest exponent is ezmaxzmax=e0=1e^{z_{\max} - z_{\max}} = e^0 = 1 - no overflow possible. All other exponents are ezim1e^{z_i - m} \leq 1 - also safe.

FlashAttention's online softmax:

  • Standard softmax requires two passes: (1) compute m=max(z)m = \max(z); (2) compute ezime^{z_i - m} and sum
  • FlashAttention maintains a running max and running sum, rescaling as new blocks of keys arrive
  • This enables computing attention in a single pass through the KV cache, keeping all data in SRAM

8.5 Log-Sum-Exp Trick

The log-sum-exp (LSE) operation appears in cross-entropy loss, softmax computation, and log-normaliser calculation:

LSE(z)=logiezi\text{LSE}(z) = \log \sum_i e^{z_i}

The overflow problem: for large ziz_i, ezie^{z_i} overflows before the log can "undo" it.

The solution - factor out the maximum:

LSE(z)=m+logiezim,m=maxi(zi)\text{LSE}(z) = m + \log \sum_i e^{z_i - m}, \quad m = \max_i(z_i)

Proof:

logiezi=log(emiezim)=log(em)+logiezim=m+logiezim\log \sum_i e^{z_i} = \log \left(e^m \sum_i e^{z_i - m}\right) = \log(e^m) + \log \sum_i e^{z_i - m} = m + \log \sum_i e^{z_i - m}

The maximum mm is factored out. All remaining exponents zim0z_i - m \leq 0 -> no overflow. The subtracted mm is re-added as a simple sum - no precision loss.

Why this matters for LLM training:

Cross-entropy loss for next-token prediction:

L=logP(ttarget)=logsoftmax(z)target=(ztargetLSE(z))L = -\log P(t_{\text{target}}) = -\log \text{softmax}(z)_{\text{target}} = -(z_{\text{target}} - \text{LSE}(z)) =LSE(z)ztarget= \text{LSE}(z) - z_{\text{target}}

This is computed millions of times per training step (once per token in the batch). If LSE(z)\text{LSE}(z) overflows, the loss becomes Inf -> gradient is NaN -> training dies.

PyTorch: torch.logsumexp(z, dim) implements this correctly. Never compute torch.log(torch.sum(torch.exp(z))) directly.

8.6 Numerical Stability of Layer Normalisation

LayerNorm:

y=xμσ2+ϵγ+βy = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta

where μ=1nxi\mu = \frac{1}{n}\sum x_i and σ2=1n(xiμ)2\sigma^2 = \frac{1}{n}\sum (x_i - \mu)^2.

Numerical concerns:

  1. Catastrophic cancellation in xμx - \mu: when xiμx_i \approx \mu (which is common - normalisation centers data near zero), xiμx_i - \mu loses significant bits (3.6)

  2. Division by small σ\sigma: if all values are nearly identical, σ0\sigma \approx 0; division amplifies noise

    • The ϵ=105\epsilon = 10^{-5} ensures we divide by at least 1050.00316\sqrt{10^{-5}} \approx 0.00316; this prevents division by zero and limits amplification
    • In BF16: ϵ=105\epsilon = 10^{-5} is representable (BF16 min normal 1.2×1038\approx 1.2 \times 10^{-38}), but the precision of σ2+ϵ\sigma^2 + \epsilon when σ20\sigma^2 \approx 0 is limited by BF16's 7-bit mantissa
  3. FP32 accumulation for mean and variance: even if input xx is BF16, compute μ\mu and σ2\sigma^2 by accumulating in FP32 before applying the normalisation

RMSNorm - a numerically cleaner alternative:

y=x1nxi2+ϵγy = \frac{x}{\sqrt{\frac{1}{n}\sum x_i^2 + \epsilon}} \cdot \gamma
  • No mean subtraction -> no catastrophic cancellation from xμx - \mu
  • Only requires computing RMS(x)=mean(x2)\text{RMS}(x) = \sqrt{\text{mean}(x^2)} - a sum of positive values (no cancellation)
  • Used in LLaMA, Mistral, and most 2024+ transformer architectures
  • Slightly cheaper to compute (one fewer reduction) and more numerically stable

8.7 Gradient Vanishing and Exploding - Numerical View

The chain rule through LL layers:

Lx0=l=1Lxlxl1=l=1LJl\frac{\partial L}{\partial x_0} = \prod_{l=1}^{L} \frac{\partial x_l}{\partial x_{l-1}} = \prod_{l=1}^{L} J_l

where JlJ_l is the Jacobian of layer ll. If each Jacobian has a dominant eigenvalue λ\lambda:

Lx0λL\left\|\frac{\partial L}{\partial x_0}\right\| \sim \lambda^L

Exploding gradients (λ>1\lambda > 1):

  • λ=1.01\lambda = 1.01, L=200L = 200 layers: 1.012007.321.01^{200} \approx 7.32 - moderate growth
  • λ=1.1\lambda = 1.1, L=200L = 200 layers: 1.12001.9×1081.1^{200} \approx 1.9 \times 10^{8} - rapid growth
  • In FP32: gradients exceed 3.4×10383.4 \times 10^{38} -> Inf -> training crashes
  • In BF16: same range as FP32, but precision errors amplified by large gradients -> unstable updates

Vanishing gradients (λ<1\lambda < 1):

  • λ=0.99\lambda = 0.99, L=200L = 200: 0.992000.1340.99^{200} \approx 0.134 - still meaningful
  • λ=0.9\lambda = 0.9, L=200L = 200: 0.92007×10100.9^{200} \approx 7 \times 10^{-10} - tiny
  • In FP16: gradients below 6.1×1056.1 \times 10^{-5} underflow to zero -> silent death (no error, no NaN, just zero gradients)
  • In BF16: gradients below 1.2×10381.2 \times 10^{-38} underflow - this almost never happens in practice

Residual connections - the numerical fix:

xl=xl1+fl(xl1)x_l = x_{l-1} + f_l(x_{l-1}) xlxl1=I+flxl1\frac{\partial x_l}{\partial x_{l-1}} = I + \frac{\partial f_l}{\partial x_{l-1}}

The Jacobian is I+JfI + J_f, whose eigenvalues are 1+λf1 + \lambda_f. Even if λf\lambda_f is small or negative, the eigenvalue stays near 1 - preventing both explosion and vanishing.

Gradient clipping - the safety net:

  • Clip the global gradient norm: g^=gmin ⁣(1,Cg)\hat{g} = g \cdot \min\!\left(1, \frac{C}{\|g\|}\right)
  • Prevents Inf from exceeding FP32 range
  • Standard practice in LLM training: C=1.0C = 1.0 typical

9. Quantization Mathematics

Quantization is the mathematical process of mapping values from a high-precision format (e.g., FP32, BF16) to a low-precision format (e.g., INT8, INT4, NF4). This section develops the mathematical foundations rigorously.

9.1 Uniform Quantization Formulas - Derivation and Intuition

The fundamental problem: map a continuous (or high-precision) value x[xmin,xmax]x \in [x_{\min}, x_{\max}] to an integer xq{0,1,,2b1}x_q \in \{0, 1, \ldots, 2^b - 1\} (unsigned) or xq{2b1,,2b11}x_q \in \{-2^{b-1}, \ldots, 2^{b-1}-1\} (signed), where bb is the bit width.

Step 1 - Define the scale factor ss:

s=xmaxxmin2b1s = \frac{x_{\max} - x_{\min}}{2^b - 1}

This is the "width" of one quantization bin. For symmetric quantization around zero:

s=2max(xmin,xmax)2b1s = \frac{2 \cdot \max(|x_{\min}|, |x_{\max}|)}{2^b - 1}

Step 2 - Define the zero-point zz (asymmetric only):

z=round ⁣(xmins)z = \text{round}\!\left(-\frac{x_{\min}}{s}\right)

The zero-point ensures that the floating-point value 0.0 maps exactly to an integer. This is critical because:

  • Padding zeros in convolutions must remain exactly zero after quantization
  • Skip connections that add 0 must not introduce bias

For symmetric quantization: z=0z = 0 (or z=2b1z = 2^{b-1} for unsigned).

Step 3 - Quantize (FP -> INT):

xq=clamp ⁣(round ⁣(xs+z),  0,  2b1)x_q = \text{clamp}\!\left(\text{round}\!\left(\frac{x}{s} + z\right),\; 0,\; 2^b - 1\right)

Step 4 - Dequantize (INT -> FP approximation):

x^=s(xqz)\hat{x} = s \cdot (x_q - z)

Quantization error:

ϵq=xx^=xs(round(x/s+z)z)\epsilon_q = x - \hat{x} = x - s \cdot (\text{round}(x/s + z) - z)

For uniform quantization with round-to-nearest, the error is bounded:

ϵqs2=xmaxxmin2(2b1)|\epsilon_q| \leq \frac{s}{2} = \frac{x_{\max} - x_{\min}}{2(2^b - 1)}

Worked example - INT8 symmetric quantization of a weight tensor:

Given:  weights = [-0.45, 0.12, -0.03, 0.67, -0.89, 0.34]
Bit width: b = 8 (signed: range [-128, 127])

Step 1: \alpha = max(|-0.89|, |0.67|) = 0.89
        s = 2\alpha / (2^8 - 1) = 1.78 / 255 = 0.006980

Step 2: z = 0 (symmetric)

Step 3: Quantize each weight:
  -0.45 -> round(-0.45 / 0.006980) = round(-64.47) = -64
   0.12 -> round(0.12 / 0.006980)  = round(17.19)  =  17
  -0.03 -> round(-0.03 / 0.006980) = round(-4.30)  =  -4
   0.67 -> round(0.67 / 0.006980)  = round(95.99)  =  96
  -0.89 -> round(-0.89 / 0.006980) = round(-127.51)= -128  <- clips to -128
   0.34 -> round(0.34 / 0.006980)  = round(48.71)  =  49

Step 4: Dequantize:
  -64 -> 0.006980 \times (-64) = -0.44672   (error: 0.00328)
   17 -> 0.006980 \times 17    =  0.11866   (error: 0.00134)
   -4 -> 0.006980 \times (-4)  = -0.02792   (error: 0.00208)
   96 -> 0.006980 \times 96    =  0.67008   (error: 0.00008)
 -128 -> 0.006980 \times (-128)= -0.89344   (error: 0.00344)  <- clipping error
   49 -> 0.006980 \times 49    =  0.34202   (error: 0.00202)

Maximum error: 0.00344 \approx s/2 = 0.00349  OK

9.2 Signal-to-Quantization-Noise Ratio (SQNR)

SQNR measures the quality of quantization - how much signal is preserved relative to the quantization noise introduced.

Definition:

SQNR=E[x2]E[(xx^)2]=signal powerquantization noise power\text{SQNR} = \frac{\mathbb{E}[x^2]}{\mathbb{E}[(x - \hat{x})^2]} = \frac{\text{signal power}}{\text{quantization noise power}}

In dB:

SQNRdB=10log10SQNR\text{SQNR}_{\text{dB}} = 10 \log_{10} \text{SQNR}

For uniform quantization of a uniformly distributed signal:

The quantisation error is approximately uniformly distributed on [s/2,s/2][-s/2, s/2] with variance:

σϵ2=s212\sigma_{\epsilon}^2 = \frac{s^2}{12}

For a signal uniformly distributed on [α,α][-\alpha, \alpha] quantized to bb bits:

s=2α2b12α2bs = \frac{2\alpha}{2^b - 1} \approx \frac{2\alpha}{2^b} σx2=α23\sigma_x^2 = \frac{\alpha^2}{3} SQNR=σx2σϵ2=α2/3s2/12=4α2s2=4α2(2α/2b)2=22b\text{SQNR} = \frac{\sigma_x^2}{\sigma_\epsilon^2} = \frac{\alpha^2/3}{s^2/12} = \frac{4\alpha^2}{s^2} = \frac{4\alpha^2}{(2\alpha/2^b)^2} = 2^{2b} SQNRdB=10log10(22b)=20blog10(2)6.02b\text{SQNR}_{\text{dB}} = 10 \log_{10}(2^{2b}) = 20b \log_{10}(2) \approx 6.02b

The 6 dB per bit rule: each additional bit of precision adds approximately 6 dB of SQNR.

Bit WidthSQNR (dB)Noise RatioTypical Use
212.06.25%Binary/ternary weights
424.10.39%Inference (GPTQ, AWQ)
848.20.0015%PTQ inference standard
1696.3<107%< 10^{-7}\%Training weights
32192.7<1017%< 10^{-17}\%Master weights, optimizer

For non-uniform distributions (which neural network weights follow):

Neural network weights typically follow a bell-shaped distribution (roughly Gaussian with heavy tails after training). This means:

  • Most values cluster near zero
  • A few outlier values stretch the range
  • Uniform quantization "wastes" most levels on the sparse tails

This is why calibration (choosing xminx_{\min}, xmaxx_{\max} to clip outliers) dramatically improves SQNR, and why non-uniform formats like NF4 outperform uniform INT4.

9.3 Block Floating-Point and Group Quantization

Instead of a single (s,z)(s, z) pair for an entire tensor, group quantization assigns one (s,z)(s, z) pair per block of GG elements:

BLOCK/GROUP QUANTIZATION
=========================

Full tensor T with N elements, block size G:

  Block 0         Block 1         Block 2
  +----------+   +----------+   +----------+
  | w_0...w_{G-1}| | w_G...w_{2G-1}| | ...        |
  | s_0, z_0      | | s_1, z_1      | | s_2, z_2      |
  +----------+   +----------+   +----------+

Each block has its own scale s_k and zero-point z_k:

  s_k = (max(block_k) - min(block_k)) / (2^b - 1)
  z_k = round(-min(block_k) / s_k)

Memory overhead:
  Without grouping:  N\timesb bits + 1\times(32+32) bits for (s,z)
  With grouping:     N\timesb bits + (N/G)\times(32+32) bits for (s_k,z_k)
  Overhead ratio:    64/G bits per element \approx 64/G extra bits per parameter

Common group sizes and their overhead (for INT4 - 4 bits per weight):

Group Size GGOverhead per paramEffective bits/paramUsed In
32+1.0 bit (FP16 scale)4.5QLoRA (NF4)
64+0.5 bit4.25Common middle ground
128+0.25 bit4.125GPTQ default
Per-channel+negligible~4.0Weight-only quant
Per-tensor+negligible~4.0Least accurate

Double quantization (QLoRA innovation):

The scales sks_k themselves are FP16 (2 bytes each). With group size G=64G = 64, this adds 2/64=0.031252/64 = 0.03125 bytes = 0.25 bits per parameter. QLoRA further quantizes these FP16 scales to FP8, reducing the overhead to 0.125 bits per parameter:

Effective bits per param (QLoRA)=4+1664816=4+0.125=4.125\text{Effective bits per param (QLoRA)} = 4 + \frac{16}{64} \cdot \frac{8}{16} = 4 + 0.125 = 4.125

9.4 Optimal Quantization Levels - Lloyd-Max Algorithm

Uniform quantization is suboptimal when the input distribution is non-uniform (which it always is for neural network weights). Lloyd-Max quantization finds the 2b2^b reproduction levels {ri}\{r_i\} and decision boundaries {di}\{d_i\} that minimise MSE=E[(xx^)2]\text{MSE} = \mathbb{E}[(x - \hat{x})^2] for a given distribution f(x)f(x).

The optimality conditions:

  1. Nearest-neighbour condition - each value maps to its closest reproduction level:
di=ri+ri+12d_i = \frac{r_i + r_{i+1}}{2}
  1. Centroid condition - each reproduction level is the conditional mean of values in its bin:
ri=di1dixf(x)dxdi1dif(x)dx=E[xdi1x<di]r_i = \frac{\int_{d_{i-1}}^{d_i} x \, f(x) \, dx}{\int_{d_{i-1}}^{d_i} f(x) \, dx} = \mathbb{E}[x \mid d_{i-1} \leq x < d_i]

Iterative algorithm:

  1. Initialise {ri}\{r_i\} (e.g., uniformly spaced or k-means++ style)
  2. Compute {di}\{d_i\} as midpoints of adjacent {ri}\{r_i\}
  3. Recompute {ri}\{r_i\} as centroids of bins defined by {di}\{d_i\}
  4. Repeat until convergence (MSE decreases monotonically)

Connection to NF4: the 16 NF4 levels (6.1) are the Lloyd-Max optimal quantization levels for a standard normal distribution N(0,1)\mathcal{N}(0, 1) with b=4b = 4 bits. The weight tensor is first normalised to zero mean and unit variance, making the normal assumption a good fit.

9.5 Hadamard Transform for Quantization - QuIP and QuaRot

The outlier problem: Transformer activations and certain weight columns contain outlier channels - values 10-100\times larger than the rest. These stretch the quantization range, wasting precision for all other values.

The Hadamard solution: The Hadamard matrix HnH_n is an n×nn \times n orthogonal matrix with entries ±1/n\pm 1/\sqrt{n}:

H1=[1],H2=12[1111],Hn=12[Hn/2Hn/2Hn/2Hn/2]H_1 = [1], \quad H_2 = \frac{1}{\sqrt{2}}\begin{bmatrix} 1 & 1 \\ 1 & -1 \end{bmatrix}, \quad H_n = \frac{1}{\sqrt{2}}\begin{bmatrix} H_{n/2} & H_{n/2} \\ H_{n/2} & -H_{n/2} \end{bmatrix}

Key property: HnHnT=IH_n H_n^T = I (orthogonal) - applying HnH_n and HnTH_n^T in sequence is an identity.

How it helps quantization:

For a weight matrix WW and input xx:

Wx=(WHnT)(Hnx)Wx = (WH_n^T)(H_n x)

Let W=WHnTW' = WH_n^T and x=Hnxx' = H_n x. Then:

  • WW' and xx' have much more uniform magnitudes (Hadamard "spreads" outliers across all dimensions)
  • Quantizing WW' and xx' yields much lower error than quantizing WW and xx directly

This is free at inference (for weights):

  • Pre-compute W=WHnTW' = WH_n^T once during quantization
  • Apply HnH_n to activations at runtime (cost: O(nlogn)O(n \log n) via Fast Walsh-Hadamard Transform, much cheaper than the O(n2)O(n^2) matmul it enables)

QuIP (Chee et al., 2023): uses random orthogonal matrices (including Hadamard) to "incohere" weight matrices before quantization, achieving near-optimal 2-bit quantization.

QuaRot (Ashkboos et al., 2024): applies Hadamard rotations to both weights and activations, enabling 4-bit quantization of all linear layers including KV cache with minimal accuracy loss.

9.6 Quantization Error Propagation Through Layers

A critical question: if each layer introduces quantization error ϵl\epsilon_l, how does the total error grow through LL layers?

Single layer error model:

y^l=Wlqxl=(Wl+ΔWl)xl=Wlxl+ΔWlxl=yl+ϵl\hat{y}_l = W_l^q x_l = (W_l + \Delta W_l) x_l = W_l x_l + \Delta W_l x_l = y_l + \epsilon_l

where ΔWl=WlqWl\Delta W_l = W_l^q - W_l is the per-layer weight perturbation and ϵl=ΔWlxl\epsilon_l = \Delta W_l x_l.

Through LL layers (without activations, for intuition):

y^L=l=1L(Wl+ΔWl)x0\hat{y}_L = \prod_{l=1}^{L} (W_l + \Delta W_l) \cdot x_0

Expanding the product (first-order approximation, dropping cross-terms of ΔW\Delta W):

y^L(l=1LWl)x0+l=1L(k=l+1LWk)ΔWl(k=1l1Wk)x0\hat{y}_L \approx \left(\prod_{l=1}^{L} W_l\right) x_0 + \sum_{l=1}^{L} \left(\prod_{k=l+1}^{L} W_k\right) \Delta W_l \left(\prod_{k=1}^{l-1} W_k\right) x_0

The total error is a sum of LL terms, each amplified by the products of weight matrices of the layers above and below.

Practical implications:

Layer PositionAmplificationRecommendation
First layer (near input)Error amplified by all subsequent layersUse higher precision or skip quantization
Middle layersModerate amplificationStandard quantization (INT4/INT8) is fine
Last layer (near output)Error directly affects logitsUse higher precision or skip quantization
Attention layersErrors in QK^T amplified by softmaxMore sensitive; use INT8 minimum

Empirical observation from GPTQ, AWQ, and similar methods:

  • Quantizing all layers to INT4: moderate perplexity increase (~0.5-1.0 on WikiText)
  • Keeping first and last layers in FP16: recovers most of the loss (~0.1-0.3 increase)
  • This is consistent with the error amplification analysis above

10. Number Systems for Specific AI Operations

Different operations within a transformer have radically different precision requirements. This section provides a per-operation analysis.

10.1 Embedding Tables

The operation: look up a learned vector etRd\mathbf{e}_t \in \mathbb{R}^d for each token tt in the vocabulary.

Precision analysis:

  • Embedding lookup is a table read - no arithmetic; precision only matters for storage
  • Embedding vectors are typically small in magnitude (et,i<1|e_{t,i}| < 1 after weight decay)
  • The embedding is immediately fed into layer normalisation, which re-centres and re-scales

Storage formats:

FormatMemory for V=128K, d=4096Quality Impact
FP322.0 GBBaseline
BF161.0 GBNo measurable loss
INT80.5 GB + scales< 0.1% perplexity increase
INT40.25 GB + scalesSlight degradation; acceptable

Autoregressive inference: embeddings are accessed one token at a time - memory-bandwidth-bound. Smaller formats directly reduce first-token latency.

Training: always FP32/BF16 (embeddings need to receive precise gradient updates since the vocabulary is large and each token's embedding updates are sparse).

10.2 Attention Score Computation

The operation:

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V

Precision requirements by component:

QKTQK^T matmul:

  • BF16 \times BF16 with FP32 accumulation is standard and sufficient
  • FP8 with FP32 accumulation is emerging (H100 tensor cores support this natively)
  • INT8 \times INT8 with INT32 accumulation works for inference with calibrated quantization

÷dk\div \sqrt{d_k} scaling:

  • A single multiplication by 1/dk1/\sqrt{d_k} - precision irrelevant
  • Often fused into the Q or K projection (WQWQ/dkW_Q \leftarrow W_Q / \sqrt{d_k}) to avoid a separate operation

Softmax (see 8.4):

  • Must use FP32 intermediate values (or at minimum the numerically stable BF16 version)
  • The exponentiation and normalisation are the precision bottleneck

Attention \times V matmul:

  • Same precision profile as QKTQK^T: BF16 with FP32 accumulation
  • The attention weights after softmax are in [0,1][0, 1] - well-suited for any format

Multi-query / grouped-query attention (MQA/GQA):

  • Fewer K, V heads (typically 1 or 8 groups vs 32-128 query heads)
  • Doesn't change precision requirements, but dramatically reduces KV cache memory (see 10.3)

10.3 KV Cache Precision

The KV cache stores past key and value vectors for autoregressive generation. For long-context models, it dominates inference memory:

Memory calculation:

KV cache=2×L×nkv_heads×dhead×seq_len×batch×bytes_per_element\text{KV cache} = 2 \times L \times n_{\text{kv\_heads}} \times d_{\text{head}} \times \text{seq\_len} \times \text{batch} \times \text{bytes\_per\_element}

Example - LLaMA 3 70B with 128K context:

  • L=80L = 80 layers, nkv_heads=8n_{\text{kv\_heads}} = 8 (GQA), dhead=128d_{\text{head}} = 128
  • In BF16: 2×80×8×128×131072×2=34.4 GB2 \times 80 \times 8 \times 128 \times 131072 \times 2 = 34.4\text{ GB} per request
  • In INT8: 17.2 GB17.2\text{ GB} per request
  • In INT4: 8.6 GB8.6\text{ GB} per request

KV cache quantization is uniquely challenging:

  • Keys and values are computed on-the-fly during generation; you can't calibrate offline
  • Outlier channels in keys cause accuracy degradation under naive quantization
  • Per-channel quantization (different scale per head dimension) works well because outlier channels are consistent across tokens

Effective approaches:

  • Per-channel INT8: scale per head dimension; negligible accuracy loss for most models
  • Per-token INT4 with group size 128: achievable with careful calibration; small perplexity increase
  • KIVI (Liu et al., 2024): Key cache in INT2 (per-channel), Value cache in INT2 (per-token), with residual FP16 for recent tokens - 8\times compression with minimal loss
  • SqueezeLLM / GEAR: mixed-precision KV cache with outlier tokens kept in higher precision

10.4 Feed-Forward Network (FFN) Precision

The FFN in modern transformers (SwiGLU variant):

FFN(x)=(Swish(xWgate)xWup)Wdown\text{FFN}(x) = (\text{Swish}(xW_{\text{gate}}) \odot xW_{\text{up}}) W_{\text{down}}

Precision analysis:

ComponentOperationFormatRationale
Gate projectionxWgatexW_{\text{gate}}BF16/INT8Standard matmul
Up projectionxWupxW_{\text{up}}BF16/INT8Standard matmul
Swish activationxσ(x)x \cdot \sigma(x)BF16Smooth function; no overflow risk
Element-wise multiply\odotBF16Simple multiplication
Down projectionWdownW_{\text{down}}BF16/INT8Standard matmul, but more sensitive

The SwiGLU precision advantage over ReLU:

  • ReLU creates exact zeros, which slightly help quantization (zero is perfectly representable)
  • SwiGLU produces smooth, non-zero activations - slightly harder to quantize but better model quality
  • In practice, the difference in quantization difficulty is negligible

FFN weight quantization sensitivity:

  • WgateW_{\text{gate}} and WupW_{\text{up}}: moderately sensitive (they determine what information passes through)
  • WdownW_{\text{down}}: most sensitive (final projection directly affects residual stream and all subsequent layers)
  • Practical strategy: quantize WgateW_{\text{gate}}, WupW_{\text{up}} to INT4; keep WdownW_{\text{down}} in INT8 or higher

10.5 Optimizer State Precision

The Adam optimizer maintains two state variables per parameter:

mt=β1mt1+(1β1)gt(first moment - momentum)m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \quad \text{(first moment - momentum)} vt=β2vt1+(1β2)gt2(second moment - adaptive learning rate)v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \quad \text{(second moment - adaptive learning rate)}

Why optimizer states are precision-critical:

  1. Exponential moving average requires cumulative precision:

    • With β2=0.999\beta_2 = 0.999, the effective window is 1000\sim 1000 steps
    • Each step contributes (1β2)=0.001(1 - \beta_2) = 0.001 of its gradient to the running average
    • In BF16: contributions smaller than 7.8×1037.8 \times 10^{-3} relative to the running sum are lost
    • Since 0.001<7.8×1030.001 < 7.8 \times 10^{-3}, individual gradient contributions are frequently lost -> optimizer blindness
    • In FP32: contributions as small as 1.19×1071.19 \times 10^{-7} relative to running sum are preserved -> 65,000\times more room
  2. vtv_t stores squared gradients:

    • Typical gradient: 10410^{-4} -> squared: 10810^{-8}
    • These tiny values need precise accumulation over thousands of steps
    • BF16 would round vtv_t values to zero for small gradients -> divide-by-zero or extremely large updates

Memory-efficient optimizer approaches:

ApproachMemory/paramQualityUsed In
Adam FP3212 bytes (\theta+m+v)BaselineStandard training
Adam BF16 states6 bytesFails - training divergesDon't use
8-bit Adam (Dettmers)4 bytesMinimal lossbitsandbytes
Adafactor4-8 bytesSlight loss for some tasksT5, PaLM
LOMO~0 bytesLimited qualityFine-tuning only
GaLore~4 bytesNear-baselineMemory-constrained training

8-bit Adam (bitsandbytes) - how it works:

  • Stores mm and vv in INT8 with dynamic exponent (block-wise FP8 effectively)
  • Maintains a per-block scale factor in FP32
  • Dequantizes to FP32 for the update step, then re-quantizes
  • Round-trip error is small enough not to affect training for models up to 175B parameters

10.6 Gradient Communication Precision

In distributed training, gradients are communicated between GPUs via all-reduce. For a 70B model, each gradient sync transmits 70B×2=140 GB\sim 70\text{B} \times 2 = 140\text{ GB} in BF16.

Gradient compression techniques:

MethodFormatCompressionQuality
Full BF16BF161\timesBaseline
FP8 all-reduceFP82\times< 0.1% loss
INT8 all-reduceINT82\times< 0.1% loss
1-bit Adam/LAMB1-bit + error feedback16-32\timesSlight loss, converges
TopK + INT8Sparse INT810-100\timesDepends on sparsity

Error feedback mechanism (for aggressive compression):

  1. Compress gradient: g~=compress(g)\tilde{g} = \text{compress}(g)
  2. Compute error: e=gg~e = g - \tilde{g}
  3. Accumulate error in FP32 buffer: eacc+=ee_{\text{acc}} += e
  4. Next step, add accumulated error: g=g+eaccg' = g + e_{\text{acc}}, then compress gg'
  5. The accumulated error ensures that over many steps, the true gradient is communicated on average

DeepSpeed ZeRO and precision:

  • ZeRO-1/2/3 partitions optimizer states/gradients/parameters across GPUs
  • Reduces per-GPU memory but increases communication volume
  • Communication precision becomes even more important at scale
  • ZeRO++ uses INT8 for all-to-all communication with negligible quality loss

11. Hardware Implementation of Number Systems

Understanding hardware constraints explains why certain number formats exist and dictates the achievable performance for each format.

11.1 Integer ALUs - Simple and Fast

An integer arithmetic logic unit (ALU) for bb-bit operations:

Gate count and area:

  • bb-bit adder: 5b\sim 5b gates (carry-lookahead) -> completes in O(logb)O(\log b) gate delays
  • bb-bit multiplier: b2\sim b^2 gates (Wallace tree) -> completes in O(log2b)O(\log^2 b) gate delays
  • An INT8 multiplier uses 64\sim 64 gates; an INT32 multiplier uses 1024\sim 1024 gates - 16\times more silicon

Relative cost (approximate, normalised to INT8 multiply):

OperationRelative AreaRelative EnergyRelative Latency
INT8 multiply1\times1\times1\times
INT16 multiply4\times4\times1.3\times
INT32 multiply16\times16\times1.6\times
INT4 multiply0.25\times0.25\times0.8\times
INT2 multiply0.0625\times0.0625\times0.6\times
INT1 XNOR0.01\times0.01\times0.2\times

This is why 1-bit networks (BitNet) are so attractive for edge deployment: the multiply-accumulate (MAC) unit - the most replicated component - becomes nearly free.

11.2 Floating-Point Units - Complex and Power-Hungry

An FP multiply involves:

  1. XOR the sign bits (1 gate)
  2. Add the exponents (integer adder)
  3. Multiply the mantissas (integer multiplier of mantissa width)
  4. Normalise and round (shift + round logic)

The mantissa multiplier dominates cost. Since mantissa width determines silicon area:

FormatMantissa bitsMultiplier gatesRelative to FP32
FP3223+1=24~5761\times
BF167+1=8~640.11\times (9\times cheaper)
FP1610+1=11~1210.21\times (5\times cheaper)
TF3210+1=11~1210.21\times
FP8 E4M33+1=4~160.028\times (36\times cheaper)
FP8 E5M22+1=3~90.016\times (64\times cheaper)

Key insight: BF16 is only slightly more expensive than FP8 in multiplier area, but the 9\times savings over FP32 is why BF16 replaced FP32 as the default training format. The jump from BF16 to FP8 saves another 4\times.

11.3 Tensor Cores - Systolic Arrays for AI

Tensor cores (NVIDIA terminology; Google calls them "matrix multiply units" in TPUs) are specialised hardware that compute small matrix multiplies in a single clock cycle.

Architecture of a tensor core:

TENSOR CORE (NVIDIA H100 - one core)
====================================

  Computes: D = A \times B + C

  A: 16\times16 matrix (FP16/BF16/FP8/INT8)    input
  B: 16\times16 matrix (FP16/BF16/FP8/INT8)    input
  C: 16\times16 matrix (FP32/FP16)              accumulator (read)
  D: 16\times16 matrix (FP32/FP16)              accumulator (write)

  +----------------------------------------+
  |  16\times16 array of fused multiply-add     |
  |  units, pipelined as a systolic array  |
  |                                        |
  |  For FP16: each unit has an 11-bit     |
  |  multiplier + FP32 adder              |
  |                                        |
  |  For INT8: each unit has an 8-bit      |
  |  multiplier + INT32 adder             |
  |                                        |
  |  For FP8: each unit has a 4-bit        |
  |  multiplier + FP32 adder              |
  +----------------------------------------+

Number of tensor cores per GPU generation:

GPUTensor CoresFormats SupportedPeak TOPS (INT8)
V100 (2017)640FP16130
A100 (2020)432FP16, BF16, TF32, INT8, INT4, FP64624
H100 (2022)528FP16, BF16, TF32, FP8, INT81,979
B200 (2024)640+FP16, BF16, TF32, FP8, FP4, INT84,500+

Each new generation adds support for lower-precision formats, enabling higher throughput without proportionally increasing power or die area.

11.4 Memory Bandwidth - The True Bottleneck

For large language model inference, the bottleneck is almost never compute - it's memory bandwidth. Loading model weights from HBM (High-Bandwidth Memory) to the compute units limits throughput.

The arithmetic intensity argument:

For a single token in autoregressive generation:

  • Compute: one matrix-vector multiply per layer: 2×d22 \times d^2 FLOPs (for a square weight matrix)
  • Memory: load weight matrix: d2×bytes_per_weightd^2 \times \text{bytes\_per\_weight}
  • Arithmetic intensity =FLOPsbytes=2d2d2×B=2B= \frac{\text{FLOPs}}{\text{bytes}} = \frac{2d^2}{d^2 \times B} = \frac{2}{B}

For different formats:

FormatBytes BBArithmetic IntensityBottleneck
FP3240.5 FLOP/byteMemory (severely)
BF1621.0 FLOP/byteMemory
INT812.0 FLOP/byteMemory (still)
INT40.54.0 FLOP/byteMemory or compute (depends on GPU)

H100 SXM5 specs:

  • HBM3 bandwidth: 3.35 TB/s
  • FP8 tensor core compute: 1,979 TFLOPS

To be compute-bound: need arithmetic intensity >1979/33500.59> 1979 / 3350 \approx 0.59 FLOP/byte. Even INT8 (2 FLOP/byte) exceeds this threshold, meaning single-token inference is compute-bound at INT8 on H100 for sufficiently small batch sizes. But at batch size 1, the overhead of loading the full weight matrix dominates.

Practical implications:

  • Reducing model size from BF16 -> INT4 doesn't just halve memory - it halves the time to load weights, almost doubling inference speed
  • This is why quantization provides near-linear speedups for inference - the improvement comes from reduced memory traffic, not faster arithmetic

11.5 Energy Cost per Operation

Energy consumption is a critical constraint for data centre deployment and edge AI:

Energy per arithmetic operation (approximate, 7nm process):

OperationEnergy (pJ)Relative
INT8 MAC0.21\times
INT16 MAC0.94.5\times
FP16 FMA1.05\times
BF16 FMA0.84\times
FP32 FMA3.718.5\times
FP8 FMA0.42\times
INT4 MAC0.050.25\times
1-bit XNOR+popcount0.020.1\times
HBM3 DRAM read (64 bytes)12.562.5\times

Critical observation: DRAM access costs 60\times more energy than an INT8 MAC. For large models:

  • Most energy is spent moving data, not computing
  • Smaller number formats reduce both memory traffic and compute energy - a double benefit
  • This energy analysis is the fundamental driver behind the industry's push toward lower precision

Energy for one forward pass of a 70B model (estimated):

FormatCompute EnergyMemory EnergyTotalRelative
FP32~80 J~150 J~230 J1\times
BF16~17 J~75 J~92 J0.40\times
INT8~4 J~38 J~42 J0.18\times
INT4~1 J~19 J~20 J0.087\times

INT4 inference uses approximately 11\times less energy than FP32 - enabling deployment at 1/11th the power budget.

11.6 Format Conversion Hardware

GPUs include dedicated hardware for converting between number formats:

Conversions that are "free" (no precision loss, handled by wiring):

  • FP32 -> FP64: zero-extend mantissa, adjust exponent bias
  • BF16 -> FP32: zero-extend 16 bits of mantissa, copy exponent and sign
  • INT8 -> INT32: sign-extend 24 bits

Conversions that require rounding (1-cycle latency on modern GPUs):

  • FP32 -> BF16: truncate 16 mantissa bits, apply rounding (3.5)
  • FP32 -> FP16: truncate + check for overflow (may produce Inf/NaN)
  • FP32 -> FP8: significant truncation + overflow check + rounding
  • FP32 -> INT8: multiply by scale, round, clamp

Key conversion in mixed-precision training:

FP32 master weight -> BF16 working copy:
  Simply drop the lower 16 mantissa bits and round
  This truncation introduces up to 0.39% error per weight

  But the FP32 master copy is never lost - it receives the
  precise gradient update and only the BF16 working copy
  is re-derived fresh each forward pass

12. Precision and the Training Stability Connection

Training instability in large language models is intimately connected to number format limitations. This section explores the mechanisms by which precision failures manifest as training failures.

12.1 Loss Landscape and Precision

The loss landscape of a large neural network is a high-dimensional surface L(θ)L(\theta) where θRN\theta \in \mathbb{R}^N (N1010N \sim 10^{10} for a 70B model).

Curvature and precision requirements:

The second-order Taylor expansion around the current point θ\theta gives:

L(θ+δ)L(θ)+LTδ+12δTHδL(\theta + \delta) \approx L(\theta) + \nabla L^T \delta + \frac{1}{2} \delta^T H \delta

where HH is the Hessian matrix. The curvature eigenvalues {λi}\{\lambda_i\} of HH determine precision requirements:

  • High curvature direction (λi1\lambda_i \gg 1): the loss changes rapidly; even small perturbations from quantization cause large loss changes -> needs high precision
  • Low curvature direction (λi0\lambda_i \approx 0): the loss is flat; quantization noise barely affects loss -> can tolerate low precision

Implications for mixed precision:

  • Most directions in parameter space are low-curvature (the loss landscape is approximately flat in most dimensions)
  • Only a small fraction of directions are "sharp" - these correspond to critical features
  • This is why BF16 training works: the quantization noise (0.4%\sim 0.4\% per weight) is smaller than the gradient step size in most directions, and the few sharp directions are protected by FP32 master weights

12.2 Precision Cliffs - When Training Suddenly Diverges

A precision cliff occurs when training appears stable for many steps, then suddenly produces NaN or Inf losses. The mechanism:

Phase 1 - Slow drift (invisible):

  • Gradient accumulation in BF16 silently drops small gradient contributions (8.2)
  • Certain weight matrices slowly drift from their optimal values
  • Loss appears stable because the drift is small compared to normal training noise

Phase 2 - Amplification:

  • Drifted weights cause slightly larger activations in deep layers
  • Attention logits grow (12.5), pushing softmax inputs toward overflow
  • Gradients start hitting gradient clipping more frequently
  • The model is now operating at the edge of numerical stability

Phase 3 - Collapse:

  • A single unlucky batch pushes attention logits past BF16/FP32 overflow -> Inf
  • Inf propagates through softmax -> NaN
  • NaN gradients update all weights -> all weights become NaN -> training dies

This typically happens after 50K-200K training steps, which is why short training runs may appear fine in low precision while long runs fail.

Diagnostic signals:

  • Gradient norm spikes (> 10\times baseline) increasing in frequency
  • Learning rate warmup completing without issue, then instability at decay phase
  • Loss spikes that don't recover to baseline

12.3 Stochastic Rounding - A Precision Amplifier

When rounding xx to the nearest representable value in format FF:

  • Round-to-nearest-even (RNE): always rounds to the same value -> systematic bias when many values round in the same direction
  • Stochastic rounding: rounds up with probability p=(xxF)/(xFxF)p = (x - \lfloor x \rfloor_F) / ({\lceil x \rceil_F - \lfloor x \rfloor_F}), otherwise rounds down

Mathematical property of stochastic rounding:

E[SR(x)]=x\mathbb{E}[\text{SR}(x)] = x

Stochastic rounding is unbiased - the expected value of the rounded result equals the true value. Over many steps, the rounding errors cancel out on average:

E[t=1TSR(gt)]=t=1Tgt\mathbb{E}\left[\sum_{t=1}^{T} \text{SR}(g_t)\right] = \sum_{t=1}^{T} g_t

This means that even if individual gradient contributions are too small to represent in BF16, stochastic rounding ensures they contribute to the weight update over time.

Convergence guarantee: With RNE, gradients smaller than 12\frac{1}{2} ULP are always rounded to zero - the weight provably never updates if all gradients are this small. With stochastic rounding, even infinitesimally small gradients have a non-zero probability of causing an update.

Hardware support:

  • NVIDIA H100 supports stochastic rounding for FP8 operations
  • This is one reason why FP8 training is feasible despite the very low precision - stochastic rounding compensates for the coarse representation

12.4 Adam Optimizer Numerical Error Analysis

The Adam update rule:

θt+1=θtηvt/(1β2t)+ϵmt1β1t\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{v_t / (1 - \beta_2^t)} + \epsilon} \cdot \frac{m_t}{1 - \beta_1^t}

Numerical failure mode 1 - ϵ\epsilon too small:

  • Standard ϵ=108\epsilon = 10^{-8}
  • If vt0v_t \approx 0 (parameter with near-zero gradients): update ηmt/ϵ\approx \eta m_t / \epsilon
  • For η=104\eta = 10^{-4} and mt=104m_t = 10^{-4}: update =104×104/108=1.0= 10^{-4} \times 10^{-4} / 10^{-8} = 1.0 - a massive weight update
  • This is why ϵ\epsilon should be 10610^{-6} or even 10410^{-4} for BF16 training

Numerical failure mode 2 - vtv_t overflow in low precision:

  • vtv_t accumulates gt2g_t^2; if gradients are large (g1g \sim 1), then vt1v_t \sim 1, fine
  • But if gradient clipping isn't applied and g100g \sim 100 (loss spike): g2=104g^2 = 10^4 -> vtv_t grows rapidly
  • In FP32: max 3.4×1038\approx 3.4 \times 10^{38} - fine
  • In 8-bit optimizer states: overflow can occur, corrupting the optimizer state silently

Numerical failure mode 3 - bias correction precision:

  • Bias correction factor (1β2t)(1 - \beta_2^t) for β2=0.999\beta_2 = 0.999:
    • t=1t = 1: 10.999=0.0011 - 0.999 = 0.001
    • t=100t = 100: 10.999100=10.905=0.0951 - 0.999^{100} = 1 - 0.905 = 0.095
    • t=10000t = 10000: 10.99910000=10.00004541.01 - 0.999^{10000} = 1 - 0.0000454 \approx 1.0
  • Early in training (t<100t < 100), dividing by 0.0010.001 amplifies vtv_t by 1000\times - this amplification can push values out of representable range in low precision

12.5 Attention Logit Growth

A common failure mode in large transformer training:

The mechanism:

  1. Attention logits a=QKT/dka = QK^T / \sqrt{d_k} grow slowly during training as the model learns sharper attention patterns
  2. The softmax temperature effectively decreases: softmax(a/T)\text{softmax}(a/T) with decreasing TT
  3. As logits grow, the softmax output becomes more peaked ("spiky")
  4. Eventually, a single attention logit dominates: softmax([50,1,1,])[1,0,0,]\text{softmax}([50, 1, 1, \ldots]) \approx [1, 0, 0, \ldots]
  5. The gradient of softmax when it's nearly one-hot is very small -> vanishing gradients for attention
  6. Meanwhile, the large logit values approach the overflow boundary of the number format

Numerical progression (BF16 example):

Training StepMax LogitSoftmax MaxGradient MagnitudeRisk
10K5.00.73NormalSafe
50K15.00.995ReducedLow
100K30.00.9999Very smallMedium
200K60.0~1.0Near zeroHigh
250K89+Overflow -> NaNN/ACrash

Mitigation strategies:

  • QK-Norm (Dehghani et al., 2023): normalise QQ and KK before computing attention: a=norm(Q)norm(K)T/dka = \text{norm}(Q) \cdot \text{norm}(K)^T / \sqrt{d_k}. This bounds logit magnitudes by dk\sqrt{d_k} regardless of training step.
  • Logit capping: clamp attention logits to [C,C][-C, C] before softmax (C=50C = 50 in PaLM). Simple but effective.
  • Softmax temperature: explicitly maintain temperature: softmax(a/Tlearned)\text{softmax}(a / T_{\text{learned}}) where TT is a learnable parameter, constrained to be positive.

13. Practical Guide - Choosing Number Formats

13.1 Decision Framework for Training

TRAINING FORMAT DECISION TREE
==============================

Start: What is your model size?

+-- < 1B parameters (fits in one GPU)
|   +-- Research/prototyping -> FP32 (simplest, no mixed-precision bugs)
|   +-- Production training  -> BF16 mixed precision (2\times speedup)
|
+-- 1B-13B parameters
|   +-- Full training  -> BF16 mixed precision (mandatory for speed)
|   +-- Fine-tuning    -> QLoRA (NF4 base + BF16 adapters)
|
+-- 13B-70B parameters
|   +-- Full training  -> BF16 + ZeRO-3 across multiple GPUs
|   +-- Fine-tuning    -> QLoRA NF4 (fits on single 80GB GPU)
|   +-- Continued PT   -> BF16 + gradient checkpointing
|
+-- 70B+ parameters
    +-- Full training  -> BF16 + 3D parallelism (need cluster)
    +-- FP8 training   -> if H100+ hardware (emerging 2024+)
    +-- Fine-tuning    -> QLoRA NF4 (4-bit base + 16-bit LoRA)

KEY RULES:
  1. Master weights ALWAYS in FP32
  2. Optimizer states ALWAYS in FP32 (or 8-bit Adam)
  3. Forward/backward: BF16 (or FP8 on H100+)
  4. Loss computation: FP32
  5. Gradient accumulation: FP32

13.2 Decision Framework for Inference

INFERENCE FORMAT DECISION TREE
===============================

Start: What is your latency/quality tradeoff?

+-- Maximum quality (no degradation acceptable)
|   +-- BF16 / FP16 (same as training format)
|
+-- Balanced quality/speed (< 1% perplexity increase)
|   +-- NVIDIA GPU -> INT8 (W8A8) with SmoothQuant
|   +-- Apple Silicon -> INT8 (W8A8) via MLX
|   +-- CPU -> INT8 with ONNX Runtime
|
+-- High compression (1-3% perplexity increase, 4\times speedup)
|   +-- Batch inference    -> GPTQ INT4 (W4A16)
|   +-- Real-time serving  -> AWQ INT4 (W4A16)
|   +-- Memory-constrained -> GGML/GGUF Q4_K_M
|
+-- Maximum compression (edge/mobile deployment)
|   +-- INT3 (W3A16) -> aggressive but usable with GPTQ
|   +-- INT2 (W2A16) -> significant quality loss; only for small models
|   +-- 1.58-bit (ternary) -> BitNet models (trained from scratch only)
|
+-- Speculative/KV cache optimisation
    +-- KV cache INT8 -> per-channel quantization
    +-- KV cache INT4 -> with group quantization + recent-token FP16

13.3 Per-Layer Sensitivity Table

Different layers have different tolerance for quantization. Based on empirical findings from AWQ, GPTQ, SqueezeLLM:

Layer TypeINT8INT4INT3INT2Notes
EmbeddingOKOKWARNINGNOVocabulary coverage matters
Attention Q, K projectionsOKOKWARNINGNOAttention pattern quality degrades
Attention V projectionOKOKOKWARNINGSlightly more robust than Q, K
Attention output projectionOKOKWARNINGNOAffects residual stream
FFN gate & up projectionOKOKOKWARNINGRelatively robust
FFN down projectionOKWARNINGNONOMost sensitive - affects residual
LM head (final projection)OKWARNINGNONODirectly affects token probabilities
Layer norm / RMSNormOKNONONOKeep in FP32 or BF16 always

Legend: OK = safe, WARNING = measurable degradation, NO = significant quality loss

13.4 Quantization Quality vs Bit Width

Empirical perplexity results (approximate, varies by model and method):

Bit WidthMethod7B Model PPL13B Model PPL70B Model PPL
16 (BF16)Baseline5.685.093.56
8 (INT8)SmoothQuant5.69 (+0.01)5.10 (+0.01)3.56 (+0.00)
4 (INT4)GPTQ5.85 (+0.17)5.20 (+0.11)3.60 (+0.04)
4 (NF4)QLoRA/bitsandbytes5.80 (+0.12)5.16 (+0.07)3.58 (+0.02)
4 (INT4)AWQ5.78 (+0.10)5.15 (+0.06)3.58 (+0.02)
3 (INT3)GPTQ6.29 (+0.61)5.51 (+0.42)3.72 (+0.16)
2 (INT2)QuIP#7.85 (+2.17)6.43 (+1.34)4.15 (+0.59)

Key observation: larger models are more tolerant of quantization - a 70B INT4 model often outperforms a 13B BF16 model. This is because larger models have more redundancy in their weight matrices.


14. Common Mistakes and Misconceptions

#MistakeWhy It's WrongCorrect Understanding
1"FP16 and BF16 are interchangeable"FP16 overflows at 65504; BF16 at 3.4\times10^3^8. FP16 needs loss scaling; BF16 doesn't.BF16 matches FP32 range; FP16 does not. Choose BF16 for training.
2"More bits always means better quality"70B at INT4 outperforms 13B at FP16 on most benchmarks.Total model capacity matters more than per-parameter precision.
3"Quantization only saves memory"Quantization also reduces memory bandwidth (the actual bottleneck) -> direct speedup.Memory savings -> bandwidth savings -> latency reduction.
4"INT8 inference loses accuracy"With proper calibration (SmoothQuant, GPTQ), INT8 is lossless for nearly all models.INT8 is the "free lunch" of inference optimisation.
5"I should quantize my model during training"Quantization-aware training is expensive and unnecessary for PTQ with 4+ bits.Use PTQ unless deploying at INT2 or lower.
6"FP32 master weights waste memory"Without FP32 master weights, BF16 training diverges after ~100K steps.FP32 masters are essential, not optional.
7"Stochastic rounding is just noisy"SR is an unbiased estimator that enables convergence in low precision.SR is mathematically principled, not a hack.
8"FLOPS determines inference speed"Memory bandwidth is the bottleneck for autoregressive LLM inference.Optimise for memory bandwidth, not compute.
9"Quantizing KV cache is dangerous"Per-channel INT8 KV cache quantization is nearly lossless.KV cache quantization is safe and critical for long-context.
10"All layers should use the same precision"Output projection and FFN down projection are much more sensitive.Use mixed-precision per-layer quantization.

15. Exercises

Exercise 1: IEEE 754 Encoding

Encode the decimal value 13.625-13.625 in IEEE 754 FP32 format. Show all steps: sign bit, binary conversion, normalisation, biased exponent, and mantissa. Verify by decoding your result back to decimal.

Exercise 2: Quantization Calculation

Given a weight tensor W=[1.2,0.3,0.7,0.9,0.0,0.1,0.5,0.4]W = [-1.2, 0.3, -0.7, 0.9, 0.0, -0.1, 0.5, -0.4], compute the INT8 symmetric quantization:

  • (a) Calculate the scale factor ss
  • (b) Quantize all values to INT8
  • (c) Dequantize and compute the maximum absolute error
  • (d) Compute the SQNR in dB

Exercise 3: BF16 Precision Limits

Two FP32 values: a=1.0a = 1.0 and b=0.001b = 0.001.

  • (a) Add a+ba + b in FP32. What is the result?
  • (b) Cast both values to BF16, add, and cast back to FP32. What is the result?
  • (c) Repeat the BF16 addition 1000 times (adding bb to a running sum starting at aa). Compare with the FP32 result. What is the relative error?
  • (d) How does this relate to gradient accumulation in training?

Exercise 4: Memory Budget

You have a single NVIDIA A100 80GB GPU. Calculate the maximum model size (in parameters) you can:

  • (a) Train with full FP32 (weights + Adam optimizer states)
  • (b) Train with BF16 mixed precision (FP32 master + BF16 working + FP32 Adam)
  • (c) Fine-tune with QLoRA (NF4 base + BF16 LoRA with rank 16, hidden dim 4096, 32 layers \times 4 projection matrices)
  • (d) Serve for inference in INT4 (weights only, no optimizer states)

Exercise 5: Softmax Stability

Given attention logits z=[88.5,88.7,88.3,88.6]z = [88.5, 88.7, 88.3, 88.6]:

  • (a) Compute ezie^{z_i} for each value. Can these be represented in FP32? In BF16?
  • (b) Apply the max-subtraction trick and recompute. Show that no overflow occurs.
  • (c) Compute the final softmax probabilities.
  • (d) What would happen if z=[200,200.1,199.9]z = [200, 200.1, 199.9]? How does log-sum-exp help?

Exercise 6: Error Propagation

A 3-layer network with weight matrices W1,W2,W3W_1, W_2, W_3 (each d×dd \times d, where d=1024d = 1024). Each weight is quantized to INT8 with scale s=0.01s = 0.01. Assuming input x=1\|x\| = 1 and Wi=1\|W_i\| = 1 (operator norm):

  • (a) Bound the quantization error ΔWi\|\Delta W_i\| for each layer.
  • (b) Using the first-order error approximation from 9.6, bound the total output error.
  • (c) At what bit width does the output error drop below 1% of the output magnitude?

Exercise 7: Hardware Arithmetic Intensity

For a decoder-only transformer with L=32L = 32 layers, d=4096d = 4096, dffn=11008d_{\text{ffn}} = 11008 (LLaMA-7B architecture):

  • (a) Calculate the total FLOPs for generating one token (attention + FFN, ignore KV cache read).
  • (b) Calculate the total weight bytes loaded from memory in INT4, INT8, and BF16.
  • (c) Compute the arithmetic intensity for each format.
  • (d) Given the H100's 3.35 TB/s bandwidth and 1979 TOPS (INT8), determine which format is compute-bound vs memory-bound.

Exercise 8: Stochastic Rounding Simulation

Implement a simple stochastic rounding function in Python. Given a "true" gradient g=0.001g = 0.001 and a simulated BF16 format where values snap to multiples of 0.0078125 (= 272^{-7}, the BF16 ULP near 1.0):

  • (a) What does round-to-nearest produce for gg? (Hint: gg is below half the ULP.)
  • (b) Implement stochastic rounding. Over 10,000 steps of accumulating gg into a running sum, what is the expected sum? Simulate and verify.
  • (c) Compare with deterministic RNE accumulation over the same 10,000 steps. What is the relative error?

16. Why This Matters for AI/ML

ConceptWhere It AppearsWhy You Need It
IEEE 754 FP32/FP64Loss computation, optimizer statesUnderstanding overflow/underflow prevents silent training failures
BF16Default training format (2020+)Know when and why to use it; understand its precision limits
FP8Next-gen training (H100+)Critical for cost-efficient training at scale
INT8 quantizationStandard inference optimisation2\times speedup with no quality loss - mandatory knowledge
INT4 quantizationHigh-compression inferenceEnables serving 70B models on consumer hardware
NF4QLoRA fine-tuningEnables 70B fine-tuning on a single GPU
Softmax stabilityEvery forward passPrevents the most common NaN crash in transformers
Mixed precisionEvery training pipelineHalves memory, doubles speed - used in all modern training
KV cache formatsLong-context inferenceEnables 128K+ context without OOM
Error propagationModel debuggingUnderstanding why certain layers can/cannot be quantized aggressively
Hardware constraintsDeployment planningChoose the right format for your target hardware
Stochastic roundingFP8 training, low-precision researchKey enabler for sub-8-bit training

17. Conceptual Bridge

This chapter covered the representation of numbers - how mathematical quantities are encoded in finite hardware. The key insight is that every number format is a tradeoff between range, precision, and cost, and AI engineering is the art of choosing the right tradeoff for each part of the pipeline.

CONCEPTUAL FLOW
===============

  Number Systems (this chapter)
  +-- How are values represented in hardware?
  +-- What are the precision limits?
  +-- How do these limits affect AI systems?
          |
          v
  Sets and Logic (next chapter)
  +-- How do we formalise collections of objects?
  +-- What are the logical foundations of proofs?
  +-- How does set theory underpin probability and statistics?
          |
          v
  Functions and Mappings
  +-- How do we formalise input-output relationships?
  +-- What properties must a function have?
  +-- Neural networks as compositions of functions
          |
          v
  Summation and Product Notation
  +-- Compact notation for sums and products
  +-- Used everywhere: loss functions, gradients, attention
  +-- Connection to vectorised computation

How number systems connect to every subsequent chapter:

  • Linear Algebra: matrix operations are chains of multiply-accumulate; the number format determines both the speed and accuracy of every matrix operation in the model
  • Calculus: derivatives are computed via finite differences or analytical rules; numerical precision determines whether gradient computation is meaningful or noise
  • Probability: probability values in [0,1][0, 1] are represented as floating-point numbers; underflow in small probabilities (P<1038P < 10^{-38}) causes silent failures unless you use log-probabilities
  • Optimisation: every optimiser update involves division, square roots, and accumulation - all precision-sensitive operations
  • Information Theory: cross-entropy loss involves log\log and exp\exp - the most overflow/underflow-prone elementary functions

The fundamental lesson: a deep understanding of number systems is not optional for AI practitioners. It's the foundation that determines whether your model trains at all, how much it costs, how fast it runs, and whether you can deploy it on your target hardware.


<- Back to Mathematical Foundations | Next: Sets and Logic ->

Skill Check

Test this lesson

Answer 4 quick questions to lock in the lesson and feed your adaptive practice queue.

--
Score
0/4
Answered
Not attempted
Status
1

Which module does this lesson belong to?

2

Which section is covered in this lesson content?

3

Which term is most central to this lesson?

4

What is the best way to use this lesson for real learning?

Your answers save locally first, then sync when account storage is available.
Practice queue