Theory NotebookMath for LLMs

Stochastic Optimization

Optimization / Stochastic Optimization

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Stochastic Optimization - Theory Notebook

Executable derivations and diagnostics for Chapter 8.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3

def header(title):
    print("\n" + "=" * 78)
    print(title)
    print("=" * 78)

def check_close(name, value, target, tol=1e-8):
    ok = abs(float(value) - float(target)) <= tol
    print(f"{'PASS' if ok else 'FAIL'} - {name}: value={value:.8f}, target={target:.8f}")
    if not ok:
        raise AssertionError(name)

def check_true(name, condition):
    ok = bool(condition)
    print(f"{'PASS' if ok else 'FAIL'} - {name}")
    if not ok:
        raise AssertionError(name)

COLORS = {
    "primary":   "#0077BB",
    "secondary": "#EE7733",
    "tertiary":  "#009988",
    "error":     "#CC3311",
    "neutral":   "#555555",
    "highlight": "#EE3377",
}
print("Helper functions ready.")

Demo 1: Stochastic Objective

This cell checks a small numerical fact connected to stochastic objective.

Code cell 5

header("Demo 1: Stochastic Objective")
x = np.linspace(-3, 3, 200)
a = 1
freq = 1
loss = 0.5 * a * x**2 + 0.1 * np.sin(freq * x)
grad = a * x + 0.1 * freq * np.cos(freq * x)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, loss, color=COLORS["primary"], label="loss")
ax.plot(x, grad, color=COLORS["secondary"], linestyle="--", label="gradient")
ax.set_title("Stochastic Optimization: stochastic objective diagnostic")
ax.set_xlabel("Parameter $\\theta$")
ax.set_ylabel("Value")
ax.legend(loc="best")
fig.tight_layout()
plt.show()
plt.close(fig)
check_true("finite loss curve", np.all(np.isfinite(loss)))
print("Takeaway: plotting the objective and gradient makes stochastic objective visible before training a large model.")

Demo 2: Empirical Risk

This cell checks a small numerical fact connected to empirical risk.

Code cell 7

header("Demo 2: update-size computation")
theta = np.array([1.5, -0.5, 0.25], dtype=float)
H = np.diag([1.0, 3.0, 5.0])
grad = H @ theta
eta = 0.1 / (2)
step = -eta * grad
new_theta = theta + step
old_loss = 0.5 * theta @ H @ theta
new_loss = 0.5 * new_theta @ H @ new_theta
print("old_loss =", round(float(old_loss), 6))
print("new_loss =", round(float(new_loss), 6))
print("relative_update =", round(float(np.linalg.norm(step) / max(np.linalg.norm(theta), 1e-12)), 6))
check_true("descent on this quadratic", new_loss < old_loss)
print("Takeaway: empirical risk should be checked through both loss change and update magnitude.")

Demo 3: Population Risk

This cell checks a small numerical fact connected to population risk.

Code cell 9

header("Demo 3: stochastic estimate")
rng = np.random.default_rng(42 + 2)
samples = rng.normal(loc=0.0, scale=1.0, size=(256, 3))
theta = np.array([0.2, -0.4, 0.6])
full_grad = samples.T @ (samples @ theta) / len(samples)
batch = samples[:32]
batch_grad = batch.T @ (batch @ theta) / len(batch)
gap = np.linalg.norm(batch_grad - full_grad)
print("full_grad =", np.round(full_grad, 5))
print("batch_grad =", np.round(batch_grad, 5))
print("gradient_gap =", round(float(gap), 6))
check_true("finite stochastic estimate", np.isfinite(gap))
print("Takeaway: even when the section is not stochastic, minibatch estimates affect how population risk appears in practice.")

Demo 4: Unbiased Gradient Oracle

This cell checks a small numerical fact connected to unbiased gradient oracle.

Code cell 11

header("Demo 4: closed-form verification")
values = np.array([4.0, 5.0, 7.0])
mean_value = float(values.mean())
centered = values - mean_value
energy = float(np.dot(centered, centered))
manual = float(sum((v - mean_value) ** 2 for v in values))
print("values =", values)
print("centered_energy =", round(energy, 6))
check_close("manual equals vectorized computation", energy, manual)
print("Takeaway: small closed-form checks prevent conceptual drift when implementing unbiased gradient oracle.")

Demo 5: Gradient Variance

This cell checks a small numerical fact connected to gradient variance.

Code cell 13

header("Demo 5: Gradient Variance")
x = np.linspace(-3, 3, 200)
a = 5
freq = 5
loss = 0.5 * a * x**2 + 0.1 * np.sin(freq * x)
grad = a * x + 0.1 * freq * np.cos(freq * x)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, loss, color=COLORS["primary"], label="loss")
ax.plot(x, grad, color=COLORS["secondary"], linestyle="--", label="gradient")
ax.set_title("Stochastic Optimization: gradient variance diagnostic")
ax.set_xlabel("Parameter $\\theta$")
ax.set_ylabel("Value")
ax.legend(loc="best")
fig.tight_layout()
plt.show()
plt.close(fig)
check_true("finite loss curve", np.all(np.isfinite(loss)))
print("Takeaway: plotting the objective and gradient makes gradient variance visible before training a large model.")

Demo 6: Minibatch Estimator

This cell checks a small numerical fact connected to minibatch estimator.

Code cell 15

header("Demo 6: update-size computation")
theta = np.array([1.5, -0.5, 0.25], dtype=float)
H = np.diag([1.0, 3.0, 4.0])
grad = H @ theta
eta = 0.1 / (3)
step = -eta * grad
new_theta = theta + step
old_loss = 0.5 * theta @ H @ theta
new_loss = 0.5 * new_theta @ H @ new_theta
print("old_loss =", round(float(old_loss), 6))
print("new_loss =", round(float(new_loss), 6))
print("relative_update =", round(float(np.linalg.norm(step) / max(np.linalg.norm(theta), 1e-12)), 6))
check_true("descent on this quadratic", new_loss < old_loss)
print("Takeaway: minibatch estimator should be checked through both loss change and update magnitude.")

Demo 7: Batch-Size Scaling

This cell checks a small numerical fact connected to batch-size scaling.

Code cell 17

header("Demo 7: stochastic estimate")
rng = np.random.default_rng(42 + 6)
samples = rng.normal(loc=0.0, scale=1.0, size=(256, 3))
theta = np.array([0.2, -0.4, 0.6])
full_grad = samples.T @ (samples @ theta) / len(samples)
batch = samples[:32]
batch_grad = batch.T @ (batch @ theta) / len(batch)
gap = np.linalg.norm(batch_grad - full_grad)
print("full_grad =", np.round(full_grad, 5))
print("batch_grad =", np.round(batch_grad, 5))
print("gradient_gap =", round(float(gap), 6))
check_true("finite stochastic estimate", np.isfinite(gap))
print("Takeaway: even when the section is not stochastic, minibatch estimates affect how batch-size scaling appears in practice.")

Demo 8: Critical Batch Size

This cell checks a small numerical fact connected to critical batch size.

Code cell 19

header("Demo 8: closed-form verification")
values = np.array([8.0, 9.0, 11.0])
mean_value = float(values.mean())
centered = values - mean_value
energy = float(np.dot(centered, centered))
manual = float(sum((v - mean_value) ** 2 for v in values))
print("values =", values)
print("centered_energy =", round(energy, 6))
check_close("manual equals vectorized computation", energy, manual)
print("Takeaway: small closed-form checks prevent conceptual drift when implementing critical batch size.")

Demo 9: Robbins-Monro Schedule

This cell checks a small numerical fact connected to Robbins-Monro schedule.

Code cell 21

header("Demo 9: Robbins-Monro Schedule")
x = np.linspace(-3, 3, 200)
a = 4
freq = 9
loss = 0.5 * a * x**2 + 0.1 * np.sin(freq * x)
grad = a * x + 0.1 * freq * np.cos(freq * x)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, loss, color=COLORS["primary"], label="loss")
ax.plot(x, grad, color=COLORS["secondary"], linestyle="--", label="gradient")
ax.set_title("Stochastic Optimization: Robbins-Monro schedule diagnostic")
ax.set_xlabel("Parameter $\\theta$")
ax.set_ylabel("Value")
ax.legend(loc="best")
fig.tight_layout()
plt.show()
plt.close(fig)
check_true("finite loss curve", np.all(np.isfinite(loss)))
print("Takeaway: plotting the objective and gradient makes Robbins-Monro schedule visible before training a large model.")

Demo 10: Sgd Convergence

This cell checks a small numerical fact connected to SGD convergence.

Code cell 23

header("Demo 10: update-size computation")
theta = np.array([1.5, -0.5, 0.25], dtype=float)
H = np.diag([1.0, 3.0, 8.0])
grad = H @ theta
eta = 0.1 / (1)
step = -eta * grad
new_theta = theta + step
old_loss = 0.5 * theta @ H @ theta
new_loss = 0.5 * new_theta @ H @ new_theta
print("old_loss =", round(float(old_loss), 6))
print("new_loss =", round(float(new_loss), 6))
print("relative_update =", round(float(np.linalg.norm(step) / max(np.linalg.norm(theta), 1e-12)), 6))
check_true("descent on this quadratic", new_loss < old_loss)
print("Takeaway: SGD convergence should be checked through both loss change and update magnitude.")

Demo 11: Strongly Convex Sgd

This cell checks a small numerical fact connected to strongly convex SGD.

Code cell 25

header("Demo 11: stochastic estimate")
rng = np.random.default_rng(42 + 10)
samples = rng.normal(loc=0.0, scale=1.0, size=(256, 3))
theta = np.array([0.2, -0.4, 0.6])
full_grad = samples.T @ (samples @ theta) / len(samples)
batch = samples[:32]
batch_grad = batch.T @ (batch @ theta) / len(batch)
gap = np.linalg.norm(batch_grad - full_grad)
print("full_grad =", np.round(full_grad, 5))
print("batch_grad =", np.round(batch_grad, 5))
print("gradient_gap =", round(float(gap), 6))
check_true("finite stochastic estimate", np.isfinite(gap))
print("Takeaway: even when the section is not stochastic, minibatch estimates affect how strongly convex SGD appears in practice.")

Demo 12: Nonconvex Sgd

This cell checks a small numerical fact connected to nonconvex SGD.

Code cell 27

header("Demo 12: closed-form verification")
values = np.array([12.0, 13.0, 15.0])
mean_value = float(values.mean())
centered = values - mean_value
energy = float(np.dot(centered, centered))
manual = float(sum((v - mean_value) ** 2 for v in values))
print("values =", values)
print("centered_energy =", round(energy, 6))
check_close("manual equals vectorized computation", energy, manual)
print("Takeaway: small closed-form checks prevent conceptual drift when implementing nonconvex SGD.")

Demo 13: Gradient Noise Scale

This cell checks a small numerical fact connected to gradient noise scale.

Code cell 29

header("Demo 13: Gradient Noise Scale")
x = np.linspace(-3, 3, 200)
a = 3
freq = 13
loss = 0.5 * a * x**2 + 0.1 * np.sin(freq * x)
grad = a * x + 0.1 * freq * np.cos(freq * x)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, loss, color=COLORS["primary"], label="loss")
ax.plot(x, grad, color=COLORS["secondary"], linestyle="--", label="gradient")
ax.set_title("Stochastic Optimization: gradient noise scale diagnostic")
ax.set_xlabel("Parameter $\\theta$")
ax.set_ylabel("Value")
ax.legend(loc="best")
fig.tight_layout()
plt.show()
plt.close(fig)
check_true("finite loss curve", np.all(np.isfinite(loss)))
print("Takeaway: plotting the objective and gradient makes gradient noise scale visible before training a large model.")

Demo 14: Svrg

This cell checks a small numerical fact connected to SVRG.

Code cell 31

header("Demo 14: update-size computation")
theta = np.array([1.5, -0.5, 0.25], dtype=float)
H = np.diag([1.0, 3.0, 7.0])
grad = H @ theta
eta = 0.1 / (2)
step = -eta * grad
new_theta = theta + step
old_loss = 0.5 * theta @ H @ theta
new_loss = 0.5 * new_theta @ H @ new_theta
print("old_loss =", round(float(old_loss), 6))
print("new_loss =", round(float(new_loss), 6))
print("relative_update =", round(float(np.linalg.norm(step) / max(np.linalg.norm(theta), 1e-12)), 6))
check_true("descent on this quadratic", new_loss < old_loss)
print("Takeaway: SVRG should be checked through both loss change and update magnitude.")

Demo 15: Saga

This cell checks a small numerical fact connected to SAGA.

Code cell 33

header("Demo 15: stochastic estimate")
rng = np.random.default_rng(42 + 14)
samples = rng.normal(loc=0.0, scale=1.0, size=(256, 3))
theta = np.array([0.2, -0.4, 0.6])
full_grad = samples.T @ (samples @ theta) / len(samples)
batch = samples[:32]
batch_grad = batch.T @ (batch @ theta) / len(batch)
gap = np.linalg.norm(batch_grad - full_grad)
print("full_grad =", np.round(full_grad, 5))
print("batch_grad =", np.round(batch_grad, 5))
print("gradient_gap =", round(float(gap), 6))
check_true("finite stochastic estimate", np.isfinite(gap))
print("Takeaway: even when the section is not stochastic, minibatch estimates affect how SAGA appears in practice.")

Demo 16: Control Variates

This cell checks a small numerical fact connected to control variates.

Code cell 35

header("Demo 16: closed-form verification")
values = np.array([16.0, 17.0, 19.0])
mean_value = float(values.mean())
centered = values - mean_value
energy = float(np.dot(centered, centered))
manual = float(sum((v - mean_value) ** 2 for v in values))
print("values =", values)
print("centered_energy =", round(energy, 6))
check_close("manual equals vectorized computation", energy, manual)
print("Takeaway: small closed-form checks prevent conceptual drift when implementing control variates.")

Demo 17: Polyak Averaging

This cell checks a small numerical fact connected to Polyak averaging.

Code cell 37

header("Demo 17: Polyak Averaging")
x = np.linspace(-3, 3, 200)
a = 2
freq = 17
loss = 0.5 * a * x**2 + 0.1 * np.sin(freq * x)
grad = a * x + 0.1 * freq * np.cos(freq * x)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, loss, color=COLORS["primary"], label="loss")
ax.plot(x, grad, color=COLORS["secondary"], linestyle="--", label="gradient")
ax.set_title("Stochastic Optimization: Polyak averaging diagnostic")
ax.set_xlabel("Parameter $\\theta$")
ax.set_ylabel("Value")
ax.legend(loc="best")
fig.tight_layout()
plt.show()
plt.close(fig)
check_true("finite loss curve", np.all(np.isfinite(loss)))
print("Takeaway: plotting the objective and gradient makes Polyak averaging visible before training a large model.")

Demo 18: Distributed Sgd

This cell checks a small numerical fact connected to distributed SGD.

Code cell 39

header("Demo 18: update-size computation")
theta = np.array([1.5, -0.5, 0.25], dtype=float)
H = np.diag([1.0, 3.0, 6.0])
grad = H @ theta
eta = 0.1 / (3)
step = -eta * grad
new_theta = theta + step
old_loss = 0.5 * theta @ H @ theta
new_loss = 0.5 * new_theta @ H @ new_theta
print("old_loss =", round(float(old_loss), 6))
print("new_loss =", round(float(new_loss), 6))
print("relative_update =", round(float(np.linalg.norm(step) / max(np.linalg.norm(theta), 1e-12)), 6))
check_true("descent on this quadratic", new_loss < old_loss)
print("Takeaway: distributed SGD should be checked through both loss change and update magnitude.")

Demo 19: Gradient Accumulation

This cell checks a small numerical fact connected to gradient accumulation.

Code cell 41

header("Demo 19: stochastic estimate")
rng = np.random.default_rng(42 + 18)
samples = rng.normal(loc=0.0, scale=1.0, size=(256, 3))
theta = np.array([0.2, -0.4, 0.6])
full_grad = samples.T @ (samples @ theta) / len(samples)
batch = samples[:32]
batch_grad = batch.T @ (batch @ theta) / len(batch)
gap = np.linalg.norm(batch_grad - full_grad)
print("full_grad =", np.round(full_grad, 5))
print("batch_grad =", np.round(batch_grad, 5))
print("gradient_gap =", round(float(gap), 6))
check_true("finite stochastic estimate", np.isfinite(gap))
print("Takeaway: even when the section is not stochastic, minibatch estimates affect how gradient accumulation appears in practice.")

Demo 20: Local Sgd

This cell checks a small numerical fact connected to local SGD.

Code cell 43

header("Demo 20: closed-form verification")
values = np.array([20.0, 21.0, 23.0])
mean_value = float(values.mean())
centered = values - mean_value
energy = float(np.dot(centered, centered))
manual = float(sum((v - mean_value) ** 2 for v in values))
print("values =", values)
print("centered_energy =", round(energy, 6))
check_close("manual equals vectorized computation", energy, manual)
print("Takeaway: small closed-form checks prevent conceptual drift when implementing local SGD.")

Demo 21: Federated Averaging

This cell checks a small numerical fact connected to federated averaging.

Code cell 45

header("Demo 21: Federated Averaging")
x = np.linspace(-3, 3, 200)
a = 1
freq = 21
loss = 0.5 * a * x**2 + 0.1 * np.sin(freq * x)
grad = a * x + 0.1 * freq * np.cos(freq * x)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(x, loss, color=COLORS["primary"], label="loss")
ax.plot(x, grad, color=COLORS["secondary"], linestyle="--", label="gradient")
ax.set_title("Stochastic Optimization: federated averaging diagnostic")
ax.set_xlabel("Parameter $\\theta$")
ax.set_ylabel("Value")
ax.legend(loc="best")
fig.tight_layout()
plt.show()
plt.close(fig)
check_true("finite loss curve", np.all(np.isfinite(loss)))
print("Takeaway: plotting the objective and gradient makes federated averaging visible before training a large model.")

Demo 22: Communication Compression

This cell checks a small numerical fact connected to communication compression.

Code cell 47

header("Demo 22: update-size computation")
theta = np.array([1.5, -0.5, 0.25], dtype=float)
H = np.diag([1.0, 3.0, 5.0])
grad = H @ theta
eta = 0.1 / (1)
step = -eta * grad
new_theta = theta + step
old_loss = 0.5 * theta @ H @ theta
new_loss = 0.5 * new_theta @ H @ new_theta
print("old_loss =", round(float(old_loss), 6))
print("new_loss =", round(float(new_loss), 6))
print("relative_update =", round(float(np.linalg.norm(step) / max(np.linalg.norm(theta), 1e-12)), 6))
check_true("descent on this quadratic", new_loss < old_loss)
print("Takeaway: communication compression should be checked through both loss change and update magnitude.")

Demo 23: Llm Pretraining Noise

This cell checks a small numerical fact connected to LLM pretraining noise.

Code cell 49

header("Demo 23: stochastic estimate")
rng = np.random.default_rng(42 + 22)
samples = rng.normal(loc=0.0, scale=1.0, size=(256, 3))
theta = np.array([0.2, -0.4, 0.6])
full_grad = samples.T @ (samples @ theta) / len(samples)
batch = samples[:32]
batch_grad = batch.T @ (batch @ theta) / len(batch)
gap = np.linalg.norm(batch_grad - full_grad)
print("full_grad =", np.round(full_grad, 5))
print("batch_grad =", np.round(batch_grad, 5))
print("gradient_gap =", round(float(gap), 6))
check_true("finite stochastic estimate", np.isfinite(gap))
print("Takeaway: even when the section is not stochastic, minibatch estimates affect how LLM pretraining noise appears in practice.")

Demo 24: Stochastic Objective

This cell checks a small numerical fact connected to stochastic objective.

Code cell 51

header("Demo 24: closed-form verification")
values = np.array([24.0, 25.0, 27.0])
mean_value = float(values.mean())
centered = values - mean_value
energy = float(np.dot(centered, centered))
manual = float(sum((v - mean_value) ** 2 for v in values))
print("values =", values)
print("centered_energy =", round(energy, 6))
check_close("manual equals vectorized computation", energy, manual)
print("Takeaway: small closed-form checks prevent conceptual drift when implementing stochastic objective.")