Theory Notebook
Theory Notebook
Converted from
theory.ipynbfor web reading.
Efficient Attention and Inference: Theory Notebook
This notebook makes serving math concrete: prefill versus decode, KV cache memory, MHA/MQA/GQA savings, FlashAttention memory intuition, paged-cache waste, speculative decoding, and latency budgets.
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
1. Prefill versus decode cost
Code cell 4
T = np.arange(128, 8193, 256)
d = 4096
prefill_ops = T**2 * d
decode_ops_per_token = T * d
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(T, prefill_ops / 1e9, label="prefill attention proxy")
ax.plot(T, decode_ops_per_token / 1e6, label="decode per-token proxy")
ax.set_title("Prefill grows quadratically; decode step grows linearly")
ax.set_xlabel("context length T")
ax.set_ylabel("scaled operation proxy")
ax.legend()
fig.tight_layout()
plt.show()
print("8192-token prefill proxy:", prefill_ops[-1])
print("8192-token decode proxy:", decode_ops_per_token[-1])
2. KV cache memory calculator
Code cell 6
def kv_cache_gb(batch, layers, tokens, kv_heads, head_dim, bytes_per=2):
return 2 * batch * layers * tokens * kv_heads * head_dim * bytes_per / 1e9
batch, layers, tokens, q_heads, head_dim = 8, 32, 4096, 32, 128
for kv_heads in [32, 8, 1]:
print(f"KV heads={kv_heads:>2}: {kv_cache_gb(batch, layers, tokens, kv_heads, head_dim):.2f} GB")
3. MHA, GQA, and MQA savings
Code cell 8
q_heads = 32
kv_options = np.array([32, 16, 8, 4, 1])
savings = q_heads / kv_options
for kv, s in zip(kv_options, savings):
print(f"H_kv={kv:>2}: cache reduction versus MHA = {s:.1f}x")
4. Naive attention-score memory
Code cell 10
B, H = 4, 32
lengths = np.array([1024, 2048, 4096, 8192, 16384])
score_gb = B * H * lengths**2 * 2 / 1e9
for T, gb in zip(lengths, score_gb):
print(f"T={T:>5}: score matrix bf16 memory={gb:8.2f} GB")
5. Online softmax equivalence
Code cell 12
scores = np.array([1.2, -0.5, 3.0, 2.2, -1.0, 0.7])
values = np.arange(len(scores), dtype=float)
full = np.exp(scores - scores.max())
full_out = np.sum((full / full.sum()) * values)
m = -np.inf
l = 0.0
weighted = 0.0
for s, v in zip(scores, values):
m_new = max(m, s)
l = l * np.exp(m - m_new) + np.exp(s - m_new)
weighted = weighted * np.exp(m - m_new) + np.exp(s - m_new) * v
m = m_new
online_out = weighted / l
print("full softmax output:", full_out)
print("online softmax output:", online_out)
print("difference:", abs(full_out - online_out))
6. Roofline bandwidth intuition
Code cell 14
peak_flops = 312e12
bandwidth = 2e12
ridge = peak_flops / bandwidth
intensity = np.logspace(-1, 4, 80)
achieved = np.minimum(peak_flops, bandwidth * intensity)
fig, ax = plt.subplots(figsize=(7, 4))
ax.loglog(intensity, achieved / 1e12)
ax.axvline(ridge, color="red", linestyle="--", label=f"ridge={ridge:.0f} FLOP/byte")
ax.set_title("Roofline model")
ax.set_xlabel("arithmetic intensity FLOP/byte")
ax.set_ylabel("achieved TFLOP/s")
ax.legend()
fig.tight_layout()
plt.show()
print("ridge point:", ridge)
7. Paged cache waste
Code cell 16
lengths = np.array([10, 17, 33, 64, 65, 100])
block = 16
allocated = np.ceil(lengths / block) * block
waste = allocated - lengths
for l, a, w in zip(lengths, allocated, waste):
print(f"length={l:>3}, allocated={int(a):>3}, waste={int(w):>2}")
print("waste fraction:", waste.sum() / allocated.sum())
8. Continuous batching toy schedule
Code cell 18
remaining = np.array([5, 2, 7, 1])
active_counts = []
step = 0
while (remaining > 0).any():
active = remaining > 0
active_counts.append(active.sum())
remaining[active] -= 1
step += 1
print("active batch size by decode step:", active_counts)
print("average active batch:", np.mean(active_counts))
9. Speculative decoding expected speedup
Code cell 20
draft_tokens = 4
accept_rates = np.linspace(0.2, 0.95, 8)
draft_cost_fraction = 0.15
for a in accept_rates:
expected_tokens = 1 + draft_tokens * a
cost = 1 + draft_cost_fraction * draft_tokens
speedup = expected_tokens / cost
print(f"accept={a:.2f}: expected speedup={speedup:.2f}x")
10. Latency budget
Code cell 22
queue_ms = 12
prefill_ms = 180
output_tokens = 120
tpot_ms = 22
sample_ms = 1
ttft = queue_ms + prefill_ms + sample_ms
total = ttft + output_tokens * tpot_ms
print("TTFT ms:", ttft)
print("total latency ms:", total)
print("tokens/sec after first token:", 1000 / tpot_ms)
11. Cached decode correctness check
Code cell 24
rng = np.random.default_rng(5)
full = rng.normal(size=(4,))
cached = full + rng.normal(scale=1e-6, size=(4,))
max_diff = np.max(np.abs(full - cached))
print("max difference:", max_diff)
print("pass:", max_diff < 1e-5)
12. Inference debugging checklist
Code cell 26
checks = [
"measure prefill and decode separately",
"compute KV cache memory from actual batch, context, layers, and KV heads",
"verify cached decode matches full recomputation on a tiny case",
"track p50, p95, and p99 latency, not only mean",
"test quality after quantization or speculative acceleration",
"attribute bottlenecks to weights, KV cache, kernels, scheduler, or network",
]
for i, check in enumerate(checks, 1):
print(f"{i}. {check}")