Theory NotebookMath for LLMs

Number Systems

Mathematical Foundations / Number Systems

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Number Systems — From Bits to LLM Training

Every number in a computer is a compromise between range, precision, and cost. Understanding these tradeoffs is the foundation of efficient AI engineering.

This notebook is the interactive companion to notes.md. It demonstrates the key concepts from all 17 sections with runnable Python, NumPy, and PyTorch code.

SectionTopicWhat You'll Build
1Positional Number SystemsBinary/hex converters, two's complement, fixed-point
2IEEE 754 Deep DiveBit-level float decoder, special values, epsilon demo
3Floating-Point Formats for AIBF16/FP16/FP8/TF32 comparison, range/precision analysis
4Integer Formats & QuantizationINT8/INT4 quantization, per-channel vs per-tensor
5Non-Uniform FormatsNF4 quantile levels, ternary weight simulation
6Floating-Point ArithmeticCatastrophic cancellation, Kahan summation, FMA
7Numerical StabilityStable softmax, log-sum-exp, RMSNorm vs LayerNorm
8Quantization MathematicsSQNR, group quantization, Lloyd-Max, Hadamard transform
9Mixed Precision TrainingFull mixed-precision pipeline, BF16 precision limits
10Hardware & Memory AnalysisArithmetic intensity, memory budget calculator
11Training StabilityStochastic rounding, Adam errors, attention logit growth
12Practical GuideFormat selector, per-layer sensitivity, error propagation

Prerequisites: Python, NumPy. PyTorch optional but recommended.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3

import numpy as np
import struct
import warnings
from typing import Tuple, List

np.random.seed(42)
np.set_printoptions(precision=8, suppress=True)

# Optional: PyTorch for real implementations
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    torch.manual_seed(42)
    HAS_TORCH = True
    print(f'NumPy {np.__version__} | PyTorch {torch.__version__}')
    if torch.cuda.is_available():
        print(f'GPU: {torch.cuda.get_device_name()}')
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        print('Device: Apple Silicon (MPS)')
    else:
        print('Device: CPU')
except ImportError:
    HAS_TORCH = False
    print(f'NumPy {np.__version__} | PyTorch not installed (NumPy demos still work)')

1. Positional Number Systems

Every number system uses position to determine value. A digit dd at position kk in base β\beta contributes d×βkd \times \beta^k:

N=kdkβkN = \sum_{k} d_k \cdot \beta^k

Why this matters for AI

  • Binary (base 2): how every value is stored in hardware
  • Hexadecimal (base 16): how we inspect memory and weights
  • Two's complement: how INT8/INT4 quantization encodes negative values
  • Fixed-point: an alternative to floating-point used in some accelerators

Code cell 5

# ══════════════════════════════════════════════════════════════════
# 1.1 Base Conversion Engine
# ══════════════════════════════════════════════════════════════════

def decimal_to_base(n: int, base: int) -> str:
    """Convert a non-negative integer to any base (2-16)."""
    if n == 0:
        return '0'
    digits = '0123456789ABCDEF'
    result = []
    while n > 0:
        result.append(digits[n % base])
        n //= base
    return ''.join(reversed(result))

def show_positional_breakdown(n: int, base: int):
    """Show how positional notation builds a number."""
    rep = decimal_to_base(n, base)
    print(f'  {n} in base {base}: {rep}')
    terms = []
    for i, d in enumerate(reversed(rep)):
        val = int(d, 16)  # handles hex digits
        if val > 0:
            terms.append(f'{d}×{base}^{i}={val * base**i}')
    print(f'    = {" + ".join(terms)} = {n}')

# Demonstrate the universality of positional notation
print('Positional Number System Demonstrations:')
print('=' * 55)
for val in [42, 255, 1024]:
    for base in [2, 10, 16]:
        show_positional_breakdown(val, base)
    print()

Code cell 6

# ══════════════════════════════════════════════════════════════════
# 1.2 Two's Complement — How INT8/INT4 Work
# ══════════════════════════════════════════════════════════════════

def twos_complement(n: int, bits: int = 8) -> str:
    """Show two's complement representation of a signed integer."""
    if n >= 0:
        return f'{n:0{bits}b}'
    else:
        # Two's complement: flip bits of |n|, add 1
        pos_bits = f'{abs(n):0{bits}b}'
        flipped = ''.join('1' if b == '0' else '0' for b in pos_bits)
        result = int(flipped, 2) + 1
        return f'{result:0{bits}b}'

print("Two's Complement — the encoding behind INT8 quantization")
print('=' * 60)
print(f'{"Decimal":>8}  {"8-bit Binary":>12}  {"Hex":>5}  Step-by-Step')
print('-' * 60)

for val in [42, -42, 127, -128, 0, 1, -1]:
    bits = twos_complement(val)
    hex_val = f'{int(bits, 2):02X}'
    if val < 0:
        step = f'flip {abs(val):08b}{".".join("1" if b=="0" else "0" for b in f"{abs(val):08b}")} + 1'
    else:
        step = 'direct binary'
    print(f'{val:>8}  {bits:>12}  0x{hex_val:>3}  {step}')

# Show the critical ranges for quantization
print(f'\nQuantization ranges:')
for bits in [8, 4, 2, 1]:
    signed_min = -(2**(bits-1))
    signed_max = 2**(bits-1) - 1
    unsigned_max = 2**bits - 1
    print(f'  INT{bits}: [{signed_min:>5}, {signed_max:>4}]  ({2**bits} levels)'
          f'   UINT{bits}: [0, {unsigned_max:>3}]  ({2**bits} levels)')

Code cell 7

# ══════════════════════════════════════════════════════════════════
# 1.3 Fixed-Point Representation (Q-format)
# ══════════════════════════════════════════════════════════════════

def float_to_fixed(value: float, int_bits: int = 3, frac_bits: int = 4) -> Tuple[int, str]:
    """Convert float to fixed-point Q{int_bits}.{frac_bits} format."""
    scale = 2 ** frac_bits
    total_bits = 1 + int_bits + frac_bits  # sign + integer + fraction
    q = int(round(value * scale))
    # Clamp to representable range
    max_val = 2**(total_bits - 1) - 1
    min_val = -(2**(total_bits - 1))
    q = max(min_val, min(max_val, q))
    # Show binary
    if q < 0:
        binary = f'{(2**total_bits + q):0{total_bits}b}'
    else:
        binary = f'{q:0{total_bits}b}'
    formatted = f'{binary[0]}|{binary[1:1+int_bits]}.{binary[1+int_bits:]}'
    return q, formatted

def fixed_to_float(q: int, frac_bits: int = 4) -> float:
    return q / (2 ** frac_bits)

print('Fixed-Point Q3.4 Format (1 sign + 3 int + 4 frac = 8 bits)')
print('=' * 60)
print(f'{"Value":>8}  {"Q3.4 Binary":>14}  {"Dequantized":>12}  {"Error":>8}')
print('-' * 60)

for val in [2.75, -3.5, 0.0625, 7.9375, -8.0, 0.1, np.pi]:
    q, binary = float_to_fixed(val)
    recon = fixed_to_float(q)
    error = abs(val - recon)
    print(f'{val:>8.4f}  {binary:>14}  {recon:>12.4f}  {error:>8.4f}')

print(f'\nQ3.4 properties:')
print(f'  Range: [{fixed_to_float(-128)}, {fixed_to_float(127)}]')
print(f'  Resolution: {1/16} = 2^(-4)')
print(f'  Compare: FP32 has variable resolution (higher near 0, lower far from 0)')

2. IEEE 754 Floating-Point — Deep Dive

Every float is: (1)sign×(1+mantissa)×2exponentbias(-1)^{\text{sign}} \times (1 + \text{mantissa}) \times 2^{\text{exponent} - \text{bias}}

FP32: [S|EEEEEEEE|MMMMMMMMMMMMMMMMMMMMMMM]   1+8+23 = 32 bits
FP16: [S|EEEEE|MMMMMMMMMM]                    1+5+10 = 16 bits
BF16: [S|EEEEEEEE|MMMMMMM]                    1+8+7  = 16 bits
FP8:  [S|EEEE|MMM] (E4M3) or [S|EEEEE|MM] (E5M2)  = 8 bits

Code cell 9

# ══════════════════════════════════════════════════════════════════
# 2.1 IEEE 754 Bit-Level Decoder
# ══════════════════════════════════════════════════════════════════

def decode_fp32(value: float) -> dict:
    """Fully decode an FP32 value into its IEEE 754 components."""
    packed = struct.pack('>f', value)
    bits = ''.join(f'{byte:08b}' for byte in packed)

    sign_bit = int(bits[0])
    exp_bits = bits[1:9]
    man_bits = bits[9:]

    biased_exp = int(exp_bits, 2)
    true_exp = biased_exp - 127

    # Compute mantissa value
    mantissa_val = sum(int(b) * 2**(-i-1) for i, b in enumerate(man_bits))

    # Determine special cases
    if biased_exp == 0:
        if mantissa_val == 0:
            category = 'Zero'
        else:
            category = 'Subnormal'
    elif biased_exp == 255:
        if mantissa_val == 0:
            category = 'Infinity'
        else:
            category = 'NaN'
    else:
        category = 'Normal'

    return {
        'value': value,
        'bits': f'{bits[0]} | {exp_bits} | {man_bits}',
        'hex': packed.hex(),
        'sign': sign_bit,
        'biased_exp': biased_exp,
        'true_exp': true_exp,
        'mantissa_val': mantissa_val,
        'implicit_1_plus_m': 1.0 + mantissa_val,
        'category': category,
        'formula': f'(-1)^{sign_bit} × {1+mantissa_val:.6f} × 2^{true_exp}'
    }

# Decode a series of important values
print('IEEE 754 FP32 Decoder')
print('=' * 80)

test_values = [5.0, -13.625, 0.1, 1.0, 0.0, float('inf'), float('-inf'), float('nan')]

for val in test_values:
    d = decode_fp32(val)
    print(f'\n  Value: {val}')
    print(f'  Bits:  {d["bits"]}')
    print(f'  Hex:   0x{d["hex"]}')
    print(f'  Type:  {d["category"]}')
    if d['category'] == 'Normal':
        print(f'  Sign={d["sign"]}, Exp={d["biased_exp"]}-127={d["true_exp"]}, '
              f'1+M={d["implicit_1_plus_m"]:.6f}')
        print(f'  = {d["formula"]} = {val}')

Code cell 10

# ══════════════════════════════════════════════════════════════════
# 2.2 Machine Epsilon & Precision Limits
# ══════════════════════════════════════════════════════════════════

print('Machine Epsilon — the fundamental precision limit')
print('=' * 70)
print(f'{"Format":<12} {"ε":>14} {"Decimal digits":>15} {"Relative precision":>20}')
print('-' * 70)
for dtype, name in [(np.float16, 'FP16'), (np.float32, 'FP32'), (np.float64, 'FP64')]:
    eps = np.finfo(dtype).eps
    digits = int(-np.log10(eps))
    print(f'{name:<12} {eps:>14.2e} {digits:>15} {f"1 part in {int(1/eps):,}":>20}')

# BF16 epsilon (manually — not in NumPy)
bf16_eps = 2**-7  # 7 mantissa bits
print(f'{"BF16":<12} {bf16_eps:>14.2e} {"~2":>15} {f"1 part in {int(1/bf16_eps):,}":>20}')
fp8_eps = 2**-3   # E4M3: 3 mantissa bits
print(f'{"FP8 E4M3":<12} {fp8_eps:>14.2e} {"~1":>15} {f"1 part in {int(1/fp8_eps):,}":>20}')

# Demonstrate precision loss in gradient accumulation
print('\n--- Precision Loss Demo: Adding small to large ---')
print(f'{"":<14} {"1.0 + ε":>14} {"1.0 + ε/2":>14} {"ε/2 lost?":>10}')
print('-' * 56)
for dtype in [np.float16, np.float32, np.float64]:
    eps = np.finfo(dtype).eps
    one = dtype(1.0)
    result1 = float(one + dtype(eps))
    result2 = float(one + dtype(eps / 2))
    lost = result2 == 1.0
    print(f'{dtype.__name__:<14} {result1:>14.10f} {result2:>14.10f} {"YES":>10}' if lost
          else f'{dtype.__name__:<14} {result1:>14.10f} {result2:>14.10f} {"no":>10}')

print('\n⚠ In BF16 training: any gradient < 0.78% of the running weight sum is silently lost!')
print('⚠ This is why FP32 master weights are mandatory for stable training.')

Code cell 11

# ══════════════════════════════════════════════════════════════════
# 2.3 Floating-Point Arithmetic is NOT Associative
# ══════════════════════════════════════════════════════════════════

# The classic proof from §3.4 of notes.md
a = np.float32(1e8)
b = np.float32(1.0)
c = np.float32(-1e8)

left = np.float32(np.float32(a + b) + c)   # (a + b) + c
right = np.float32(a + np.float32(b + c))   # a + (b + c)

print('Floating-Point Associativity Failure')
print('=' * 55)
print(f'  a = {a:.0f},  b = {b:.0f},  c = {c:.0f}')
print(f'  (a + b) + c = ({a:.0f} + {b:.0f}) + {c:.0f}')
print(f'              = {np.float32(a+b):.0f} + {c:.0f}')  # b is absorbed!
print(f'              = {left:.0f}  ← b was absorbed into a!')
print(f'  a + (b + c) = {a:.0f} + ({b:.0f} + {c:.0f})')
print(f'              = {a:.0f} + {np.float32(b+c):.0f}')
print(f'              = {right:.0f}  ← correct!')
print(f'\n  (a+b)+c = {left}  ≠  a+(b+c) = {right}')
print(f'  → Associativity FAILS in floating-point!')
print(f'  → This is why summation order matters for gradient accumulation.')

Code cell 12

# ══════════════════════════════════════════════════════════════════
# 2.4 Special Values: Subnormals, Infinity, NaN
# ══════════════════════════════════════════════════════════════════

print('IEEE 754 Special Values')
print('=' * 65)

# Subnormal numbers — the "gradual underflow" zone
print('\n--- Subnormal Numbers (gradual underflow) ---')
normal_min = np.finfo(np.float32).tiny  # smallest normal
subnormal = normal_min / 2  # enters subnormal territory
print(f'  Smallest normal FP32: {normal_min:.6e}')
print(f'  Half of that (subnormal): {subnormal:.6e}')
print(f'  Is subnormal? biased_exp == 0: ', end='')
d = decode_fp32(subnormal)
print(f'{d["category"]} (biased_exp = {d["biased_exp"]})')

# Infinity arithmetic
print('\n--- Infinity Arithmetic ---')
inf = float('inf')
for expr, result in [
    ('inf + 1', inf + 1),
    ('inf + inf', inf + inf),
    ('inf * 2', inf * 2),
    ('1 / inf', 1 / inf),
    ('inf / inf', inf / inf),  # NaN
    ('inf - inf', inf - inf),  # NaN
    ('0 * inf', 0 * inf),     # NaN
]:
    print(f'  {expr:>12} = {result}')

# NaN propagation — the training killer
print('\n--- NaN Propagation (why one NaN kills training) ---')
nan = float('nan')
print(f'  nan + 1 = {nan + 1}')
print(f'  nan * 0 = {nan * 0}')
print(f'  nan == nan: {nan == nan}  ← NaN is not equal to itself!')
print(f'  np.isnan(nan): {np.isnan(nan)}  ← use this to detect NaN')
print(f'  → If a single weight becomes NaN, ALL subsequent matmuls produce NaN')
print(f'  → The entire model is corrupted in one forward pass')

3. Floating-Point Formats for AI

The AI industry uses a zoo of number formats. Each trades precision for speed/memory:

FormatBitsExpMantissaRangeML Role
FP64641152±10³⁰⁸Scientific computing
FP3232823±3.4×10³⁸Master weights, optimizer
TF3219810±3.4×10³⁸Auto matmul on A100+
BF161687±3.4×10³⁸Training default (2020+)
FP1616510±65504Legacy training, inference
FP8 E4M3843±448Forward pass (H100+)
FP8 E5M2852±57344Backward pass (H100+)

Code cell 14

# ══════════════════════════════════════════════════════════════════
# 3.1 Format Range & Precision Comparison
# ══════════════════════════════════════════════════════════════════

# Define format properties (exp_bits, man_bits, max_exp_bias)
formats = {
    'FP64':     {'exp': 11, 'man': 52, 'max_val': 1.8e308,  'eps': 2**-52,  'bytes': 8},
    'FP32':     {'exp':  8, 'man': 23, 'max_val': 3.4e38,   'eps': 2**-23,  'bytes': 4},
    'TF32':     {'exp':  8, 'man': 10, 'max_val': 3.4e38,   'eps': 2**-10,  'bytes': 2.375},
    'BF16':     {'exp':  8, 'man':  7, 'max_val': 3.4e38,   'eps': 2**-7,   'bytes': 2},
    'FP16':     {'exp':  5, 'man': 10, 'max_val': 65504,    'eps': 2**-10,  'bytes': 2},
    'FP8 E4M3': {'exp':  4, 'man':  3, 'max_val': 448,      'eps': 2**-3,   'bytes': 1},
    'FP8 E5M2': {'exp':  5, 'man':  2, 'max_val': 57344,    'eps': 2**-2,   'bytes': 1},
}

print(f'{"Format":<12} {"Bits":>5} {"Max Value":>12} {"Epsilon":>12} {"Precision":>12} {"70B Model":>10}')
print('=' * 75)
for name, f in formats.items():
    total_bits = 1 + f['exp'] + f['man']
    mem_70b = 70e9 * f['bytes'] / 1e9
    precision = f'~{int(-np.log10(f["eps"]))} digits'
    print(f'{name:<12} {total_bits:>5} {f["max_val"]:>12.2e} {f["eps"]:>12.2e} '
          f'{precision:>12} {mem_70b:>8.1f} GB')

print(f'\nKey insight: BF16 has the SAME range as FP32 (8-bit exponent) but only ~2 digits precision.')
print(f'This is why BF16 is preferred over FP16 for training — it never overflows on typical gradients.')

Code cell 15

# ══════════════════════════════════════════════════════════════════
# 3.2 BF16 vs FP16: The Critical Difference
# ══════════════════════════════════════════════════════════════════

if HAS_TORCH:
    print('BF16 vs FP16 — Range and Overflow Behaviour')
    print('=' * 65)

    # Test values near FP16 overflow boundary
    test_vals = [100, 1000, 10000, 65504, 65505, 70000, 100000]
    print(f'{"Value":>10} {"FP16":>12} {"BF16":>12} {"FP32":>12}')
    print('-' * 50)
    for v in test_vals:
        fp32 = torch.tensor(float(v), dtype=torch.float32)
        fp16 = fp32.half()
        bf16 = fp32.bfloat16()
        print(f'{v:>10} {float(fp16):>12.1f} {float(bf16):>12.1f} {float(fp32):>12.1f}')

    # Precision comparison near 1.0
    print(f'\n--- Precision near 1.0 ---')
    base = torch.tensor(1.0)
    for delta_exp in range(-1, -8, -1):
        delta = 2.0 ** delta_exp
        fp16_result = float((base + delta).half())
        bf16_result = float((base + delta).bfloat16())
        fp16_ok = fp16_result != 1.0
        bf16_ok = bf16_result != 1.0
        print(f'  1.0 + 2^{delta_exp} ({delta:.6f}): '
              f'FP16={"✓" if fp16_ok else "✗ LOST":<8} '
              f'BF16={"✓" if bf16_ok else "✗ LOST":<8}')

    print(f'\n→ FP16 has MORE precision (10 vs 7 mantissa bits) but LESS range')
    print(f'→ BF16 has LESS precision but the SAME range as FP32')
    print(f'→ For training: range wins (gradient magnitudes vary hugely)')
else:
    print('PyTorch required for BF16/FP16 comparison')

Code cell 16

# ══════════════════════════════════════════════════════════════════
# 3.3 FP8 Format Deep Dive
# ══════════════════════════════════════════════════════════════════

def enumerate_fp8_values(exp_bits: int, man_bits: int) -> list:
    """Enumerate all representable positive normal values in an FP8-like format."""
    bias = 2**(exp_bits - 1) - 1
    values = []
    for e in range(1, 2**exp_bits - 1):  # skip 0 (subnormal) and all-1s (special)
        for m in range(2**man_bits):
            mantissa = 1.0 + m / (2**man_bits)
            val = mantissa * (2 ** (e - bias))
            values.append(val)
    return sorted(values)

e4m3_vals = enumerate_fp8_values(4, 3)
e5m2_vals = enumerate_fp8_values(5, 2)

print('FP8 Representable Values — E4M3 vs E5M2')
print('=' * 60)
print(f'  E4M3: {len(e4m3_vals)} positive normal values, range [{e4m3_vals[0]}, {e4m3_vals[-1]}]')
print(f'  E5M2: {len(e5m2_vals)} positive normal values, range [{e5m2_vals[0]}, {e5m2_vals[-1]}]')

# Show values between 1.0 and 2.0 to compare precision
print(f'\nValues between 1.0 and 2.0 (precision comparison):')
e4m3_1to2 = [v for v in e4m3_vals if 1.0 <= v < 2.0]
e5m2_1to2 = [v for v in e5m2_vals if 1.0 <= v < 2.0]
print(f'  E4M3 ({len(e4m3_1to2)} levels): {e4m3_1to2}')
print(f'  E5M2 ({len(e5m2_1to2)} levels): {e5m2_1to2}')
print(f'  E4M3 spacing: {e4m3_1to2[1]-e4m3_1to2[0]:.3f}')
print(f'  E5M2 spacing: {e5m2_1to2[1]-e5m2_1to2[0]:.3f}')

print(f'\n→ E4M3 has 2× more levels (precision) → better for weights/activations')
print(f'→ E5M2 has wider range → better for gradients (which vary more in magnitude)')

4. Integer Formats for AI — Quantization in Practice

Quantization maps floating-point values to integers:

xq=clamp ⁣(round ⁣(xs+z),0,2b1)x_q = \text{clamp}\!\left(\text{round}\!\left(\frac{x}{s} + z\right), 0, 2^b - 1\right) x^=s(xqz)(dequantize)\hat{x} = s \cdot (x_q - z) \quad \text{(dequantize)}

Code cell 18

# ══════════════════════════════════════════════════════════════════
# 4.1 Complete INT8 Quantization with Error Analysis
# ══════════════════════════════════════════════════════════════════

def quantize_symmetric(x: np.ndarray, bits: int = 8) -> Tuple[np.ndarray, float]:
    """Symmetric quantization: zero maps to zero, scale based on max abs value."""
    alpha = np.max(np.abs(x))
    qmax = 2**(bits - 1) - 1  # 127 for INT8
    scale = alpha / qmax if alpha > 0 else 1.0
    x_q = np.clip(np.round(x / scale), -qmax - 1, qmax).astype(np.int8 if bits == 8 else np.int32)
    return x_q, scale

def dequantize(x_q: np.ndarray, scale: float) -> np.ndarray:
    """Dequantize: INT → float approximation."""
    return x_q.astype(np.float32) * scale

# Worked example from §9.1 of notes.md
weights = np.array([-0.45, 0.12, -0.03, 0.67, -0.89, 0.34], dtype=np.float32)
x_q, scale = quantize_symmetric(weights, bits=8)
reconstructed = dequantize(x_q, scale)
errors = np.abs(weights - reconstructed)

print('INT8 Symmetric Quantization — Worked Example')
print('=' * 70)
print(f'Scale factor s = max(|w|) / 127 = {np.max(np.abs(weights)):.2f} / 127 = {scale:.6f}')
print(f'\n{"Original":>10} {"w/s":>10} {"Rounded":>8} {"INT8":>6} {"Deq":>10} {"Error":>10}')
print('-' * 60)
for w, q, r, e in zip(weights, weights/scale, x_q, errors):
    print(f'{w:>10.4f} {q:>10.2f} {np.round(q):>8.0f} {int(r):>6} {dequantize(np.array([r]), scale)[0]:>10.5f} {e:>10.5f}')

print(f'\nMax error: {np.max(errors):.5f}  ≈  s/2 = {scale/2:.5f}')
print(f'MSE: {np.mean(errors**2):.8f}')
print(f'Memory: {weights.nbytes} bytes (FP32) → {x_q.nbytes} bytes (INT8) = {weights.nbytes/x_q.nbytes:.0f}× compression')

Code cell 19

# ══════════════════════════════════════════════════════════════════
# 4.2 Per-Tensor vs Per-Channel vs Per-Group Quantization
# ══════════════════════════════════════════════════════════════════

np.random.seed(42)
# Simulate a weight matrix with channels of VERY different magnitudes
# (this is realistic — transformer weights often have outlier channels)
W = np.random.randn(8, 256).astype(np.float32)
W[0] *= 10      # Channel 0: 10× larger (outlier)
W[1] *= 0.01    # Channel 1: 100× smaller

# Per-tensor: single scale for entire matrix
q_tensor, s_tensor = quantize_symmetric(W.flatten())
recon_tensor = dequantize(q_tensor, s_tensor).reshape(W.shape)

# Per-channel: one scale per output channel (row)
recon_channel = np.zeros_like(W)
for c in range(W.shape[0]):
    q_c, s_c = quantize_symmetric(W[c])
    recon_channel[c] = dequantize(q_c, s_c)

# Per-group: one scale per group of G elements
G = 64  # group size
recon_group = np.zeros_like(W)
for c in range(W.shape[0]):
    for start in range(0, W.shape[1], G):
        end = min(start + G, W.shape[1])
        q_g, s_g = quantize_symmetric(W[c, start:end])
        recon_group[c, start:end] = dequantize(q_g, s_g)

print('Quantization Granularity Comparison')
print('=' * 70)
print(f'{"Channel":>4} {"Weight σ":>9} {"Per-Tensor MSE":>16} {"Per-Channel MSE":>16} {"Per-Group MSE":>14}')
print('-' * 65)
for c in range(W.shape[0]):
    mse_t = np.mean((W[c] - recon_tensor[c])**2)
    mse_c = np.mean((W[c] - recon_channel[c])**2)
    mse_g = np.mean((W[c] - recon_group[c])**2)
    flag = ' ← outlier!' if c in [0, 1] else ''
    print(f'{c:>4} {np.std(W[c]):>9.4f} {mse_t:>16.8f} {mse_c:>16.8f} {mse_g:>14.8f}{flag}')

total_t = np.mean((W - recon_tensor)**2)
total_c = np.mean((W - recon_channel)**2)
total_g = np.mean((W - recon_group)**2)
print(f'\nTotal MSE: Per-tensor={total_t:.6f}  Per-channel={total_c:.6f}  Per-group(64)={total_g:.6f}')
print(f'Per-channel is {total_t/total_c:.1f}× better than per-tensor')
print(f'Per-group is {total_t/total_g:.1f}× better than per-tensor')
print(f'\n→ Channel 1 (tiny weights) is crushed by per-tensor: the global scale is too large')
print(f'→ Per-group gives the best results with modest overhead ({W.size // G} extra scale factors)')

Code cell 20

# ══════════════════════════════════════════════════════════════════
# 4.3 INT4 Quantization — W4A16 Pipeline
# ══════════════════════════════════════════════════════════════════

def quantize_int4_symmetric(x: np.ndarray) -> Tuple[np.ndarray, float]:
    """Symmetric INT4 quantization: range [-8, 7]."""
    alpha = np.max(np.abs(x))
    scale = alpha / 7.0 if alpha > 0 else 1.0
    x_q = np.clip(np.round(x / scale), -8, 7).astype(np.int8)  # store in int8 container
    return x_q, scale

# Simulate INT4 weight quantization (W4A16 = 4-bit weights, 16-bit activations)
np.random.seed(42)
W = np.random.randn(128, 128).astype(np.float32) * 0.02  # typical weight scale
x = np.random.randn(1, 128).astype(np.float32)  # input activation

# Full precision matmul
y_fp32 = x @ W.T

# INT4 quantized matmul (per-group, G=32)
G = 32
y_int4 = np.zeros((1, 128), dtype=np.float32)
for col_start in range(0, 128, G):
    col_end = col_start + G
    W_group = W[:, col_start:col_end]
    x_group = x[:, col_start:col_end]
    # Quantize weights to INT4 per group
    for row in range(128):
        q, s = quantize_int4_symmetric(W_group[row])
        W_deq = dequantize(q, s)
        y_int4[0, row] += (x_group @ W_deq.reshape(-1, 1)).item()

# Compare
mse = np.mean((y_fp32 - y_int4)**2)
relative_error = np.mean(np.abs(y_fp32 - y_int4) / (np.abs(y_fp32) + 1e-10))

print('W4A16 (INT4 weight, FP16 activation) Pipeline')
print('=' * 55)
print(f'Weight matrix: {W.shape}, stored in INT4 with group_size={G}')
print(f'FP32 memory: {W.nbytes:,} bytes')
print(f'INT4 memory: ~{W.size // 2:,} bytes (+ scale overhead)')
print(f'Compression: ~{W.nbytes / (W.size // 2):.0f}×')
print(f'\nOutput MSE: {mse:.8f}')
print(f'Mean relative error: {relative_error:.4%}')
print(f'Max |y_fp32|: {np.max(np.abs(y_fp32)):.4f}')
print(f'Max |error|: {np.max(np.abs(y_fp32 - y_int4)):.6f}')
print(f'\n→ Only 16 quantization levels! Yet relative error is small because')
print(f'  the group-wise scale adapts to local weight distribution.')

5. Non-Uniform and Specialised Formats

Neural network weights follow a bell-shaped distribution (roughly Gaussian). Uniform quantization wastes levels on the sparse tails. Non-uniform formats place more levels where the data is dense.

Code cell 22

# ══════════════════════════════════════════════════════════════════
# 5.1 NF4 — Normal Float 4-bit (QLoRA)
# ══════════════════════════════════════════════════════════════════

from scipy.stats import norm

def compute_nf4_levels(bits: int = 4) -> np.ndarray:
    """Compute NF4 quantisation levels as quantiles of N(0,1).
    These are the optimal Lloyd-Max levels for a standard normal distribution."""
    n_levels = 2**bits
    # Quantiles at evenly spaced probabilities
    probs = np.linspace(1/(2*n_levels), 1 - 1/(2*n_levels), n_levels)
    levels = norm.ppf(probs)  # inverse CDF of N(0,1)
    # Normalise to [-1, 1]
    levels = levels / np.max(np.abs(levels))
    return levels

nf4_levels = compute_nf4_levels(4)

print('NF4 (Normal Float 4-bit) — The 16 Quantisation Levels')
print('=' * 65)
print('These are the optimal quantisation levels for normally-distributed data:')
print(f'\nLevels: {np.array2string(nf4_levels, precision=4, separator=", ")}')

# Visualise the distribution of levels
print(f'\nLevel distribution (density near zero is highest):')
for i, level in enumerate(nf4_levels):
    bar = '█' * int((level + 1) * 25)
    print(f'  {i:>2}: {level:>7.4f}  {bar}')

# Compare NF4 vs uniform INT4 on normally-distributed weights
np.random.seed(42)
weights = np.random.randn(10000).astype(np.float32)

# NF4 quantization
w_norm = weights / np.max(np.abs(weights))  # normalise to [-1, 1]
nf4_q = np.array([nf4_levels[np.argmin(np.abs(nf4_levels - w))] for w in w_norm])
nf4_recon = nf4_q * np.max(np.abs(weights))
nf4_mse = np.mean((weights - nf4_recon)**2)

# Uniform INT4 quantization
int4_q, int4_s = quantize_int4_symmetric(weights)
int4_recon = dequantize(int4_q, int4_s)
int4_mse = np.mean((weights - int4_recon)**2)

print(f'\nQuantisation quality comparison (10K normally-distributed weights):')
print(f'  NF4 MSE:    {nf4_mse:.6f}')
print(f'  INT4 MSE:   {int4_mse:.6f}')
print(f'  NF4 is {int4_mse/nf4_mse:.1f}× better for Gaussian data')
print(f'  → NF4 places more levels near zero where most weights live')

Code cell 23

# ══════════════════════════════════════════════════════════════════
# 5.2 Ternary Weights — BitNet b1.58 {-1, 0, +1}
# ══════════════════════════════════════════════════════════════════

def ternarize_weights(W: np.ndarray) -> Tuple[np.ndarray, float]:
    """Quantise weights to {-1, 0, +1} using mean absolute value as threshold."""
    alpha = np.mean(np.abs(W))  # scale factor
    W_ternary = np.zeros_like(W, dtype=np.int8)
    W_ternary[W > alpha * 0.5] = 1
    W_ternary[W < -alpha * 0.5] = -1
    return W_ternary, alpha

# Simulate a ternary matmul vs full-precision
np.random.seed(42)
d = 512
W = np.random.randn(d, d).astype(np.float32) * 0.02
x = np.random.randn(1, d).astype(np.float32)

# Full precision
y_full = x @ W.T

# Ternary
W_tern, alpha = ternarize_weights(W)
y_tern = x @ (W_tern.astype(np.float32) * alpha).T

# Statistics
n_zero = np.sum(W_tern == 0)
n_pos = np.sum(W_tern == 1)
n_neg = np.sum(W_tern == -1)
total = W_tern.size
relative_error = np.mean(np.abs(y_full - y_tern) / (np.abs(y_full) + 1e-10))

print('Ternary Weights — BitNet b1.58 Simulation')
print('=' * 55)
print(f'Weight matrix: {W.shape}')
print(f'Scale factor α = mean(|W|) = {alpha:.6f}')
print(f'\nWeight distribution:')
print(f'  -1: {n_neg:>6} ({n_neg/total:>6.1%})')
print(f'   0: {n_zero:>6} ({n_zero/total:>6.1%})')
print(f'  +1: {n_pos:>6} ({n_pos/total:>6.1%})')
print(f'\nAverage entropy per weight: {-sum(p*np.log2(p+1e-10) for p in [n_neg/total, n_zero/total, n_pos/total]):.2f} bits')
print(f'  → log₂(3) = {np.log2(3):.2f} bits (1.58-bit, hence the name)')
print(f'\nOutput relative error: {relative_error:.4%}')
print(f'\nMatmul cost comparison:')
print(f'  FP32: {d*d} multiply-accumulate operations')
print(f'  Ternary: {n_pos + n_neg} additions only (no multiplications!)')
print(f'  FLOPs reduction: {(total) / (n_pos + n_neg):.1f}× (only additions, no multiplies)')

6. Floating-Point Arithmetic Deep Dive

Understanding the mechanics of FP addition and multiplication explains why catastrophic cancellation occurs and why FMA (fused multiply-add) matters.

Code cell 25

# ══════════════════════════════════════════════════════════════════
# 6.1 Catastrophic Cancellation
# ══════════════════════════════════════════════════════════════════

print('Catastrophic Cancellation — when subtraction destroys precision')
print('=' * 65)

# Example: computing variance with naive formula vs stable formula
# Naive: var = E[x²] - (E[x])²  ← catastrophic cancellation when E[x] >> std(x)
# Stable: var = E[(x - mean(x))²]  ← no cancellation

np.random.seed(42)
# Data with large mean, small variance (extreme case)
data = np.float32(1e6) + np.random.randn(10000).astype(np.float32) * 0.01
true_var = np.float64(np.var(data.astype(np.float64)))  # ground truth in FP64

# Naive formula in FP32
mean_sq = np.float32(np.mean(data.astype(np.float32)**2))
sq_mean = np.float32(np.mean(data.astype(np.float32)))**2
naive_var = np.float32(mean_sq - sq_mean)

# Stable formula in FP32
centered = data - np.float32(np.mean(data))
stable_var = np.float32(np.mean(centered**2))

print(f'Data: {len(data)} values ~ N(1000000, 0.01²)')
print(f'True variance (FP64): {true_var:.10f}')
print(f'\nNaive var = E[x²] - E[x]²:')
print(f'  E[x²] = {mean_sq:.6f}')
print(f'  E[x]² = {sq_mean:.6f}')
print(f'  Difference = {naive_var:.6f}')  # garbage!
print(f'  Relative error: {abs(naive_var - true_var) / true_var:.2%}')
print(f'\nStable var = E[(x - μ)²]:')
print(f'  Result = {stable_var:.10f}')
print(f'  Relative error: {abs(stable_var - true_var) / true_var:.2%}')
print(f'\n→ The naive formula subtracts two nearly-equal large numbers')
print(f'→ Most significant bits cancel, leaving only rounding noise')
print(f'→ This is exactly what happens in LayerNorm if not implemented carefully')

Code cell 26

# ══════════════════════════════════════════════════════════════════
# 6.2 Kahan Summation Algorithm
# ══════════════════════════════════════════════════════════════════

def naive_sum_fp32(values: np.ndarray) -> float:
    """Simple left-to-right summation in FP32."""
    total = np.float32(0.0)
    for v in values:
        total = np.float32(total + np.float32(v))
    return float(total)

def kahan_sum(values: np.ndarray) -> float:
    """Kahan compensated summation — tracks and corrects rounding error."""
    total = np.float32(0.0)
    comp = np.float32(0.0)  # running compensation for lost bits
    for v in values:
        v = np.float32(v)
        y = np.float32(v - comp)          # add back last error
        temp = np.float32(total + y)       # large + small → error here
        comp = np.float32(np.float32(temp - total) - y)  # capture what was lost
        total = temp
    return float(total)

def pairwise_sum(values: np.ndarray) -> float:
    """Pairwise summation (what NumPy uses internally)."""
    if len(values) <= 2:
        return float(np.float32(np.sum(values.astype(np.float32))))
    mid = len(values) // 2
    return float(np.float32(pairwise_sum(values[:mid]) + pairwise_sum(values[mid:])))

# Test: sum 1,000,000 small values
n = 1_000_000
vals = np.full(n, 1e-5, dtype=np.float32)
true_sum = n * 1e-5  # exact: 10.0

r_naive = naive_sum_fp32(vals)
r_kahan = kahan_sum(vals)
r_numpy = float(np.sum(vals))

print(f'Summation of {n:,} × 1e-5 (expected: {true_sum})')
print('=' * 55)
print(f'{"Method":<20} {"Result":>12} {"Error":>12} {"Rel Error":>12}')
print('-' * 58)
for name, result in [('Naive FP32', r_naive), ('Kahan (compensated)', r_kahan), ('NumPy (pairwise)', r_numpy)]:
    err = abs(result - true_sum)
    rel = err / true_sum
    print(f'{name:<20} {result:>12.6f} {err:>12.6f} {rel:>12.2e}')

print(f'\n→ Kahan summation reduces error by {abs(r_naive - true_sum) / max(abs(r_kahan - true_sum), 1e-15):.0f}×')
print(f'→ This is essential for gradient accumulation across large batches')

Code cell 27

# ══════════════════════════════════════════════════════════════════
# 6.3 BF16 Dot Product Accumulation Error
# ══════════════════════════════════════════════════════════════════

if HAS_TORCH:
    # Demonstrate why FP32 accumulation is critical for BF16 matmul
    d = 4096  # typical transformer hidden dimension
    torch.manual_seed(42)

    a = torch.randn(d, dtype=torch.float32)
    b = torch.randn(d, dtype=torch.float32)

    # Ground truth in FP32
    dot_fp32 = torch.dot(a, b)

    # BF16 with BF16 accumulation (BAD)
    a_bf16 = a.bfloat16()
    b_bf16 = b.bfloat16()
    # Simulate BF16 accumulation (Python loop)
    acc_bf16 = torch.tensor(0.0, dtype=torch.bfloat16)
    for i in range(min(d, 512)):  # first 512 for speed
        acc_bf16 += a_bf16[i] * b_bf16[i]
    # Scale up for full vector
    dot_bf16_acc = float(acc_bf16) * (d / 512)

    # BF16 with FP32 accumulation (GOOD — what tensor cores do)
    acc_fp32 = torch.tensor(0.0, dtype=torch.float32)
    for i in range(min(d, 512)):
        product = float(a_bf16[i]) * float(b_bf16[i])  # BF16 multiply
        acc_fp32 += product  # FP32 accumulate
    dot_bf16_fp32_acc = float(acc_fp32) * (d / 512)

    print('Dot Product Accumulation Precision')
    print('=' * 55)
    print(f'Vector dimension: {d}')
    print(f'FP32 dot product (truth):     {float(dot_fp32):.4f}')
    print(f'BF16 with BF16 accumulation:  {dot_bf16_acc:.4f}  '
          f'(error: {abs(dot_bf16_acc - float(dot_fp32)) / abs(float(dot_fp32)):.1%})')
    print(f'BF16 with FP32 accumulation:  {dot_bf16_fp32_acc:.4f}  '
          f'(error: {abs(dot_bf16_fp32_acc - float(dot_fp32)) / abs(float(dot_fp32)):.1%})')
    print(f'\n→ BF16 accumulation has ~{abs(dot_bf16_acc - float(dot_fp32)) / max(abs(dot_bf16_fp32_acc - float(dot_fp32)), 1e-10):.0f}× more error than FP32 accumulation')
    print(f'→ This is why tensor cores always accumulate in FP32')
else:
    print('PyTorch required for BF16 dot product demo')

7. Numerical Stability in Neural Networks

The most common sources of training crashes in LLMs are:

  1. Softmax overflow — attention logits too large
  2. Log-sum-exp overflow — cross-entropy loss computation
  3. LayerNorm cancellation — subtracting mean from nearly-identical values
  4. Gradient vanishing — underflow in FP16 during backward pass

Code cell 29

# ══════════════════════════════════════════════════════════════════
# 7.1 Numerically Stable Softmax — The Max-Subtraction Trick
# ══════════════════════════════════════════════════════════════════

def naive_softmax(z: np.ndarray) -> np.ndarray:
    """Naive softmax — overflows for large logits."""
    exp_z = np.exp(z)
    return exp_z / np.sum(exp_z)

def stable_softmax(z: np.ndarray) -> np.ndarray:
    """Numerically stable softmax — used in all production code."""
    m = np.max(z)
    exp_z = np.exp(z - m)  # max exponent is e^0 = 1
    return exp_z / np.sum(exp_z)

# Normal case — both work
z_normal = np.array([1.0, 2.0, 3.0, 4.0])
print('Softmax: Naive vs Stable')
print('=' * 60)
print(f'Normal logits {z_normal}:')
print(f'  Naive:  {naive_softmax(z_normal)}')
print(f'  Stable: {stable_softmax(z_normal)}')

# Large logits — naive breaks!
z_large = np.array([88.5, 88.7, 88.3, 88.6])
print(f'\nLarge logits {z_large} (near FP32 overflow for exp):')
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    naive_result = naive_softmax(z_large)
print(f'  Naive:  {naive_result}  {"← overflow!" if np.any(np.isnan(naive_result)) else ""}')
print(f'  Stable: {stable_softmax(z_large)}')

# Extreme logits — definitely breaks
z_extreme = np.array([1000.0, 1001.0, 999.0])
print(f'\nExtreme logits {z_extreme}:')
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    naive_result = naive_softmax(z_extreme)
print(f'  Naive:  {naive_result}  ← NaN from inf/inf!')
print(f'  Stable: {stable_softmax(z_extreme)}')

# Mathematical proof
print(f'\n--- Mathematical equivalence proof ---')
print(f'  softmax(z-m)_i = exp(z_i - m) / Σ exp(z_j - m)')
print(f'                 = exp(z_i)·e^(-m) / [e^(-m) · Σ exp(z_j)]')
print(f'                 = exp(z_i) / Σ exp(z_j)  ← e^(-m) cancels!')
print(f'  The trick is mathematically exact, numerically essential.')

Code cell 30

# ══════════════════════════════════════════════════════════════════
# 7.2 Log-Sum-Exp Trick — Essential for Cross-Entropy Loss
# ══════════════════════════════════════════════════════════════════

def naive_logsumexp(z: np.ndarray) -> float:
    """Naive: log(sum(exp(z))) — overflows."""
    return float(np.log(np.sum(np.exp(z))))

def stable_logsumexp(z: np.ndarray) -> float:
    """Stable: m + log(sum(exp(z - m))) — no overflow."""
    m = np.max(z)
    return float(m + np.log(np.sum(np.exp(z - m))))

# Normal case
z = np.array([1.0, 2.0, 3.0])
print('Log-Sum-Exp Trick')
print('=' * 55)
print(f'z = {z}')
print(f'  Naive:  {naive_logsumexp(z):.6f}')
print(f'  Stable: {stable_logsumexp(z):.6f}')

# Overflow case
z_big = np.array([1000.0, 1001.0, 1002.0])
print(f'\nz = {z_big}')
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    print(f'  Naive:  {naive_logsumexp(z_big)}  ← overflow!')
print(f'  Stable: {stable_logsumexp(z_big):.6f}')

# Cross-entropy loss uses log-sum-exp
print(f'\n--- Cross-Entropy Loss = LSE(z) - z_target ---')
vocab_size = 50000
np.random.seed(42)
logits = np.random.randn(vocab_size).astype(np.float32) * 5  # logits with std=5
target = 42  # target token

lse = stable_logsumexp(logits)
loss = lse - logits[target]
print(f'  Vocab size: {vocab_size:,}')
print(f'  LSE(logits): {lse:.4f}')
print(f'  logits[target]: {logits[target]:.4f}')
print(f'  Cross-entropy loss: {loss:.4f}')
print(f'  → This computation happens millions of times per training step')
print(f'  → NEVER compute log(sum(exp(z))) directly!')

Code cell 31

# ══════════════════════════════════════════════════════════════════
# 7.3 LayerNorm vs RMSNorm — Numerical Comparison
# ══════════════════════════════════════════════════════════════════

def layernorm(x: np.ndarray, gamma: np.ndarray, beta: np.ndarray, eps: float = 1e-5) -> np.ndarray:
    """Standard LayerNorm: y = (x - μ) / √(σ² + ε) · γ + β"""
    mu = np.mean(x)
    var = np.mean((x - mu)**2)
    return gamma * (x - mu) / np.sqrt(var + eps) + beta

def rmsnorm(x: np.ndarray, gamma: np.ndarray, eps: float = 1e-5) -> np.ndarray:
    """RMSNorm: y = x / √(mean(x²) + ε) · γ  (no mean subtraction, no beta)"""
    rms = np.sqrt(np.mean(x**2) + eps)
    return gamma * x / rms

d = 128
gamma = np.ones(d, dtype=np.float32)
beta = np.zeros(d, dtype=np.float32)

print('LayerNorm vs RMSNorm — Numerical Properties')
print('=' * 65)

# Test 1: Normal input — both work fine
np.random.seed(42)
x_normal = np.random.randn(d).astype(np.float32)
ln_out = layernorm(x_normal, gamma, beta)
rms_out = rmsnorm(x_normal, gamma)
print(f'Normal input (μ≈0, σ≈1):')
print(f'  LayerNorm output mean: {np.mean(ln_out):.6f}, std: {np.std(ln_out):.6f}')
print(f'  RMSNorm output mean:   {np.mean(rms_out):.6f}, std: {np.std(rms_out):.6f}')

# Test 2: Large-offset input — catastrophic cancellation risk
x_offset = x_normal + 1e6  # huge mean, small variance

# FP32: both OK
ln_offset = layernorm(x_offset, gamma, beta)
rms_offset = rmsnorm(x_offset, gamma)
print(f'\nLarge-offset input (μ≈10⁶, σ≈1) in FP32:')
print(f'  LayerNorm mean: {np.mean(ln_offset):.6f}, std: {np.std(ln_offset):.6f}')
print(f'  RMSNorm mean:   {np.mean(rms_offset):.8f}, std: {np.std(rms_offset):.8f}')

# Test 3: Simulate BF16 precision for mean computation
# BF16 has ~7 mantissa bits → relative precision 2^(-7) ≈ 0.78%
def simulate_bf16_mean(x):
    """Simulate BF16 precision loss in mean computation."""
    # Each value rounded to ~2 decimal digits of precision
    precision = 2**(-7)  # BF16 mantissa precision
    x_bf16 = np.round(x / (np.abs(x) * precision + 1e-38)) * (np.abs(x) * precision + 1e-38)
    return np.mean(x_bf16)

print(f'\nWhy RMSNorm is numerically superior:')
print(f'  LayerNorm computes x - μ → catastrophic cancellation when x ≈ μ')
print(f'  RMSNorm computes x² → sum of positive values → no cancellation')
print(f'  → RMSNorm is used in LLaMA, Mistral, Gemma, and most 2024+ architectures')

Code cell 32

# ══════════════════════════════════════════════════════════════════
# 7.4 Gradient Vanishing/Exploding — Numerical View
# ══════════════════════════════════════════════════════════════════

print('Gradient Flow Through L Layers')
print('=' * 65)
print(f'If each layer\'s Jacobian has dominant eigenvalue λ:')
print(f'  Gradient magnitude ∝ λ^L\n')

print(f'{"λ":>6} {"L=10":>12} {"L=50":>12} {"L=100":>12} {"L=200":>12} {"Status":>20}')
print('-' * 80)

for lam in [1.1, 1.01, 1.001, 1.0, 0.999, 0.99, 0.9]:
    vals = [lam**L for L in [10, 50, 100, 200]]
    if vals[-1] > 1e38:
        status = '💥 EXPLODES (overflow)'
    elif vals[-1] < 1e-10:
        status = '💀 VANISHES (underflow)'
    elif vals[-1] < 0.01:
        status = '⚠ Very small'
    else:
        status = '✓ Stable'
    print(f'{lam:>6.3f} {vals[0]:>12.4e} {vals[1]:>12.4e} {vals[2]:>12.4e} {vals[3]:>12.4e} {status:>20}')

print(f'\n--- Residual Connections Fix This ---')
print(f'  Without residual: x_l = f(x_{{l-1}})       → Jacobian eigenvalue = λ_f')
print(f'  With residual:    x_l = x_{{l-1}} + f(x_{{l-1}}) → Jacobian eigenvalue = 1 + λ_f')
print(f'  Even if λ_f is small, 1 + λ_f ≈ 1 → gradient flows stably')

# FP16 underflow demonstration
print(f'\n--- FP16 Underflow Zone ---')
fp16_min = np.finfo(np.float16).tiny  # smallest positive normal
print(f'  FP16 smallest positive normal: {fp16_min:.2e}')
print(f'  BF16 smallest positive normal: ~1.2e-38')
print(f'  Gradient of 1e-5 in FP16: {np.float16(1e-5)}{"OK" if np.float16(1e-5) != 0 else "LOST!"}')
print(f'  Gradient of 1e-6 in FP16: {np.float16(1e-6)}{"OK" if np.float16(1e-6) != 0 else "LOST!"}')
print(f'  → FP16 silently kills gradients < 6e-5 → training stalls without any error message')

8. Quantization Mathematics

Beyond the mechanics of quantization (§4), this section covers the mathematical theory:

  • Signal-to-quantization-noise ratio (SQNR)
  • Lloyd-Max optimal quantization
  • Hadamard transform for outlier suppression
  • Error propagation through multiple layers

Code cell 34

# ══════════════════════════════════════════════════════════════════
# 8.1 Signal-to-Quantization-Noise Ratio (SQNR)
# ══════════════════════════════════════════════════════════════════

def compute_sqnr(x: np.ndarray, x_hat: np.ndarray) -> float:
    """Compute SQNR in dB."""
    signal_power = np.mean(x**2)
    noise_power = np.mean((x - x_hat)**2)
    if noise_power == 0:
        return float('inf')
    return 10 * np.log10(signal_power / noise_power)

print('SQNR vs Bit Width — The 6 dB/bit Rule')
print('=' * 60)

np.random.seed(42)
# Uniformly distributed signal (theoretical case)
x_uniform = np.random.uniform(-1, 1, 100000).astype(np.float32)
# Normally distributed signal (realistic weights)
x_normal = np.random.randn(100000).astype(np.float32)

print(f'{"Bits":>5} {"Theory (dB)":>12} {"Uniform (dB)":>13} {"Normal (dB)":>12} {"Noise ratio":>12}')
print('-' * 60)

for bits in [1, 2, 3, 4, 6, 8, 16]:
    theory = 6.02 * bits

    # Quantize uniform signal
    q_u, s_u = quantize_symmetric(x_uniform, bits=min(bits, 8))
    if bits > 8:
        # For >8 bits, simulate with scaled INT16
        alpha = np.max(np.abs(x_uniform))
        qmax = 2**(bits-1) - 1
        scale = alpha / qmax
        q_u = np.clip(np.round(x_uniform / scale), -qmax-1, qmax)
        r_u = q_u * scale
    else:
        r_u = dequantize(q_u, s_u)
    sqnr_u = compute_sqnr(x_uniform, r_u)

    # Quantize normal signal
    if bits <= 8:
        q_n, s_n = quantize_symmetric(x_normal, bits=bits)
        r_n = dequantize(q_n, s_n)
    else:
        alpha = np.max(np.abs(x_normal))
        qmax = 2**(bits-1) - 1
        scale = alpha / qmax
        q_n = np.clip(np.round(x_normal / scale), -qmax-1, qmax)
        r_n = q_n * scale
    sqnr_n = compute_sqnr(x_normal, r_n)

    noise_ratio = 10**(-sqnr_u / 10) * 100
    print(f'{bits:>5} {theory:>12.1f} {sqnr_u:>13.1f} {sqnr_n:>12.1f} {noise_ratio:>11.4f}%')

print(f'\n→ Each additional bit adds ~6 dB of SQNR (= 4× less noise power)')
print(f'→ Normal distribution SQNR is slightly worse because outliers waste quantisation levels')

Code cell 35

# ══════════════════════════════════════════════════════════════════
# 8.2 Lloyd-Max Optimal Quantisation
# ══════════════════════════════════════════════════════════════════

def lloyd_max(data: np.ndarray, n_levels: int, max_iter: int = 100) -> Tuple[np.ndarray, np.ndarray]:
    """Lloyd-Max algorithm: find optimal quantisation levels for given data.
    Returns (levels, boundaries)."""
    # Initialise levels uniformly
    levels = np.linspace(np.min(data), np.max(data), n_levels)

    for iteration in range(max_iter):
        # Step 1: boundaries = midpoints between adjacent levels
        boundaries = (levels[:-1] + levels[1:]) / 2

        # Step 2: levels = centroids of data in each bin
        new_levels = np.zeros_like(levels)
        all_boundaries = np.concatenate([[-np.inf], boundaries, [np.inf]])
        for i in range(n_levels):
            mask = (data >= all_boundaries[i]) & (data < all_boundaries[i+1])
            if np.sum(mask) > 0:
                new_levels[i] = np.mean(data[mask])
            else:
                new_levels[i] = levels[i]

        if np.allclose(levels, new_levels, atol=1e-8):
            break
        levels = new_levels

    return levels, boundaries

# Optimal quantisation for normal distribution
np.random.seed(42)
data = np.random.randn(100000).astype(np.float64)

print('Lloyd-Max Optimal Quantisation for N(0,1)')
print('=' * 65)

for bits in [2, 3, 4]:
    n_levels = 2**bits
    levels, boundaries = lloyd_max(data, n_levels)

    # Quantise and compute MSE
    indices = np.digitize(data, boundaries)
    quantized = levels[indices]
    mse_lm = np.mean((data - quantized)**2)

    # Compare with uniform quantisation
    alpha = np.max(np.abs(data))
    uniform_levels = np.linspace(-alpha, alpha, n_levels)
    uniform_boundaries = (uniform_levels[:-1] + uniform_levels[1:]) / 2
    u_indices = np.digitize(data, uniform_boundaries)
    u_quantized = uniform_levels[u_indices]
    mse_uniform = np.mean((data - u_quantized)**2)

    print(f'\n{bits}-bit ({n_levels} levels):')
    print(f'  Lloyd-Max levels: {np.array2string(levels, precision=3, separator=", ")}')
    print(f'  Lloyd-Max MSE:  {mse_lm:.6f}')
    print(f'  Uniform MSE:    {mse_uniform:.6f}')
    print(f'  Improvement:    {mse_uniform/mse_lm:.2f}× better')

print(f'\n→ At 4-bit, Lloyd-Max levels closely match the NF4 levels used in QLoRA!')
print(f'→ Non-uniform quantisation significantly outperforms uniform for bell-shaped distributions')

Code cell 36

# ══════════════════════════════════════════════════════════════════
# 8.3 Hadamard Transform for Outlier Suppression (QuIP/QuaRot)
# ══════════════════════════════════════════════════════════════════

def hadamard_matrix(n: int) -> np.ndarray:
    """Construct normalised Hadamard matrix of size n (n must be power of 2)."""
    if n == 1:
        return np.array([[1.0]])
    H_half = hadamard_matrix(n // 2)
    H = np.block([[H_half, H_half], [H_half, -H_half]]) / np.sqrt(2)
    return H

# Demonstrate outlier suppression
np.random.seed(42)
d = 64  # dimension

# Create a weight vector with outliers (realistic: some channels are 10-100× larger)
w = np.random.randn(d).astype(np.float64) * 0.1
w[0] = 10.0   # massive outlier
w[1] = -8.0   # another outlier
w[2] = 5.0    # moderate outlier

H = hadamard_matrix(d)

# Rotate weights
w_rotated = H @ w

print('Hadamard Transform for Outlier Suppression')
print('=' * 65)
print(f'Original weights:')
print(f'  Max |w|: {np.max(np.abs(w)):.4f}')
print(f'  Min |w|: {np.min(np.abs(w)):.4f}')
print(f'  Max/Min ratio: {np.max(np.abs(w)) / np.min(np.abs(w[w != 0])):.1f}×')
print(f'  Std of |w|: {np.std(np.abs(w)):.4f}')

print(f'\nAfter Hadamard rotation (w\' = Hw):')
print(f'  Max |w\'|: {np.max(np.abs(w_rotated)):.4f}')
print(f'  Min |w\'|: {np.min(np.abs(w_rotated)):.4f}')
print(f'  Max/Min ratio: {np.max(np.abs(w_rotated)) / np.min(np.abs(w_rotated[w_rotated != 0])):.1f}×')
print(f'  Std of |w\'|: {np.std(np.abs(w_rotated)):.4f}')

# Verify orthogonality: H @ H^T = I
identity_check = H @ H.T
print(f'\nOrthogonality check: ||HH^T - I|| = {np.linalg.norm(identity_check - np.eye(d)):.2e} (should be ~0)')

# Quantise both and compare
q_orig, s_orig = quantize_int4_symmetric(w.astype(np.float32))
r_orig = dequantize(q_orig, s_orig)
mse_orig = np.mean((w.astype(np.float32) - r_orig)**2)

q_rot, s_rot = quantize_int4_symmetric(w_rotated.astype(np.float32))
r_rot = dequantize(q_rot, s_rot)
w_recon = (H.T @ r_rot).astype(np.float32)  # rotate back
mse_hadamard = np.mean((w.astype(np.float32) - w_recon)**2)

print(f'\nINT4 quantisation comparison:')
print(f'  Direct quantisation MSE:   {mse_orig:.6f}')
print(f'  Hadamard + quantise MSE:   {mse_hadamard:.6f}')
print(f'  Improvement: {mse_orig / mse_hadamard:.1f}×')
print(f'\n→ Hadamard "spreads" outliers across all dimensions → more uniform range → better quantisation')

Code cell 37

# ══════════════════════════════════════════════════════════════════
# 8.4 Error Propagation Through Layers
# ══════════════════════════════════════════════════════════════════

np.random.seed(42)
d = 256
n_layers = 12

# Create a simple multi-layer network: y = W_L ... W_2 W_1 x
# Each weight matrix is quantised to INT8
x = np.random.randn(d).astype(np.float32)
x = x / np.linalg.norm(x)  # normalise input

# Create weight matrices (orthogonal for stability)
weights = []
for _ in range(n_layers):
    W = np.random.randn(d, d).astype(np.float32)
    U, _, Vt = np.linalg.svd(W, full_matrices=False)
    weights.append(U @ Vt)  # orthogonal matrix (preserves norms)

# Forward pass: exact vs quantised for different bit widths
print('Error Propagation Through Quantised Layers')
print('=' * 70)
print(f'{n_layers} layers, dimension {d}')
print(f'\n{"Bits":>5} {"After L=1":>12} {"After L=4":>12} {"After L=8":>12} {"After L=12":>13}')
print('-' * 60)

for bits in [2, 4, 8, 16]:
    # Exact forward pass
    y_exact = x.copy()
    for W in weights:
        y_exact = W @ y_exact

    # Quantised forward pass with error tracking
    y_quant = x.copy()
    errors = []
    y_ex = x.copy()
    for i, W in enumerate(weights):
        # Quantise weight matrix
        if bits <= 8:
            W_flat_q, W_s = quantize_symmetric(W.flatten(), bits=bits)
            W_deq = dequantize(W_flat_q, W_s).reshape(W.shape)
        else:
            alpha = np.max(np.abs(W))
            qmax = 2**(bits-1) - 1
            scale = alpha / qmax
            W_q = np.clip(np.round(W / scale), -qmax-1, qmax)
            W_deq = W_q * scale

        y_quant = W_deq @ y_quant
        y_ex = W @ y_ex

        if (i+1) in [1, 4, 8, 12]:
            rel_error = np.linalg.norm(y_quant - y_ex) / (np.linalg.norm(y_ex) + 1e-10)
            errors.append(rel_error)

    error_strs = [f'{e:>12.4e}' for e in errors]
    print(f'{bits:>5} {" ".join(error_strs)}')

print(f'\n→ Error grows roughly linearly with number of layers (first-order approximation)')
print(f'→ INT4 after 12 layers: significant error → keep first/last layers in higher precision')
print(f'→ INT8 error stays manageable even through 12 layers')

9. Mixed Precision Training — Complete Pipeline

The standard recipe for all large-scale LLM training since 2020:

  • FP32 master weights (never lost, never quantised)
  • BF16 forward/backward (fast, good range)
  • FP32 optimizer states (Adam m and v need high precision)
  • FP32 loss (cross-entropy needs log/exp precision)

Code cell 39

# ══════════════════════════════════════════════════════════════════
# 9.1 Simulating Mixed Precision Training
# ══════════════════════════════════════════════════════════════════

if HAS_TORCH:
    def simulate_training(model, dtype_forward, use_fp32_master=True, steps=200, lr=0.01):
        """Simulate training with different precision configurations."""
        torch.manual_seed(42)
        # Simple target function: y = sin(x)
        x_train = torch.linspace(-3, 3, 100).unsqueeze(1)
        y_train = torch.sin(x_train)

        if use_fp32_master:
            master_params = [p.clone().float() for p in model.parameters()]
        optimizer = torch.optim.SGD(model.parameters(), lr=lr)

        losses = []
        for step in range(steps):
            if use_fp32_master:
                # Copy FP32 master → working precision
                for p, mp in zip(model.parameters(), master_params):
                    p.data = mp.data.to(dtype_forward)

            # Forward pass in specified precision
            x_in = x_train.to(dtype_forward)
            y_pred = model(x_in)
            loss = F.mse_loss(y_pred, y_train.to(dtype_forward))
            losses.append(loss.float().item())

            # Backward
            optimizer.zero_grad()
            loss.backward()

            if use_fp32_master:
                # Update FP32 master weights with FP32 gradients
                for mp, p in zip(master_params, model.parameters()):
                    mp.data -= lr * p.grad.float()
            else:
                optimizer.step()

        return losses

    # Compare different precision configurations
    configs = [
        ('FP32 (baseline)', torch.float32, True),
        ('BF16 + FP32 master', torch.bfloat16, True),
        ('BF16 without master', torch.bfloat16, False),
        ('FP16 + FP32 master', torch.float16, True),
    ]

    print('Mixed Precision Training Comparison')
    print('=' * 70)
    print(f'{"Config":<25} {"Final Loss":>12} {"Min Loss":>12} {"Converged?":>12}')
    print('-' * 65)

    for name, dtype, use_master in configs:
        model = nn.Sequential(nn.Linear(1, 32), nn.ReLU(), nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 1))
        losses = simulate_training(model, dtype, use_master, steps=200, lr=0.005)
        final = losses[-1]
        min_loss = min(losses)
        converged = final < 0.1
        print(f'{name:<25} {final:>12.6f} {min_loss:>12.6f} {"✓" if converged else "✗":>12}')

    print(f'\n→ BF16 + FP32 master matches FP32 baseline')
    print(f'→ BF16 without FP32 master may diverge or stall on longer training')
    print(f'→ FP16 + FP32 master also works (with loss scaling for larger models)')
else:
    print('PyTorch required for mixed precision training demo')

Code cell 40

# ══════════════════════════════════════════════════════════════════
# 9.2 Memory Budget Calculator
# ══════════════════════════════════════════════════════════════════

def memory_budget(n_params: float, mode: str = 'train_bf16') -> dict:
    """Calculate memory requirements for different configurations."""
    configs = {
        'train_fp32': {
            'weights': 4,       # FP32
            'adam_m': 4,        # FP32
            'adam_v': 4,        # FP32
            'gradients': 4,    # FP32
            'activations': 4,  # FP32 (estimate: ~1× params)
        },
        'train_bf16': {
            'weights_master': 4,  # FP32 master
            'weights_working': 2, # BF16
            'adam_m': 4,          # FP32
            'adam_v': 4,          # FP32
            'gradients': 2,      # BF16
            'activations': 2,    # BF16
        },
        'qlora_nf4': {
            'base_weights': 0.5,  # NF4 (4 bits)
            'lora_weights': 0.01, # BF16 but << 1% of params
            'adam_m': 0.01,       # only for LoRA params
            'adam_v': 0.01,       # only for LoRA params
            'gradients': 0.01,
            'activations': 2,    # BF16
        },
        'inference_bf16': {'weights': 2},
        'inference_int8': {'weights': 1},
        'inference_int4': {'weights': 0.5},
    }

    config = configs[mode]
    total_bytes = sum(v * n_params for v in config.values())
    return {
        'components': {k: v * n_params / 1e9 for k, v in config.items()},
        'total_gb': total_bytes / 1e9
    }

# Calculate for common model sizes
print('Memory Budget Calculator')
print('=' * 80)

for model_name, n_params in [('7B', 7e9), ('13B', 13e9), ('70B', 70e9), ('405B', 405e9)]:
    print(f'\n--- {model_name} Parameters ---')
    print(f'{"Mode":<22} {"Total (GB)":>10} {"Fits in 80GB?":>14} {"GPUs needed":>12}')
    print('-' * 62)
    for mode in ['train_fp32', 'train_bf16', 'qlora_nf4', 'inference_bf16', 'inference_int8', 'inference_int4']:
        budget = memory_budget(n_params, mode)
        total = budget['total_gb']
        fits = total <= 80
        gpus = max(1, int(np.ceil(total / 80)))
        print(f'{mode:<22} {total:>10.1f} {"✓" if fits else "✗":>14} {gpus:>12}')

print(f'\n→ QLoRA enables 70B fine-tuning on a single 80GB GPU!')
print(f'→ INT4 inference enables 70B serving on a single GPU')
print(f'→ Full FP32 training of 70B requires 14+ GPUs just for weight storage')

10. Hardware & Memory Analysis

For LLM inference, memory bandwidth (not compute) is the bottleneck. Smaller number formats directly reduce memory traffic → direct speedup.

Code cell 42

# ══════════════════════════════════════════════════════════════════
# 10.1 Arithmetic Intensity Analysis
# ══════════════════════════════════════════════════════════════════

print('Arithmetic Intensity — Why Quantisation Speeds Up Inference')
print('=' * 70)

# For autoregressive generation (batch_size=1), each token requires:
# - Loading full weight matrix: d² × bytes_per_weight
# - Computing matrix-vector multiply: 2 × d² FLOPs
# Arithmetic intensity = FLOPs / bytes = 2 / bytes_per_weight

# H100 specs
h100_bandwidth = 3.35e12  # bytes/sec (HBM3)
h100_fp16_flops = 989.5e12  # TFLOPS
h100_int8_ops = 1979e12    # TOPS

# Compute roofline
print(f'H100 SXM5 Specifications:')
print(f'  HBM3 Bandwidth: {h100_bandwidth/1e12:.2f} TB/s')
print(f'  FP16 Compute:   {h100_fp16_flops/1e12:.1f} TFLOPS')
print(f'  INT8 Compute:   {h100_int8_ops/1e12:.0f} TOPS')

ridge_point = h100_fp16_flops / h100_bandwidth
print(f'  Ridge point: {ridge_point:.1f} FLOPs/byte')

print(f'\n{"Format":>10} {"Bytes":>6} {"Arith Intensity":>16} {"Bound":>10} {"Theoretical speedup":>20}')
print('-' * 68)
for name, bpw in [('FP32', 4), ('BF16', 2), ('INT8', 1), ('INT4', 0.5), ('INT2', 0.25)]:
    ai = 2 / bpw
    bound = 'Compute' if ai >= ridge_point else 'Memory'
    # Speedup relative to FP32 (memory-bound)
    speedup = 4 / bpw  # linear with compression ratio when memory-bound
    print(f'{name:>10} {bpw:>6.1f} {ai:>16.1f} {bound:>10} {speedup:>18.0f}×')

print(f'\n→ For single-token generation, ALL formats are memory-bound on H100')
print(f'→ This means reducing weight precision from BF16→INT4 gives ~4× real speedup')
print(f'→ The speedup comes from less data to load, not faster compute')

Code cell 43

# ══════════════════════════════════════════════════════════════════
# 10.2 Energy Cost Analysis
# ══════════════════════════════════════════════════════════════════

# Approximate energy per operation (7nm process, picojoules)
energy_pj = {
    'INT1 XNOR': 0.02,
    'INT4 MAC': 0.05,
    'INT8 MAC': 0.2,
    'FP8 FMA': 0.4,
    'BF16 FMA': 0.8,
    'FP16 FMA': 1.0,
    'FP32 FMA': 3.7,
    'DRAM read (64B)': 12.5,
}

print('Energy Cost per Operation (7nm process)')
print('=' * 55)
print(f'{"Operation":<18} {"Energy (pJ)":>12} {"Relative to INT8":>18}')
print('-' * 50)
ref = energy_pj['INT8 MAC']
for op, pj in energy_pj.items():
    print(f'{op:<18} {pj:>12.2f} {pj/ref:>18.1f}×')

print(f'\n⚡ Key insight: DRAM access costs {energy_pj["DRAM read (64B)"]/energy_pj["INT8 MAC"]:.0f}× more '
      f'energy than an INT8 MAC!')
print(f'→ Most inference energy is spent moving data, not computing')
print(f'→ Smaller formats reduce BOTH compute AND memory energy')

# Estimate energy for one forward pass of a 70B model
n_params = 70e9
n_ops = 2 * n_params  # ~2 FLOPs per parameter for a forward pass

print(f'\nEstimated energy for one 70B forward pass (single token):')
print(f'{"Format":<10} {"Compute (J)":>12} {"Memory (J)":>12} {"Total (J)":>12} {"Relative":>10}')
print('-' * 60)
for name, compute_pj, bpw in [
    ('FP32', 3.7, 4), ('BF16', 0.8, 2), ('INT8', 0.2, 1), ('INT4', 0.05, 0.5)]:
    compute_j = n_ops * compute_pj * 1e-12
    memory_bytes = n_params * bpw
    memory_j = (memory_bytes / 64) * energy_pj['DRAM read (64B)'] * 1e-12
    total_j = compute_j + memory_j
    print(f'{name:<10} {compute_j:>12.1f} {memory_j:>12.1f} {total_j:>12.1f} {total_j/(2*70e9*3.7e-12 + (70e9*4/64)*12.5e-12):>10.2f}×')

11. Training Stability & Precision

Precision failures manifest as training failures: loss spikes, NaN crashes, or silent divergence. Understanding the numerical mechanisms enables prevention.

Code cell 45

# ══════════════════════════════════════════════════════════════════
# 11.1 Stochastic Rounding — Unbiased Low-Precision Updates
# ══════════════════════════════════════════════════════════════════

def round_to_nearest(x: float, resolution: float) -> float:
    """Deterministic round-to-nearest."""
    return round(x / resolution) * resolution

def stochastic_round(x: float, resolution: float) -> float:
    """Stochastic rounding: E[SR(x)] = x."""
    lower = np.floor(x / resolution) * resolution
    upper = lower + resolution
    prob_up = (x - lower) / resolution
    return upper if np.random.random() < prob_up else lower

# Simulate gradient accumulation over many steps
resolution = 0.0078125  # BF16 ULP near 1.0 (= 2^(-7))
gradient = 0.001  # small gradient (< resolution/2)
n_steps = 10000

# Deterministic accumulation
weight_det = 1.0
for _ in range(n_steps):
    weight_det = round_to_nearest(weight_det + gradient, resolution)

# Stochastic accumulation
np.random.seed(42)
weight_stoch = 1.0
for _ in range(n_steps):
    weight_stoch = stochastic_round(weight_stoch + gradient, resolution)

expected = 1.0 + n_steps * gradient

print('Stochastic Rounding vs Deterministic (RNE)')
print('=' * 60)
print(f'BF16 resolution (ULP near 1.0): {resolution}')
print(f'Gradient per step: {gradient}')
print(f'Gradient / resolution = {gradient/resolution:.3f} (< 0.5 → always rounds to same value!)')
print(f'Number of steps: {n_steps:,}')
print(f'\nExpected final weight: {expected:.1f}')
print(f'Deterministic (RNE):   {weight_det:.4f}  (error: {abs(weight_det - expected):.4f})')
print(f'Stochastic rounding:   {weight_stoch:.4f}  (error: {abs(weight_stoch - expected):.4f})')
print(f'\n→ Deterministic: gradient is ALWAYS rounded to zero → weight NEVER updates!')
print(f'→ Stochastic: gradient contributes probabilistically → correct on average')
print(f'\n→ This is why FP8 training on H100 uses stochastic rounding — without it,')
print(f'  many gradient updates would be silently lost.')

# Statistical verification: E[SR(x)] = x
print(f'\n--- Bias verification ---')
np.random.seed(0)
x_test = 1.003  # between two BF16-representable values
sr_results = [stochastic_round(x_test, resolution) for _ in range(100000)]
print(f'  x = {x_test}')
print(f'  E[SR(x)] over 100K trials: {np.mean(sr_results):.6f}  (should be ≈ {x_test})')
print(f'  RNE(x) = {round_to_nearest(x_test, resolution):.6f}  (biased!)')

Code cell 46

# ══════════════════════════════════════════════════════════════════
# 11.2 Adam Optimizer Numerical Edge Cases
# ══════════════════════════════════════════════════════════════════

print('Adam Optimizer — Numerical Failure Modes')
print('=' * 65)

# Failure mode 1: epsilon too small
print('\n--- Failure Mode 1: ε too small ---')
eta = 1e-4  # learning rate
m_t = 1e-4  # momentum (typical small gradient)
for eps in [1e-8, 1e-6, 1e-4]:
    for v_t in [1e-8, 1e-4, 1.0]:
        update = eta * m_t / (np.sqrt(v_t) + eps)
        print(f'  ε={eps:.0e}, v_t={v_t:.0e} → update = {update:.6f}'
              f'{"  ⚠ HUGE!" if update > 0.01 else ""}')

print(f'\n  When v_t ≈ 0 and ε = 1e-8: update ≈ η·m_t/ε = {eta * m_t / 1e-8:.0f}')
print(f'  → This can cause catastrophic weight jumps!')
print(f'  → Use ε = 1e-6 or 1e-4 for BF16 training')

# Failure mode 2: bias correction amplification
print(f'\n--- Failure Mode 2: Bias correction early in training ---')
beta2 = 0.999
print(f'  Bias correction factor 1/(1 - β₂^t) for β₂ = {beta2}:')
for t in [1, 5, 10, 50, 100, 1000, 10000]:
    correction = 1 / (1 - beta2**t)
    print(f'    t={t:>5}: 1/(1-β₂^t) = {correction:>10.1f}×'
          f'{"  ← 1000× amplification!" if correction > 100 else ""}')

print(f'\n  → At t=1, v_t is amplified 1000×! This can cause overflow in low precision.')
print(f'  → Learning rate warmup helps by keeping η small during high-amplification phase.')

Code cell 47

# ══════════════════════════════════════════════════════════════════
# 11.3 Attention Logit Growth — The Slow Training Killer
# ══════════════════════════════════════════════════════════════════

if HAS_TORCH:
    d_k = 128  # head dimension

    print('Attention Logit Growth Simulation')
    print('=' * 65)
    print(f'Head dimension d_k = {d_k}')
    print(f'Standard scaling: QK^T / √d_k = QK^T / {np.sqrt(d_k):.1f}')

    # Simulate Q, K with increasing norms (as happens during training)
    print(f'\n{"||Q||":>8} {"Max logit":>10} {"Softmax max":>13} {"Grad magnitude":>15} {"Risk":>12}')
    print('-' * 65)

    for qk_scale in [1.0, 2.0, 5.0, 10.0, 15.0, 20.0, 25.0]:
        torch.manual_seed(42)
        seq_len = 64
        Q = torch.randn(1, seq_len, d_k) * qk_scale
        K = torch.randn(1, seq_len, d_k) * qk_scale

        # Attention logits
        logits = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
        max_logit = logits.max().item()

        # Softmax
        attn = F.softmax(logits, dim=-1)
        max_attn = attn.max().item()

        # Gradient magnitude (proxy: entropy of attention distribution)
        entropy = -(attn * torch.log(attn + 1e-10)).sum(-1).mean().item()
        grad_proxy = entropy / np.log(seq_len)  # normalised entropy

        # Risk assessment
        if max_logit > 88:
            risk = '💥 OVERFLOW'
        elif max_attn > 0.999:
            risk = '⚠ Near-1hot'
        elif max_attn > 0.99:
            risk = '⚠ Spiky'
        else:
            risk = '✓ Safe'

        print(f'{qk_scale:>8.1f} {max_logit:>10.1f} {max_attn:>13.6f} {grad_proxy:>15.6f} {risk:>12}')

    print(f'\n→ As Q, K norms grow during training, attention becomes "spiky"')
    print(f'→ Spiky attention → vanishing gradients → learning stops for that head')
    print(f'→ Eventually max logit > 88 → exp() overflow → NaN → training crash')
    print(f'\nMitigations:')
    print(f'  1. QK-Norm: normalise Q and K before computing attention')
    print(f'  2. Logit capping: clamp logits to [-50, 50] before softmax')
    print(f'  3. Gradient clipping: clip global norm to prevent explosive updates')
else:
    print('PyTorch required for attention logit demo')

12. Practical Guide & Real-World Applications

Code cell 49

# ══════════════════════════════════════════════════════════════════
# 12.1 Real PyTorch Quantisation Demo
# ══════════════════════════════════════════════════════════════════

if HAS_TORCH:
    print('PyTorch Quantisation — Production Code')
    print('=' * 55)

    # Build a small model
    model = nn.Sequential(
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    )

    # Size before quantisation
    orig_size = sum(p.numel() * p.element_size() for p in model.parameters())

    # Dynamic quantisation (real PyTorch API)
    quantized_model = torch.quantization.quantize_dynamic(
        model,
        {nn.Linear},
        dtype=torch.qint8
    )

    # Compare outputs
    x = torch.randn(1, 512)
    with torch.no_grad():
        y_orig = model(x)
        y_quant = quantized_model(x)

    diff = (y_orig - y_quant).abs().mean().item()
    print(f'Original model: {orig_size:,} bytes')
    print(f'Output difference (FP32 vs INT8): {diff:.6f}')

    # Show quantised model structure
    print(f'\nQuantised model layers:')
    for name, module in quantized_model.named_modules():
        if name:
            print(f'  {name}: {module.__class__.__name__}')

    # Show production LLM quantisation code
    print(f'\n--- LLM Quantisation in Practice ---')
    print('''
    # Method 1: bitsandbytes (easiest)
    from transformers import AutoModelForCausalLM
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-3-8b",
        load_in_4bit=True,                    # ← NF4 quantisation
        bnb_4bit_compute_dtype=torch.bfloat16 # ← BF16 compute
    )

    # Method 2: GPTQ (highest quality INT4)
    model = AutoModelForCausalLM.from_pretrained(
        "TheBloke/Llama-3-8B-GPTQ",
        device_map="auto"
    )

    # Method 3: AWQ (fastest INT4 inference)
    from awq import AutoAWQForCausalLM
    model = AutoAWQForCausalLM.from_quantized(
        "TheBloke/Llama-3-8B-AWQ"
    )

    # Method 4: QLoRA fine-tuning
    from peft import prepare_model_for_kbit_training, LoraConfig
    model = prepare_model_for_kbit_training(model)
    lora_config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"])
    ''')
else:
    print('PyTorch required for quantisation demo')

Code cell 50

# ══════════════════════════════════════════════════════════════════
# 12.2 KV Cache Memory Calculator
# ══════════════════════════════════════════════════════════════════

def kv_cache_memory(n_layers: int, n_kv_heads: int, d_head: int,
                    seq_len: int, batch_size: int, bytes_per_element: float) -> float:
    """Calculate KV cache memory in GB."""
    # 2 for K and V, × layers × heads × head_dim × seq_len × batch × bytes
    return 2 * n_layers * n_kv_heads * d_head * seq_len * batch_size * bytes_per_element / 1e9

print('KV Cache Memory Analysis')
print('=' * 70)

# Common LLM configurations
models = {
    'LLaMA-3 8B':  {'n_layers': 32, 'n_kv_heads': 8,  'd_head': 128},
    'LLaMA-3 70B': {'n_layers': 80, 'n_kv_heads': 8,  'd_head': 128},
    'GPT-4 (est)': {'n_layers': 120, 'n_kv_heads': 16, 'd_head': 128},
}

for model_name, config in models.items():
    print(f'\n--- {model_name} ---')
    print(f'  {config["n_layers"]}L × {config["n_kv_heads"]} KV heads × {config["d_head"]}d')
    print(f'  {"Seq Len":>10} {"BF16 (GB)":>10} {"INT8 (GB)":>10} {"INT4 (GB)":>10}')
    print(f'  {"":>10} {"-"*10} {"-"*10} {"-"*10}')
    for seq_len in [2048, 8192, 32768, 131072]:
        bf16 = kv_cache_memory(**config, seq_len=seq_len, batch_size=1, bytes_per_element=2)
        int8 = kv_cache_memory(**config, seq_len=seq_len, batch_size=1, bytes_per_element=1)
        int4 = kv_cache_memory(**config, seq_len=seq_len, batch_size=1, bytes_per_element=0.5)
        print(f'  {seq_len:>10,} {bf16:>10.2f} {int8:>10.2f} {int4:>10.2f}')

print(f'\n→ 70B model with 128K context: KV cache alone is 34 GB in BF16!')
print(f'→ INT4 KV cache reduces this to 8.6 GB — critical for long-context serving')
print(f'→ Per-channel INT8 KV quantisation is nearly lossless and should always be used')

Code cell 51

# ══════════════════════════════════════════════════════════════════
# 12.3 Per-Layer Quantisation Sensitivity Analysis
# ══════════════════════════════════════════════════════════════════

if HAS_TORCH:
    # Build a mini transformer-like model to test per-layer sensitivity
    torch.manual_seed(42)

    class MiniTransformerBlock(nn.Module):
        def __init__(self, d=256):
            super().__init__()
            self.q_proj = nn.Linear(d, d)
            self.k_proj = nn.Linear(d, d)
            self.v_proj = nn.Linear(d, d)
            self.out_proj = nn.Linear(d, d)
            self.ffn_gate = nn.Linear(d, d * 4)
            self.ffn_down = nn.Linear(d * 4, d)
            self.norm = nn.LayerNorm(d)

        def forward(self, x):
            # Simplified attention (no actual attention pattern)
            q = self.q_proj(x)
            k = self.k_proj(x)
            v = self.v_proj(x)
            attn_out = self.out_proj(v)  # simplified
            x = self.norm(x + attn_out)
            # FFN
            ffn = F.silu(self.ffn_gate(x))
            ffn = self.ffn_down(ffn)
            return self.norm(x + ffn)

    model = MiniTransformerBlock(256)
    x = torch.randn(1, 32, 256)

    with torch.no_grad():
        y_baseline = model(x)

    # Test sensitivity: quantise each layer individually
    print('Per-Layer Quantisation Sensitivity (INT4)')
    print('=' * 60)
    print(f'{"Layer":<15} {"Output Δ (L2)":>14} {"Relative Δ":>12} {"Sensitivity":>12}')
    print('-' * 55)

    for layer_name in ['q_proj', 'k_proj', 'v_proj', 'out_proj', 'ffn_gate', 'ffn_down']:
        # Clone model and quantise only this layer
        test_model = MiniTransformerBlock(256)
        test_model.load_state_dict(model.state_dict())

        layer = getattr(test_model, layer_name)
        with torch.no_grad():
            w = layer.weight.data.numpy()
            q, s = quantize_int4_symmetric(w.flatten())
            w_deq = dequantize(q, s).reshape(w.shape)
            layer.weight.data = torch.tensor(w_deq)

        with torch.no_grad():
            y_quant = test_model(x)

        delta = torch.norm(y_baseline - y_quant).item()
        relative = delta / torch.norm(y_baseline).item()
        sensitivity = '🔴 HIGH' if relative > 0.1 else ('🟡 MED' if relative > 0.01 else '🟢 LOW')
        print(f'{layer_name:<15} {delta:>14.6f} {relative:>12.4%} {sensitivity:>12}')

    print(f'\n→ ffn_down is the most sensitive layer (directly affects residual stream)')
    print(f'→ Strategy: keep ffn_down in INT8, quantise others to INT4')
else:
    print('PyTorch required for per-layer sensitivity demo')

Summary: Key Takeaways

ConceptWhy It MattersAction
IEEE 754 layoutDebug NaN, precision bugsInspect bits with struct.pack
Machine epsilonMinimum useful learning rateCheck np.finfo() for your dtype
BF16 > FP16Same range as FP32, no overflowAlways use BF16 for training
Softmax stabilityPrevents NaN in every attention layerAlways subtract max first
Log-sum-expCross-entropy loss without overflowUse torch.logsumexp()
RMSNorm > LayerNormNo catastrophic cancellationPreferred in 2024+ architectures
Per-group quantisationMuch better than per-tensorUse group_size=64 or 128
NF4 (QLoRA)Optimal for Gaussian-distributed weightsload_in_4bit=True
SQNR: 6 dB/bitEach bit doubles precisionChoose bit width by quality need
Hadamard rotationSuppresses outliers for better quantUsed in QuIP, QuaRot
FP32 master weightsPrevent training divergenceNEVER skip this
Stochastic roundingEnables sub-8-bit trainingUsed in FP8 on H100
Memory bandwidthThe true inference bottleneckSmaller = faster (linear!)
Error propagationQuantisation error grows through layersKeep first/last layers higher precision
KV cache quantisationEssential for long-context servingINT8 per-channel is nearly lossless

Best Practices

  1. Training: BF16 forward/backward + FP32 master weights + FP32 Adam states
  2. Inference (balanced): INT8 weights with SmoothQuant — nearly lossless
  3. Inference (compressed): INT4 with GPTQ or AWQ + per-group quantisation
  4. Fine-tuning (budget): QLoRA with NF4 base + BF16 LoRA adapters
  5. Long context: INT8 KV cache per-channel quantisation
  6. Never compare floats with == → use np.isclose() or torch.allclose()
  7. Never compute log(sum(exp(x))) → use torch.logsumexp()

Practice Questions

  1. IEEE 754: Encode -13.625 in FP32 binary. Show sign, exponent, mantissa.
  2. Precision: Why does 1.0 + 1e-8 == 1.0 in FP32 but not FP64?
  3. BF16 vs FP16: A gradient of 70000.0 — which format handles it without overflow?
  4. Softmax: Given logits [89.0, 89.1, 89.2], what happens with naive softmax in FP32?
  5. SQNR: Calculate the theoretical SQNR for 3-bit uniform quantisation.
  6. Memory: How much VRAM for a 13B model in INT4 inference? In BF16 training?
  7. Stochastic rounding: If gradient=0.003 and BF16 ULP=0.0078, what does RNE do? SR?
  8. KV Cache: Calculate the KV cache size for a 32-layer, 32-head, d=128 model at 32K context.
  9. Hadamard: Why does rotating weights before quantisation reduce error?
  10. Error propagation: Why are the first and last layers of an LLM more sensitive to quantisation?

Next Steps

  • exercises.ipynb: Extended practice problems with detailed solutions
  • notes.md: Complete 17-section mathematical reference
  • Continue to: 02-Sets-and-Logic