Theory NotebookMath for LLMs

Mixture of Experts and Routing

Math for LLMs / Mixture of Experts and Routing

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Mixture of Experts and Routing: Theory Notebook

This notebook turns MoE routing into concrete arrays: router probabilities, top-k dispatch, capacity overflow, auxiliary balancing, active parameter counts, all-to-all traffic, and collapse diagnostics.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

1. Router probabilities and top-k experts

Code cell 4

def softmax(z, axis=-1):
    z = np.asarray(z, dtype=float)
    z = z - np.max(z, axis=axis, keepdims=True)
    e = np.exp(z)
    return e / e.sum(axis=axis, keepdims=True)

router_logits = np.array([
    [2.0, 0.5, -0.3, 1.2],
    [0.1, 2.2, 1.8, -0.4],
    [1.0, 1.1, 1.2, 1.3],
])
probs = softmax(router_logits, axis=1)
k = 2
topk = np.argsort(probs, axis=1)[:, -k:][:, ::-1]
print("router probabilities:\n", np.round(probs, 3))
print("top-k experts:\n", topk)

2. Top-k gated combination

Code cell 6

rng = np.random.default_rng(4)
tokens, experts, d = 3, 4, 5
x = rng.normal(size=(tokens, d))
expert_outputs = rng.normal(size=(tokens, experts, d))
weights = np.zeros((tokens, experts))
for t in range(tokens):
    selected = topk[t]
    w = probs[t, selected]
    weights[t, selected] = w / w.sum()
y = np.einsum("te,ted->td", weights, expert_outputs)
print("routing weights:\n", np.round(weights, 3))
print("combined output shape:", y.shape)

3. Parameter accounting

Code cell 8

d_model = 4096
d_ff = 4 * d_model
experts = 8
k = 2
dense_ffn = 2 * d_model * d_ff
total_expert = experts * dense_ffn
active_expert = k * dense_ffn
router = d_model * experts
print("dense FFN params:", f"{dense_ffn:,}")
print("total expert params:", f"{total_expert:,}")
print("active expert params per token:", f"{active_expert:,}")
print("router params:", f"{router:,}")

4. Capacity and overflow

Code cell 10

assignments = np.array([0, 0, 0, 1, 1, 2, 3, 3, 3, 3])
tokens = len(assignments)
experts = 4
capacity_factor = 1.25
capacity = int(np.ceil(capacity_factor * tokens / experts))
loads = np.bincount(assignments, minlength=experts)
overflow = np.maximum(0, loads - capacity)
print("capacity per expert:", capacity)
print("loads:", loads)
print("overflow:", overflow)
print("drop rate:", overflow.sum() / tokens)

5. Expert histogram

Code cell 12

fig, ax = plt.subplots(figsize=(7, 4))
ax.bar(np.arange(experts), loads)
ax.axhline(capacity, color="red", linestyle="--", label="capacity")
ax.set_title("Expert token loads")
ax.set_xlabel("expert")
ax.set_ylabel("tokens")
ax.legend()
fig.tight_layout()
plt.show()
print("max/min load ratio:", loads.max() / max(1, loads.min()))

6. Switch-style auxiliary loss proxy

Code cell 14

router_probs = probs
top1 = np.argmax(router_probs, axis=1)
f = np.bincount(top1, minlength=router_probs.shape[1]) / router_probs.shape[0]
P = router_probs.mean(axis=0)
aux = router_probs.shape[1] * np.sum(f * P)
print("fraction routed f:", np.round(f, 3))
print("mean probability P:", np.round(P, 3))
print("aux loss proxy:", round(aux, 4))

7. Router entropy

Code cell 16

entropy = -np.sum(probs * np.log(probs + 1e-12), axis=1)
for t, h in enumerate(entropy):
    print(f"token {t}: entropy={h:.3f}")
print("max entropy for 4 experts:", np.log(4))

8. Z-loss proxy

Code cell 18

lse = np.log(np.exp(router_logits).sum(axis=1))
z_loss = np.mean(lse**2)
print("log-sum-exp values:", np.round(lse, 3))
print("z-loss proxy:", round(z_loss, 4))

9. All-to-all traffic toy

Code cell 20

expert_to_rank = np.array([0, 0, 1, 1])
token_origin_rank = np.array([0, 0, 0, 1, 1, 1])
token_expert = np.array([0, 2, 3, 1, 2, 3])
dest_rank = expert_to_rank[token_expert]
cross = token_origin_rank != dest_rank
print("destination ranks:", dest_rank)
print("cross-rank tokens:", cross.sum(), "of", len(cross))

10. Router collapse simulation

Code cell 22

rng = np.random.default_rng(0)
balanced = rng.choice(8, size=512, p=np.ones(8) / 8)
collapsed = rng.choice(8, size=512, p=np.array([0.55, 0.25, 0.10, 0.04, 0.03, 0.02, 0.005, 0.005]))
for name, arr in [("balanced", balanced), ("collapsed", collapsed)]:
    loads = np.bincount(arr, minlength=8)
    print(name, "loads", loads, "max/min", loads.max() / max(1, loads.min()))

11. Top-1 versus top-2 active compute

Code cell 24

experts = 16
dense_ffn_flops = 1.0
for k in [1, 2, 4]:
    active = k * dense_ffn_flops
    total_capacity = experts * dense_ffn_flops
    print(f"top-{k}: active FFN compute={active:.1f}x dense, total expert capacity={total_capacity:.1f}x")

12. Final MoE checklist

Code cell 26

checks = [
    "log total parameters and active parameters separately",
    "plot expert load histogram",
    "track drop rate and capacity factor",
    "track router entropy and z-loss",
    "measure all-to-all traffic and expert compute time",
    "compare top-1, top-2, and dense baselines",
]
for i, check in enumerate(checks, 1):
    print(f"{i}. {check}")