Theory NotebookMath for LLMs

Training at Scale

Math for LLMs / Training at Scale

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Training at Scale: Theory Notebook

This notebook turns scale-training formulas into small numerical checks. The examples are toy sized, but the units and failure modes are the same ones used in real LLM training plans.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

1. AdamW update by hand

Code cell 4

theta = 1.5
grad = 0.30
m = 0.0
v = 0.0
beta1, beta2 = 0.9, 0.999
lr, eps, weight_decay = 1e-3, 1e-8, 0.1
t = 1

m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * grad**2
m_hat = m / (1 - beta1**t)
v_hat = v / (1 - beta2**t)
theta_next = theta - lr * m_hat / (np.sqrt(v_hat) + eps) - lr * weight_decay * theta
print("m_hat:", m_hat)
print("v_hat:", v_hat)
print("theta_next:", theta_next)

2. Gradient clipping

Code cell 6

g = np.array([3.0, 4.0, 12.0])
clip = 5.0
norm = np.linalg.norm(g)
scale = min(1.0, clip / norm)
g_clipped = g * scale
print("original norm:", norm)
print("scale:", scale)
print("clipped norm:", np.linalg.norm(g_clipped))

3. Warmup plus cosine decay

Code cell 8

steps = np.arange(0, 1000)
warmup = 100
lr_max = 3e-4
lr_min = 3e-5

lr = np.empty_like(steps, dtype=float)
for i, step in enumerate(steps):
    if step < warmup:
        lr[i] = lr_max * (step + 1) / warmup
    else:
        s = (step - warmup) / (len(steps) - warmup - 1)
        lr[i] = lr_min + 0.5 * (lr_max - lr_min) * (1 + np.cos(np.pi * s))

fig, ax = plt.subplots(figsize=(9, 4))
ax.plot(steps, lr)
ax.set_title("Warmup plus cosine decay")
ax.set_xlabel("training step")
ax.set_ylabel("learning rate")
fig.tight_layout()
plt.show()
print("first, peak, last:", lr[0], lr.max(), lr[-1])

4. Effective batch size

Code cell 10

micro_batch = 2
sequence_length = 4096
data_parallel = 64
grad_accum = 8
tokens_per_step = micro_batch * sequence_length * data_parallel * grad_accum
print("effective sequences per optimizer step:", micro_batch * data_parallel * grad_accum)
print("tokens per optimizer step:", tokens_per_step)

5. Memory accounting

Code cell 12

P = 7e9
bytes_bf16 = 2
bytes_fp32 = 4
weights = P * bytes_bf16
grads = P * bytes_bf16
adam_moments = 2 * P * bytes_fp32
total = weights + grads + adam_moments
print("weights GB:", weights / 1e9)
print("grads GB:", grads / 1e9)
print("Adam moments GB:", adam_moments / 1e9)
print("training state GB before activations:", total / 1e9)

6. ZeRO-style sharding stages

Code cell 14

world = 8
weights = 14.0
grads = 14.0
optim = 56.0

stage0 = weights + grads + optim
stage1 = weights + grads + optim / world
stage2 = weights + grads / world + optim / world
stage3 = weights / world + grads / world + optim / world

for name, gb in [("replicated", stage0), ("stage1", stage1), ("stage2", stage2), ("stage3", stage3)]:
    print(f"{name:>10s}: {gb:.2f} GB per rank before activations")

7. Activation checkpointing tradeoff

Code cell 16

layers = np.arange(1, 49)
activation_per_layer_gb = 0.35
saved_without = layers * activation_per_layer_gb
saved_with_segments = np.sqrt(layers) * activation_per_layer_gb

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(layers, saved_without, label="store every layer")
ax.plot(layers, saved_with_segments, label="checkpoint segments")
ax.set_title("Activation memory scaling intuition")
ax.set_xlabel("layers")
ax.set_ylabel("approx saved activation GB")
ax.legend()
fig.tight_layout()
plt.show()
print("48-layer rough memory no checkpoint:", saved_without[-1])
print("48-layer rough memory checkpointed:", saved_with_segments[-1])

8. Pipeline bubble

Code cell 18

stages = 8
micro_batches = np.arange(1, 65)
bubble = (stages - 1) / (micro_batches + stages - 1)

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(micro_batches, bubble)
ax.set_title("Pipeline bubble fraction")
ax.set_xlabel("micro-batches")
ax.set_ylabel("bubble fraction")
fig.tight_layout()
plt.show()
for m in [1, 4, 16, 64]:
    print(f"micro_batches={m:>2}: bubble={(stages-1)/(m+stages-1):.3f}")

9. All-reduce cost model

Code cell 20

payload_gb = np.array([0.5, 1.0, 2.0, 4.0])
bandwidth_gbps = 200.0
latency_ms = 0.05
time_ms = latency_ms + payload_gb / bandwidth_gbps * 1000
for payload, ms in zip(payload_gb, time_ms):
    print(f"payload={payload:.1f} GB -> approx {ms:.2f} ms")

10. FLOPs and MFU

Code cell 22

params = 7e9
tokens = 300e9
train_flops = 6 * params * tokens
days = 14
seconds = days * 24 * 3600
useful_flops_per_sec = train_flops / seconds
hardware_peak = 512 * 312e12
mfu = useful_flops_per_sec / hardware_peak
print("training FLOPs:", f"{train_flops:.3e}")
print("useful FLOPs/sec:", f"{useful_flops_per_sec:.3e}")
print("MFU:", round(mfu, 3))

11. Compute-optimal tradeoff surface

Code cell 24

N = np.logspace(8, 11, 40)
D = np.logspace(9, 12, 40)
NN, DD = np.meshgrid(N, D)
compute = 6 * NN * DD
loss_proxy = 1.7 + 0.8 * (NN / 1e9) ** -0.08 + 0.6 * (DD / 1e10) ** -0.10

fig, ax = plt.subplots(figsize=(7, 5))
cs = ax.contour(np.log10(NN), np.log10(DD), np.log10(compute), levels=8, colors="gray", alpha=0.5)
im = ax.contourf(np.log10(NN), np.log10(DD), loss_proxy, levels=16)
fig.colorbar(im, ax=ax, label="proxy loss")
ax.clabel(cs, inline=True, fontsize=8)
ax.set_title("Toy loss surface with compute contours")
ax.set_xlabel("log10 parameters")
ax.set_ylabel("log10 tokens")
fig.tight_layout()
plt.show()
print("This is a teaching proxy, not an empirical scaling-law fit.")

12. Token packing utilization

Code cell 26

lengths = np.array([128, 512, 900, 64, 256, 700, 1024, 300])
block = 1024
padded_tokens = len(lengths) * block
real_tokens = lengths.sum()
naive_util = real_tokens / padded_tokens
packed_blocks = int(np.ceil(real_tokens / block))
packed_util = real_tokens / (packed_blocks * block)
print("naive utilization:", round(naive_util, 3))
print("packed utilization:", round(packed_util, 3))

13. Checkpoint overhead

Code cell 28

step_time = 2.0
checkpoint_time = 180.0
intervals = np.array([100, 500, 1000, 5000])
overhead_fraction = checkpoint_time / (intervals * step_time + checkpoint_time)
for interval, overhead in zip(intervals, overhead_fraction):
    print(f"checkpoint every {interval:>4} steps -> overhead {100*overhead:.2f}%")

14. Loss spike detection

Code cell 30

rng = np.random.default_rng(7)
loss = 3.0 - 0.002 * np.arange(200) + rng.normal(0, 0.015, 200)
loss[120] += 0.45
window = 20
median = np.array([np.median(loss[max(0, i-window):i+1]) for i in range(len(loss))])
spikes = np.where(loss > median + 0.20)[0]
print("detected spike steps:", spikes.tolist())

fig, ax = plt.subplots(figsize=(9, 4))
ax.plot(loss, label="loss")
ax.plot(median, label="rolling median")
ax.scatter(spikes, loss[spikes], color="red", zorder=3, label="spike")
ax.set_title("Toy loss spike detector")
ax.set_xlabel("step")
ax.set_ylabel("loss")
ax.legend()
fig.tight_layout()
plt.show()

15. Final launch checklist

Code cell 32

checks = [
    "small run learns and resumes exactly",
    "labels are shifted and masks are tested",
    "memory estimate includes optimizer states and activations",
    "global batch and token budget are explicit",
    "parallelism product equals world size",
    "validation loss and throughput are logged separately",
    "checkpoint includes model, optimizer, scheduler, RNG, and dataloader state",
]
for i, check in enumerate(checks, 1):
    print(f"{i}. {check}")