Theory Notebook
Theory Notebook
Converted from
theory.ipynbfor web reading.
Mixture of Experts and Routing: Theory Notebook
This notebook turns MoE routing into concrete arrays: router probabilities, top-k dispatch, capacity overflow, auxiliary balancing, active parameter counts, all-to-all traffic, and collapse diagnostics.
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
1. Router probabilities and top-k experts
Code cell 4
def softmax(z, axis=-1):
z = np.asarray(z, dtype=float)
z = z - np.max(z, axis=axis, keepdims=True)
e = np.exp(z)
return e / e.sum(axis=axis, keepdims=True)
router_logits = np.array([
[2.0, 0.5, -0.3, 1.2],
[0.1, 2.2, 1.8, -0.4],
[1.0, 1.1, 1.2, 1.3],
])
probs = softmax(router_logits, axis=1)
k = 2
topk = np.argsort(probs, axis=1)[:, -k:][:, ::-1]
print("router probabilities:\n", np.round(probs, 3))
print("top-k experts:\n", topk)
2. Top-k gated combination
Code cell 6
rng = np.random.default_rng(4)
tokens, experts, d = 3, 4, 5
x = rng.normal(size=(tokens, d))
expert_outputs = rng.normal(size=(tokens, experts, d))
weights = np.zeros((tokens, experts))
for t in range(tokens):
selected = topk[t]
w = probs[t, selected]
weights[t, selected] = w / w.sum()
y = np.einsum("te,ted->td", weights, expert_outputs)
print("routing weights:\n", np.round(weights, 3))
print("combined output shape:", y.shape)
3. Parameter accounting
Code cell 8
d_model = 4096
d_ff = 4 * d_model
experts = 8
k = 2
dense_ffn = 2 * d_model * d_ff
total_expert = experts * dense_ffn
active_expert = k * dense_ffn
router = d_model * experts
print("dense FFN params:", f"{dense_ffn:,}")
print("total expert params:", f"{total_expert:,}")
print("active expert params per token:", f"{active_expert:,}")
print("router params:", f"{router:,}")
4. Capacity and overflow
Code cell 10
assignments = np.array([0, 0, 0, 1, 1, 2, 3, 3, 3, 3])
tokens = len(assignments)
experts = 4
capacity_factor = 1.25
capacity = int(np.ceil(capacity_factor * tokens / experts))
loads = np.bincount(assignments, minlength=experts)
overflow = np.maximum(0, loads - capacity)
print("capacity per expert:", capacity)
print("loads:", loads)
print("overflow:", overflow)
print("drop rate:", overflow.sum() / tokens)
5. Expert histogram
Code cell 12
fig, ax = plt.subplots(figsize=(7, 4))
ax.bar(np.arange(experts), loads)
ax.axhline(capacity, color="red", linestyle="--", label="capacity")
ax.set_title("Expert token loads")
ax.set_xlabel("expert")
ax.set_ylabel("tokens")
ax.legend()
fig.tight_layout()
plt.show()
print("max/min load ratio:", loads.max() / max(1, loads.min()))
6. Switch-style auxiliary loss proxy
Code cell 14
router_probs = probs
top1 = np.argmax(router_probs, axis=1)
f = np.bincount(top1, minlength=router_probs.shape[1]) / router_probs.shape[0]
P = router_probs.mean(axis=0)
aux = router_probs.shape[1] * np.sum(f * P)
print("fraction routed f:", np.round(f, 3))
print("mean probability P:", np.round(P, 3))
print("aux loss proxy:", round(aux, 4))
7. Router entropy
Code cell 16
entropy = -np.sum(probs * np.log(probs + 1e-12), axis=1)
for t, h in enumerate(entropy):
print(f"token {t}: entropy={h:.3f}")
print("max entropy for 4 experts:", np.log(4))
8. Z-loss proxy
Code cell 18
lse = np.log(np.exp(router_logits).sum(axis=1))
z_loss = np.mean(lse**2)
print("log-sum-exp values:", np.round(lse, 3))
print("z-loss proxy:", round(z_loss, 4))
9. All-to-all traffic toy
Code cell 20
expert_to_rank = np.array([0, 0, 1, 1])
token_origin_rank = np.array([0, 0, 0, 1, 1, 1])
token_expert = np.array([0, 2, 3, 1, 2, 3])
dest_rank = expert_to_rank[token_expert]
cross = token_origin_rank != dest_rank
print("destination ranks:", dest_rank)
print("cross-rank tokens:", cross.sum(), "of", len(cross))
10. Router collapse simulation
Code cell 22
rng = np.random.default_rng(0)
balanced = rng.choice(8, size=512, p=np.ones(8) / 8)
collapsed = rng.choice(8, size=512, p=np.array([0.55, 0.25, 0.10, 0.04, 0.03, 0.02, 0.005, 0.005]))
for name, arr in [("balanced", balanced), ("collapsed", collapsed)]:
loads = np.bincount(arr, minlength=8)
print(name, "loads", loads, "max/min", loads.max() / max(1, loads.min()))
11. Top-1 versus top-2 active compute
Code cell 24
experts = 16
dense_ffn_flops = 1.0
for k in [1, 2, 4]:
active = k * dense_ffn_flops
total_capacity = experts * dense_ffn_flops
print(f"top-{k}: active FFN compute={active:.1f}x dense, total expert capacity={total_capacity:.1f}x")
12. Final MoE checklist
Code cell 26
checks = [
"log total parameters and active parameters separately",
"plot expert load histogram",
"track drop rate and capacity factor",
"track router entropy and z-loss",
"measure all-to-all traffic and expert compute time",
"compare top-1, top-2, and dense baselines",
]
for i, check in enumerate(checks, 1):
print(f"{i}. {check}")