Exercises Notebook
Converted from
exercises.ipynbfor web reading.
Training at Scale: Exercises
These ten exercises train the accounting skills that matter before a large LLM run: optimizer steps, clipping, schedules, memory, parallelism, FLOPs, MFU, and launch checks.
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
Exercise 1: AdamW scalar update
Compute one AdamW update for a scalar parameter.
Code cell 4
# Your Solution
theta = 2.0
grad = 0.5
print("Starter: update m, v, bias-correct them, then apply AdamW.")
Code cell 5
# Solution
theta = 2.0
grad = 0.5
beta1, beta2 = 0.9, 0.999
lr, eps, wd = 1e-3, 1e-8, 0.01
m = (1 - beta1) * grad
v = (1 - beta2) * grad**2
m_hat = m / (1 - beta1)
v_hat = v / (1 - beta2)
theta_next = theta - lr * m_hat / (np.sqrt(v_hat) + eps) - lr * wd * theta
print("theta_next:", theta_next)
Exercise 2: Clip a gradient
Clip vector to norm 5.
Code cell 7
# Your Solution
g = np.array([6.0, 8.0])
print("Starter: multiply by min(1, 5 / norm(g)).")
Code cell 8
# Solution
g = np.array([6.0, 8.0])
scale = min(1.0, 5.0 / np.linalg.norm(g))
clipped = g * scale
print("clipped:", clipped)
print("norm:", np.linalg.norm(clipped))
assert np.isclose(np.linalg.norm(clipped), 5.0)
Exercise 3: Warmup schedule value
Find LR at step 50 with 100-step warmup and peak 0.0003.
Code cell 10
# Your Solution
step = 50
warmup = 100
peak = 3e-4
print("Starter: during warmup, lr = peak * (step + 1) / warmup.")
Code cell 11
# Solution
step = 50
warmup = 100
peak = 3e-4
lr = peak * (step + 1) / warmup
print("lr:", lr)
assert lr < peak
Exercise 4: Effective token batch
Compute tokens per optimizer step.
Code cell 13
# Your Solution
micro_batch = 4
seq_len = 2048
dp = 32
accum = 8
print("Starter: multiply micro_batch, seq_len, dp, and accum.")
Code cell 14
# Solution
micro_batch = 4
seq_len = 2048
dp = 32
accum = 8
tokens = micro_batch * seq_len * dp * accum
print("tokens per optimizer step:", tokens)
Exercise 5: Memory estimate
Estimate replicated Adam training state for 1B parameters with bf16 weights/grads and fp32 moments.
Code cell 16
# Your Solution
P = 1e9
print("Starter: weights=2P, grads=2P, moments=8P bytes.")
Code cell 17
# Solution
P = 1e9
total_bytes = 2*P + 2*P + 8*P
print("GB:", total_bytes / 1e9)
assert np.isclose(total_bytes / 1e9, 12.0)
Exercise 6: Pipeline bubble
Compute bubble fraction for 4 stages and 12 micro-batches.
Code cell 19
# Your Solution
P = 4
M = 12
print("Starter: bubble = (P - 1) / (M + P - 1).")
Code cell 20
# Solution
P = 4
M = 12
bubble = (P - 1) / (M + P - 1)
print("bubble:", bubble)
Exercise 7: Tensor-parallel shard
Split a 4096 x 16384 weight matrix across 4 column-parallel ranks.
Code cell 22
# Your Solution
in_dim = 4096
out_dim = 16384
tp = 4
print("Starter: each rank owns out_dim / tp columns.")
Code cell 23
# Solution
in_dim = 4096
out_dim = 16384
tp = 4
shape = (in_dim, out_dim // tp)
print("per-rank shard shape:", shape)
assert shape == (4096, 4096)
Exercise 8: Training FLOPs
Use for a 7B model trained on 300B tokens.
Code cell 25
# Your Solution
N = 7e9
D = 300e9
print("Starter: C = 6 * N * D.")
Code cell 26
# Solution
N = 7e9
D = 300e9
C = 6 * N * D
print("FLOPs:", f"{C:.3e}")
Exercise 9: MFU
Compute MFU from useful FLOPs/sec and hardware peak.
Code cell 28
# Your Solution
useful = 60e15
peak = 160e15
print("Starter: MFU = useful / peak.")
Code cell 29
# Solution
useful = 60e15
peak = 160e15
mfu = useful / peak
print("MFU:", mfu)
assert 0 <= mfu <= 1
Exercise 10: Launch checklist
Write four checks before scaling a training run.
Code cell 31
# Your Solution
print("Starter: include loss, resume, memory, and batch checks.")
Code cell 32
# Solution
checks = [
"small run reduces validation loss",
"resume restores optimizer, scheduler, RNG, and dataloader state",
"memory estimate includes activations and optimizer states",
"effective global token batch is documented",
]
for check in checks:
print("-", check)
assert len(checks) == 4
Closing Reflection
At scale, arithmetic mistakes become infrastructure failures. Keep units explicit, test small, and make the loss curve prove that the system is learning.