Exercises NotebookMath for LLMs

Training at Scale

Math for LLMs / Training at Scale

Run notebook
Exercises Notebook

Exercises Notebook

Converted from exercises.ipynb for web reading.

Training at Scale: Exercises

These ten exercises train the accounting skills that matter before a large LLM run: optimizer steps, clipping, schedules, memory, parallelism, FLOPs, MFU, and launch checks.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Exercise 1: AdamW scalar update

Compute one AdamW update for a scalar parameter.

Code cell 4

# Your Solution
theta = 2.0
grad = 0.5
print("Starter: update m, v, bias-correct them, then apply AdamW.")

Code cell 5

# Solution
theta = 2.0
grad = 0.5
beta1, beta2 = 0.9, 0.999
lr, eps, wd = 1e-3, 1e-8, 0.01
m = (1 - beta1) * grad
v = (1 - beta2) * grad**2
m_hat = m / (1 - beta1)
v_hat = v / (1 - beta2)
theta_next = theta - lr * m_hat / (np.sqrt(v_hat) + eps) - lr * wd * theta
print("theta_next:", theta_next)

Exercise 2: Clip a gradient

Clip vector [6,8][6,8] to norm 5.

Code cell 7

# Your Solution
g = np.array([6.0, 8.0])
print("Starter: multiply by min(1, 5 / norm(g)).")

Code cell 8

# Solution
g = np.array([6.0, 8.0])
scale = min(1.0, 5.0 / np.linalg.norm(g))
clipped = g * scale
print("clipped:", clipped)
print("norm:", np.linalg.norm(clipped))
assert np.isclose(np.linalg.norm(clipped), 5.0)

Exercise 3: Warmup schedule value

Find LR at step 50 with 100-step warmup and peak 0.0003.

Code cell 10

# Your Solution
step = 50
warmup = 100
peak = 3e-4
print("Starter: during warmup, lr = peak * (step + 1) / warmup.")

Code cell 11

# Solution
step = 50
warmup = 100
peak = 3e-4
lr = peak * (step + 1) / warmup
print("lr:", lr)
assert lr < peak

Exercise 4: Effective token batch

Compute tokens per optimizer step.

Code cell 13

# Your Solution
micro_batch = 4
seq_len = 2048
dp = 32
accum = 8
print("Starter: multiply micro_batch, seq_len, dp, and accum.")

Code cell 14

# Solution
micro_batch = 4
seq_len = 2048
dp = 32
accum = 8
tokens = micro_batch * seq_len * dp * accum
print("tokens per optimizer step:", tokens)

Exercise 5: Memory estimate

Estimate replicated Adam training state for 1B parameters with bf16 weights/grads and fp32 moments.

Code cell 16

# Your Solution
P = 1e9
print("Starter: weights=2P, grads=2P, moments=8P bytes.")

Code cell 17

# Solution
P = 1e9
total_bytes = 2*P + 2*P + 8*P
print("GB:", total_bytes / 1e9)
assert np.isclose(total_bytes / 1e9, 12.0)

Exercise 6: Pipeline bubble

Compute bubble fraction for 4 stages and 12 micro-batches.

Code cell 19

# Your Solution
P = 4
M = 12
print("Starter: bubble = (P - 1) / (M + P - 1).")

Code cell 20

# Solution
P = 4
M = 12
bubble = (P - 1) / (M + P - 1)
print("bubble:", bubble)

Exercise 7: Tensor-parallel shard

Split a 4096 x 16384 weight matrix across 4 column-parallel ranks.

Code cell 22

# Your Solution
in_dim = 4096
out_dim = 16384
tp = 4
print("Starter: each rank owns out_dim / tp columns.")

Code cell 23

# Solution
in_dim = 4096
out_dim = 16384
tp = 4
shape = (in_dim, out_dim // tp)
print("per-rank shard shape:", shape)
assert shape == (4096, 4096)

Exercise 8: Training FLOPs

Use C=6NDC=6ND for a 7B model trained on 300B tokens.

Code cell 25

# Your Solution
N = 7e9
D = 300e9
print("Starter: C = 6 * N * D.")

Code cell 26

# Solution
N = 7e9
D = 300e9
C = 6 * N * D
print("FLOPs:", f"{C:.3e}")

Exercise 9: MFU

Compute MFU from useful FLOPs/sec and hardware peak.

Code cell 28

# Your Solution
useful = 60e15
peak = 160e15
print("Starter: MFU = useful / peak.")

Code cell 29

# Solution
useful = 60e15
peak = 160e15
mfu = useful / peak
print("MFU:", mfu)
assert 0 <= mfu <= 1

Exercise 10: Launch checklist

Write four checks before scaling a training run.

Code cell 31

# Your Solution
print("Starter: include loss, resume, memory, and batch checks.")

Code cell 32

# Solution
checks = [
    "small run reduces validation loss",
    "resume restores optimizer, scheduler, RNG, and dataloader state",
    "memory estimate includes activations and optimizer states",
    "effective global token batch is documented",
]
for check in checks:
    print("-", check)
assert len(checks) == 4

Closing Reflection

At scale, arithmetic mistakes become infrastructure failures. Keep units explicit, test small, and make the loss curve prove that the system is learning.