Theory Notebook
Theory Notebook
Converted from
theory.ipynbfor web reading.
Training at Scale: Theory Notebook
This notebook turns scale-training formulas into small numerical checks. The examples are toy sized, but the units and failure modes are the same ones used in real LLM training plans.
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
1. AdamW update by hand
Code cell 4
theta = 1.5
grad = 0.30
m = 0.0
v = 0.0
beta1, beta2 = 0.9, 0.999
lr, eps, weight_decay = 1e-3, 1e-8, 0.1
t = 1
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * grad**2
m_hat = m / (1 - beta1**t)
v_hat = v / (1 - beta2**t)
theta_next = theta - lr * m_hat / (np.sqrt(v_hat) + eps) - lr * weight_decay * theta
print("m_hat:", m_hat)
print("v_hat:", v_hat)
print("theta_next:", theta_next)
2. Gradient clipping
Code cell 6
g = np.array([3.0, 4.0, 12.0])
clip = 5.0
norm = np.linalg.norm(g)
scale = min(1.0, clip / norm)
g_clipped = g * scale
print("original norm:", norm)
print("scale:", scale)
print("clipped norm:", np.linalg.norm(g_clipped))
3. Warmup plus cosine decay
Code cell 8
steps = np.arange(0, 1000)
warmup = 100
lr_max = 3e-4
lr_min = 3e-5
lr = np.empty_like(steps, dtype=float)
for i, step in enumerate(steps):
if step < warmup:
lr[i] = lr_max * (step + 1) / warmup
else:
s = (step - warmup) / (len(steps) - warmup - 1)
lr[i] = lr_min + 0.5 * (lr_max - lr_min) * (1 + np.cos(np.pi * s))
fig, ax = plt.subplots(figsize=(9, 4))
ax.plot(steps, lr)
ax.set_title("Warmup plus cosine decay")
ax.set_xlabel("training step")
ax.set_ylabel("learning rate")
fig.tight_layout()
plt.show()
print("first, peak, last:", lr[0], lr.max(), lr[-1])
4. Effective batch size
Code cell 10
micro_batch = 2
sequence_length = 4096
data_parallel = 64
grad_accum = 8
tokens_per_step = micro_batch * sequence_length * data_parallel * grad_accum
print("effective sequences per optimizer step:", micro_batch * data_parallel * grad_accum)
print("tokens per optimizer step:", tokens_per_step)
5. Memory accounting
Code cell 12
P = 7e9
bytes_bf16 = 2
bytes_fp32 = 4
weights = P * bytes_bf16
grads = P * bytes_bf16
adam_moments = 2 * P * bytes_fp32
total = weights + grads + adam_moments
print("weights GB:", weights / 1e9)
print("grads GB:", grads / 1e9)
print("Adam moments GB:", adam_moments / 1e9)
print("training state GB before activations:", total / 1e9)
6. ZeRO-style sharding stages
Code cell 14
world = 8
weights = 14.0
grads = 14.0
optim = 56.0
stage0 = weights + grads + optim
stage1 = weights + grads + optim / world
stage2 = weights + grads / world + optim / world
stage3 = weights / world + grads / world + optim / world
for name, gb in [("replicated", stage0), ("stage1", stage1), ("stage2", stage2), ("stage3", stage3)]:
print(f"{name:>10s}: {gb:.2f} GB per rank before activations")
7. Activation checkpointing tradeoff
Code cell 16
layers = np.arange(1, 49)
activation_per_layer_gb = 0.35
saved_without = layers * activation_per_layer_gb
saved_with_segments = np.sqrt(layers) * activation_per_layer_gb
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(layers, saved_without, label="store every layer")
ax.plot(layers, saved_with_segments, label="checkpoint segments")
ax.set_title("Activation memory scaling intuition")
ax.set_xlabel("layers")
ax.set_ylabel("approx saved activation GB")
ax.legend()
fig.tight_layout()
plt.show()
print("48-layer rough memory no checkpoint:", saved_without[-1])
print("48-layer rough memory checkpointed:", saved_with_segments[-1])
8. Pipeline bubble
Code cell 18
stages = 8
micro_batches = np.arange(1, 65)
bubble = (stages - 1) / (micro_batches + stages - 1)
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(micro_batches, bubble)
ax.set_title("Pipeline bubble fraction")
ax.set_xlabel("micro-batches")
ax.set_ylabel("bubble fraction")
fig.tight_layout()
plt.show()
for m in [1, 4, 16, 64]:
print(f"micro_batches={m:>2}: bubble={(stages-1)/(m+stages-1):.3f}")
9. All-reduce cost model
Code cell 20
payload_gb = np.array([0.5, 1.0, 2.0, 4.0])
bandwidth_gbps = 200.0
latency_ms = 0.05
time_ms = latency_ms + payload_gb / bandwidth_gbps * 1000
for payload, ms in zip(payload_gb, time_ms):
print(f"payload={payload:.1f} GB -> approx {ms:.2f} ms")
10. FLOPs and MFU
Code cell 22
params = 7e9
tokens = 300e9
train_flops = 6 * params * tokens
days = 14
seconds = days * 24 * 3600
useful_flops_per_sec = train_flops / seconds
hardware_peak = 512 * 312e12
mfu = useful_flops_per_sec / hardware_peak
print("training FLOPs:", f"{train_flops:.3e}")
print("useful FLOPs/sec:", f"{useful_flops_per_sec:.3e}")
print("MFU:", round(mfu, 3))
11. Compute-optimal tradeoff surface
Code cell 24
N = np.logspace(8, 11, 40)
D = np.logspace(9, 12, 40)
NN, DD = np.meshgrid(N, D)
compute = 6 * NN * DD
loss_proxy = 1.7 + 0.8 * (NN / 1e9) ** -0.08 + 0.6 * (DD / 1e10) ** -0.10
fig, ax = plt.subplots(figsize=(7, 5))
cs = ax.contour(np.log10(NN), np.log10(DD), np.log10(compute), levels=8, colors="gray", alpha=0.5)
im = ax.contourf(np.log10(NN), np.log10(DD), loss_proxy, levels=16)
fig.colorbar(im, ax=ax, label="proxy loss")
ax.clabel(cs, inline=True, fontsize=8)
ax.set_title("Toy loss surface with compute contours")
ax.set_xlabel("log10 parameters")
ax.set_ylabel("log10 tokens")
fig.tight_layout()
plt.show()
print("This is a teaching proxy, not an empirical scaling-law fit.")
12. Token packing utilization
Code cell 26
lengths = np.array([128, 512, 900, 64, 256, 700, 1024, 300])
block = 1024
padded_tokens = len(lengths) * block
real_tokens = lengths.sum()
naive_util = real_tokens / padded_tokens
packed_blocks = int(np.ceil(real_tokens / block))
packed_util = real_tokens / (packed_blocks * block)
print("naive utilization:", round(naive_util, 3))
print("packed utilization:", round(packed_util, 3))
13. Checkpoint overhead
Code cell 28
step_time = 2.0
checkpoint_time = 180.0
intervals = np.array([100, 500, 1000, 5000])
overhead_fraction = checkpoint_time / (intervals * step_time + checkpoint_time)
for interval, overhead in zip(intervals, overhead_fraction):
print(f"checkpoint every {interval:>4} steps -> overhead {100*overhead:.2f}%")
14. Loss spike detection
Code cell 30
rng = np.random.default_rng(7)
loss = 3.0 - 0.002 * np.arange(200) + rng.normal(0, 0.015, 200)
loss[120] += 0.45
window = 20
median = np.array([np.median(loss[max(0, i-window):i+1]) for i in range(len(loss))])
spikes = np.where(loss > median + 0.20)[0]
print("detected spike steps:", spikes.tolist())
fig, ax = plt.subplots(figsize=(9, 4))
ax.plot(loss, label="loss")
ax.plot(median, label="rolling median")
ax.scatter(spikes, loss[spikes], color="red", zorder=3, label="spike")
ax.set_title("Toy loss spike detector")
ax.set_xlabel("step")
ax.set_ylabel("loss")
ax.legend()
fig.tight_layout()
plt.show()
15. Final launch checklist
Code cell 32
checks = [
"small run learns and resumes exactly",
"labels are shifted and masks are tested",
"memory estimate includes optimizer states and activations",
"global batch and token budget are explicit",
"parallelism product equals world size",
"validation loss and throughput are logged separately",
"checkpoint includes model, optimizer, scheduler, RNG, and dataloader state",
]
for i, check in enumerate(checks, 1):
print(f"{i}. {check}")