Theory Notebook
Converted from
theory.ipynbfor web reading.
Learning Rate Schedules - Theory Notebook
Executable derivations and diagnostics for Chapter 8.
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
Code cell 3
def header(title):
print("\n" + "=" * 78)
print(title)
print("=" * 78)
def check_close(name, value, target, tol=1e-8):
ok = abs(float(value) - float(target)) <= tol
print(f"{'PASS' if ok else 'FAIL'} - {name}: value={value:.8f}, target={target:.8f}")
if not ok:
raise AssertionError(name)
def check_true(name, condition):
ok = bool(condition)
print(f"{'PASS' if ok else 'FAIL'} - {name}")
if not ok:
raise AssertionError(name)
COLORS = {
"primary": "#0077BB",
"secondary": "#EE7733",
"tertiary": "#009988",
"error": "#CC3311",
"neutral": "#555555",
"highlight": "#EE3377",
}
print("Helper functions ready.")
Demo 1: Schedule Function
This cell checks a small numerical fact connected to schedule function.
Code cell 5
header("Demo 1: Schedule Function")
x = np.linspace(-3, 3, 200)
a = 1
freq = 1
loss = 0.5 * a * x**2 + 0.1 * np.sin(freq * x)
grad = a * x + 0.1 * freq * np.cos(freq * x)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, loss, color=COLORS["primary"], label="loss")
ax.plot(x, grad, color=COLORS["secondary"], linestyle="--", label="gradient")
ax.set_title("Learning Rate Schedules: schedule function diagnostic")
ax.set_xlabel("Parameter $\\theta$")
ax.set_ylabel("Value")
ax.legend(loc="best")
fig.tight_layout()
plt.show()
plt.close(fig)
check_true("finite loss curve", np.all(np.isfinite(loss)))
print("Takeaway: plotting the objective and gradient makes schedule function visible before training a large model.")
Demo 2: Constant Learning Rate
This cell checks a small numerical fact connected to constant learning rate.
Code cell 7
header("Demo 2: update-size computation")
theta = np.array([1.5, -0.5, 0.25], dtype=float)
H = np.diag([1.0, 3.0, 5.0])
grad = H @ theta
eta = 0.1 / (2)
step = -eta * grad
new_theta = theta + step
old_loss = 0.5 * theta @ H @ theta
new_loss = 0.5 * new_theta @ H @ new_theta
print("old_loss =", round(float(old_loss), 6))
print("new_loss =", round(float(new_loss), 6))
print("relative_update =", round(float(np.linalg.norm(step) / max(np.linalg.norm(theta), 1e-12)), 6))
check_true("descent on this quadratic", new_loss < old_loss)
print("Takeaway: constant learning rate should be checked through both loss change and update magnitude.")
Demo 3: Step Decay
This cell checks a small numerical fact connected to step decay.
Code cell 9
header("Demo 3: stochastic estimate")
rng = np.random.default_rng(42 + 2)
samples = rng.normal(loc=0.0, scale=1.0, size=(256, 3))
theta = np.array([0.2, -0.4, 0.6])
full_grad = samples.T @ (samples @ theta) / len(samples)
batch = samples[:32]
batch_grad = batch.T @ (batch @ theta) / len(batch)
gap = np.linalg.norm(batch_grad - full_grad)
print("full_grad =", np.round(full_grad, 5))
print("batch_grad =", np.round(batch_grad, 5))
print("gradient_gap =", round(float(gap), 6))
check_true("finite stochastic estimate", np.isfinite(gap))
print("Takeaway: even when the section is not stochastic, minibatch estimates affect how step decay appears in practice.")
Demo 4: Exponential Decay
This cell checks a small numerical fact connected to exponential decay.
Code cell 11
header("Demo 4: closed-form verification")
values = np.array([4.0, 5.0, 7.0])
mean_value = float(values.mean())
centered = values - mean_value
energy = float(np.dot(centered, centered))
manual = float(sum((v - mean_value) ** 2 for v in values))
print("values =", values)
print("centered_energy =", round(energy, 6))
check_close("manual equals vectorized computation", energy, manual)
print("Takeaway: small closed-form checks prevent conceptual drift when implementing exponential decay.")
Demo 5: Polynomial Decay
This cell checks a small numerical fact connected to polynomial decay.
Code cell 13
header("Demo 5: Polynomial Decay")
x = np.linspace(-3, 3, 200)
a = 5
freq = 5
loss = 0.5 * a * x**2 + 0.1 * np.sin(freq * x)
grad = a * x + 0.1 * freq * np.cos(freq * x)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, loss, color=COLORS["primary"], label="loss")
ax.plot(x, grad, color=COLORS["secondary"], linestyle="--", label="gradient")
ax.set_title("Learning Rate Schedules: polynomial decay diagnostic")
ax.set_xlabel("Parameter $\\theta$")
ax.set_ylabel("Value")
ax.legend(loc="best")
fig.tight_layout()
plt.show()
plt.close(fig)
check_true("finite loss curve", np.all(np.isfinite(loss)))
print("Takeaway: plotting the objective and gradient makes polynomial decay visible before training a large model.")
Demo 6: Linear Warmup
This cell checks a small numerical fact connected to linear warmup.
Code cell 15
header("Demo 6: update-size computation")
theta = np.array([1.5, -0.5, 0.25], dtype=float)
H = np.diag([1.0, 3.0, 4.0])
grad = H @ theta
eta = 0.1 / (3)
step = -eta * grad
new_theta = theta + step
old_loss = 0.5 * theta @ H @ theta
new_loss = 0.5 * new_theta @ H @ new_theta
print("old_loss =", round(float(old_loss), 6))
print("new_loss =", round(float(new_loss), 6))
print("relative_update =", round(float(np.linalg.norm(step) / max(np.linalg.norm(theta), 1e-12)), 6))
check_true("descent on this quadratic", new_loss < old_loss)
print("Takeaway: linear warmup should be checked through both loss change and update magnitude.")
Demo 7: Warmup Ratio
This cell checks a small numerical fact connected to warmup ratio.
Code cell 17
header("Demo 7: stochastic estimate")
rng = np.random.default_rng(42 + 6)
samples = rng.normal(loc=0.0, scale=1.0, size=(256, 3))
theta = np.array([0.2, -0.4, 0.6])
full_grad = samples.T @ (samples @ theta) / len(samples)
batch = samples[:32]
batch_grad = batch.T @ (batch @ theta) / len(batch)
gap = np.linalg.norm(batch_grad - full_grad)
print("full_grad =", np.round(full_grad, 5))
print("batch_grad =", np.round(batch_grad, 5))
print("gradient_gap =", round(float(gap), 6))
check_true("finite stochastic estimate", np.isfinite(gap))
print("Takeaway: even when the section is not stochastic, minibatch estimates affect how warmup ratio appears in practice.")
Demo 8: Cosine Annealing
This cell checks a small numerical fact connected to cosine annealing.
Code cell 19
header("Demo 8: closed-form verification")
values = np.array([8.0, 9.0, 11.0])
mean_value = float(values.mean())
centered = values - mean_value
energy = float(np.dot(centered, centered))
manual = float(sum((v - mean_value) ** 2 for v in values))
print("values =", values)
print("centered_energy =", round(energy, 6))
check_close("manual equals vectorized computation", energy, manual)
print("Takeaway: small closed-form checks prevent conceptual drift when implementing cosine annealing.")
Demo 9: Cosine With Restarts
This cell checks a small numerical fact connected to cosine with restarts.
Code cell 21
header("Demo 9: Cosine With Restarts")
x = np.linspace(-3, 3, 200)
a = 4
freq = 9
loss = 0.5 * a * x**2 + 0.1 * np.sin(freq * x)
grad = a * x + 0.1 * freq * np.cos(freq * x)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, loss, color=COLORS["primary"], label="loss")
ax.plot(x, grad, color=COLORS["secondary"], linestyle="--", label="gradient")
ax.set_title("Learning Rate Schedules: cosine with restarts diagnostic")
ax.set_xlabel("Parameter $\\theta$")
ax.set_ylabel("Value")
ax.legend(loc="best")
fig.tight_layout()
plt.show()
plt.close(fig)
check_true("finite loss curve", np.all(np.isfinite(loss)))
print("Takeaway: plotting the objective and gradient makes cosine with restarts visible before training a large model.")
Demo 10: Cyclic Learning Rate
This cell checks a small numerical fact connected to cyclic learning rate.
Code cell 23
header("Demo 10: update-size computation")
theta = np.array([1.5, -0.5, 0.25], dtype=float)
H = np.diag([1.0, 3.0, 8.0])
grad = H @ theta
eta = 0.1 / (1)
step = -eta * grad
new_theta = theta + step
old_loss = 0.5 * theta @ H @ theta
new_loss = 0.5 * new_theta @ H @ new_theta
print("old_loss =", round(float(old_loss), 6))
print("new_loss =", round(float(new_loss), 6))
print("relative_update =", round(float(np.linalg.norm(step) / max(np.linalg.norm(theta), 1e-12)), 6))
check_true("descent on this quadratic", new_loss < old_loss)
print("Takeaway: cyclic learning rate should be checked through both loss change and update magnitude.")
Demo 11: One-Cycle Policy
This cell checks a small numerical fact connected to one-cycle policy.
Code cell 25
header("Demo 11: stochastic estimate")
rng = np.random.default_rng(42 + 10)
samples = rng.normal(loc=0.0, scale=1.0, size=(256, 3))
theta = np.array([0.2, -0.4, 0.6])
full_grad = samples.T @ (samples @ theta) / len(samples)
batch = samples[:32]
batch_grad = batch.T @ (batch @ theta) / len(batch)
gap = np.linalg.norm(batch_grad - full_grad)
print("full_grad =", np.round(full_grad, 5))
print("batch_grad =", np.round(batch_grad, 5))
print("gradient_gap =", round(float(gap), 6))
check_true("finite stochastic estimate", np.isfinite(gap))
print("Takeaway: even when the section is not stochastic, minibatch estimates affect how one-cycle policy appears in practice.")
Demo 12: Linear Decay
This cell checks a small numerical fact connected to linear decay.
Code cell 27
header("Demo 12: closed-form verification")
values = np.array([12.0, 13.0, 15.0])
mean_value = float(values.mean())
centered = values - mean_value
energy = float(np.dot(centered, centered))
manual = float(sum((v - mean_value) ** 2 for v in values))
print("values =", values)
print("centered_energy =", round(energy, 6))
check_close("manual equals vectorized computation", energy, manual)
print("Takeaway: small closed-form checks prevent conceptual drift when implementing linear decay.")
Demo 13: Inverse-Square-Root Decay
This cell checks a small numerical fact connected to inverse-square-root decay.
Code cell 29
header("Demo 13: Inverse-Square-Root Decay")
x = np.linspace(-3, 3, 200)
a = 3
freq = 13
loss = 0.5 * a * x**2 + 0.1 * np.sin(freq * x)
grad = a * x + 0.1 * freq * np.cos(freq * x)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, loss, color=COLORS["primary"], label="loss")
ax.plot(x, grad, color=COLORS["secondary"], linestyle="--", label="gradient")
ax.set_title("Learning Rate Schedules: inverse-square-root decay diagnostic")
ax.set_xlabel("Parameter $\\theta$")
ax.set_ylabel("Value")
ax.legend(loc="best")
fig.tight_layout()
plt.show()
plt.close(fig)
check_true("finite loss curve", np.all(np.isfinite(loss)))
print("Takeaway: plotting the objective and gradient makes inverse-square-root decay visible before training a large model.")
Demo 14: Wsd Schedule
This cell checks a small numerical fact connected to WSD schedule.
Code cell 31
header("Demo 14: update-size computation")
theta = np.array([1.5, -0.5, 0.25], dtype=float)
H = np.diag([1.0, 3.0, 7.0])
grad = H @ theta
eta = 0.1 / (2)
step = -eta * grad
new_theta = theta + step
old_loss = 0.5 * theta @ H @ theta
new_loss = 0.5 * new_theta @ H @ new_theta
print("old_loss =", round(float(old_loss), 6))
print("new_loss =", round(float(new_loss), 6))
print("relative_update =", round(float(np.linalg.norm(step) / max(np.linalg.norm(theta), 1e-12)), 6))
check_true("descent on this quadratic", new_loss < old_loss)
print("Takeaway: WSD schedule should be checked through both loss change and update magnitude.")
Demo 15: Cooldown
This cell checks a small numerical fact connected to cooldown.
Code cell 33
header("Demo 15: stochastic estimate")
rng = np.random.default_rng(42 + 14)
samples = rng.normal(loc=0.0, scale=1.0, size=(256, 3))
theta = np.array([0.2, -0.4, 0.6])
full_grad = samples.T @ (samples @ theta) / len(samples)
batch = samples[:32]
batch_grad = batch.T @ (batch @ theta) / len(batch)
gap = np.linalg.norm(batch_grad - full_grad)
print("full_grad =", np.round(full_grad, 5))
print("batch_grad =", np.round(batch_grad, 5))
print("gradient_gap =", round(float(gap), 6))
check_true("finite stochastic estimate", np.isfinite(gap))
print("Takeaway: even when the section is not stochastic, minibatch estimates affect how cooldown appears in practice.")
Demo 16: Learning-Rate Rewinding
This cell checks a small numerical fact connected to learning-rate rewinding.
Code cell 35
header("Demo 16: closed-form verification")
values = np.array([16.0, 17.0, 19.0])
mean_value = float(values.mean())
centered = values - mean_value
energy = float(np.dot(centered, centered))
manual = float(sum((v - mean_value) ** 2 for v in values))
print("values =", values)
print("centered_energy =", round(energy, 6))
check_close("manual equals vectorized computation", energy, manual)
print("Takeaway: small closed-form checks prevent conceptual drift when implementing learning-rate rewinding.")
Demo 17: Batch-Size Scaling
This cell checks a small numerical fact connected to batch-size scaling.
Code cell 37
header("Demo 17: Batch-Size Scaling")
x = np.linspace(-3, 3, 200)
a = 2
freq = 17
loss = 0.5 * a * x**2 + 0.1 * np.sin(freq * x)
grad = a * x + 0.1 * freq * np.cos(freq * x)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, loss, color=COLORS["primary"], label="loss")
ax.plot(x, grad, color=COLORS["secondary"], linestyle="--", label="gradient")
ax.set_title("Learning Rate Schedules: batch-size scaling diagnostic")
ax.set_xlabel("Parameter $\\theta$")
ax.set_ylabel("Value")
ax.legend(loc="best")
fig.tight_layout()
plt.show()
plt.close(fig)
check_true("finite loss curve", np.all(np.isfinite(loss)))
print("Takeaway: plotting the objective and gradient makes batch-size scaling visible before training a large model.")
Demo 18: Gradient Accumulation Coupling
This cell checks a small numerical fact connected to gradient accumulation coupling.
Code cell 39
header("Demo 18: update-size computation")
theta = np.array([1.5, -0.5, 0.25], dtype=float)
H = np.diag([1.0, 3.0, 6.0])
grad = H @ theta
eta = 0.1 / (3)
step = -eta * grad
new_theta = theta + step
old_loss = 0.5 * theta @ H @ theta
new_loss = 0.5 * new_theta @ H @ new_theta
print("old_loss =", round(float(old_loss), 6))
print("new_loss =", round(float(new_loss), 6))
print("relative_update =", round(float(np.linalg.norm(step) / max(np.linalg.norm(theta), 1e-12)), 6))
check_true("descent on this quadratic", new_loss < old_loss)
print("Takeaway: gradient accumulation coupling should be checked through both loss change and update magnitude.")
Demo 19: Token-Budget Scheduling
This cell checks a small numerical fact connected to token-budget scheduling.
Code cell 41
header("Demo 19: stochastic estimate")
rng = np.random.default_rng(42 + 18)
samples = rng.normal(loc=0.0, scale=1.0, size=(256, 3))
theta = np.array([0.2, -0.4, 0.6])
full_grad = samples.T @ (samples @ theta) / len(samples)
batch = samples[:32]
batch_grad = batch.T @ (batch @ theta) / len(batch)
gap = np.linalg.norm(batch_grad - full_grad)
print("full_grad =", np.round(full_grad, 5))
print("batch_grad =", np.round(batch_grad, 5))
print("gradient_gap =", round(float(gap), 6))
check_true("finite stochastic estimate", np.isfinite(gap))
print("Takeaway: even when the section is not stochastic, minibatch estimates affect how token-budget scheduling appears in practice.")
Demo 20: Optimizer-State Interaction
This cell checks a small numerical fact connected to optimizer-state interaction.
Code cell 43
header("Demo 20: closed-form verification")
values = np.array([20.0, 21.0, 23.0])
mean_value = float(values.mean())
centered = values - mean_value
energy = float(np.dot(centered, centered))
manual = float(sum((v - mean_value) ** 2 for v in values))
print("values =", values)
print("centered_energy =", round(energy, 6))
check_close("manual equals vectorized computation", energy, manual)
print("Takeaway: small closed-form checks prevent conceptual drift when implementing optimizer-state interaction.")
Demo 21: Llm Pretraining Schedule Design
This cell checks a small numerical fact connected to LLM pretraining schedule design.
Code cell 45
header("Demo 21: Llm Pretraining Schedule Design")
x = np.linspace(-3, 3, 200)
a = 1
freq = 21
loss = 0.5 * a * x**2 + 0.1 * np.sin(freq * x)
grad = a * x + 0.1 * freq * np.cos(freq * x)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, loss, color=COLORS["primary"], label="loss")
ax.plot(x, grad, color=COLORS["secondary"], linestyle="--", label="gradient")
ax.set_title("Learning Rate Schedules: LLM pretraining schedule design diagnostic")
ax.set_xlabel("Parameter $\\theta$")
ax.set_ylabel("Value")
ax.legend(loc="best")
fig.tight_layout()
plt.show()
plt.close(fig)
check_true("finite loss curve", np.all(np.isfinite(loss)))
print("Takeaway: plotting the objective and gradient makes LLM pretraining schedule design visible before training a large model.")
Demo 22: Fine-Tuning Schedule Design
This cell checks a small numerical fact connected to fine-tuning schedule design.
Code cell 47
header("Demo 22: update-size computation")
theta = np.array([1.5, -0.5, 0.25], dtype=float)
H = np.diag([1.0, 3.0, 5.0])
grad = H @ theta
eta = 0.1 / (1)
step = -eta * grad
new_theta = theta + step
old_loss = 0.5 * theta @ H @ theta
new_loss = 0.5 * new_theta @ H @ new_theta
print("old_loss =", round(float(old_loss), 6))
print("new_loss =", round(float(new_loss), 6))
print("relative_update =", round(float(np.linalg.norm(step) / max(np.linalg.norm(theta), 1e-12)), 6))
check_true("descent on this quadratic", new_loss < old_loss)
print("Takeaway: fine-tuning schedule design should be checked through both loss change and update magnitude.")
Demo 23: Schedule Function
This cell checks a small numerical fact connected to schedule function.
Code cell 49
header("Demo 23: stochastic estimate")
rng = np.random.default_rng(42 + 22)
samples = rng.normal(loc=0.0, scale=1.0, size=(256, 3))
theta = np.array([0.2, -0.4, 0.6])
full_grad = samples.T @ (samples @ theta) / len(samples)
batch = samples[:32]
batch_grad = batch.T @ (batch @ theta) / len(batch)
gap = np.linalg.norm(batch_grad - full_grad)
print("full_grad =", np.round(full_grad, 5))
print("batch_grad =", np.round(batch_grad, 5))
print("gradient_gap =", round(float(gap), 6))
check_true("finite stochastic estimate", np.isfinite(gap))
print("Takeaway: even when the section is not stochastic, minibatch estimates affect how schedule function appears in practice.")
Demo 24: Constant Learning Rate
This cell checks a small numerical fact connected to constant learning rate.
Code cell 51
header("Demo 24: closed-form verification")
values = np.array([24.0, 25.0, 27.0])
mean_value = float(values.mean())
centered = values - mean_value
energy = float(np.dot(centered, centered))
manual = float(sum((v - mean_value) ** 2 for v in values))
print("values =", values)
print("centered_energy =", round(energy, 6))
check_close("manual equals vectorized computation", energy, manual)
print("Takeaway: small closed-form checks prevent conceptual drift when implementing constant learning rate.")