Theory NotebookMath for LLMs

Efficient Attention and Inference

Math for LLMs / Efficient Attention and Inference

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Efficient Attention and Inference: Theory Notebook

This notebook makes serving math concrete: prefill versus decode, KV cache memory, MHA/MQA/GQA savings, FlashAttention memory intuition, paged-cache waste, speculative decoding, and latency budgets.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

1. Prefill versus decode cost

Code cell 4

T = np.arange(128, 8193, 256)
d = 4096
prefill_ops = T**2 * d
decode_ops_per_token = T * d

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(T, prefill_ops / 1e9, label="prefill attention proxy")
ax.plot(T, decode_ops_per_token / 1e6, label="decode per-token proxy")
ax.set_title("Prefill grows quadratically; decode step grows linearly")
ax.set_xlabel("context length T")
ax.set_ylabel("scaled operation proxy")
ax.legend()
fig.tight_layout()
plt.show()
print("8192-token prefill proxy:", prefill_ops[-1])
print("8192-token decode proxy:", decode_ops_per_token[-1])

2. KV cache memory calculator

Code cell 6

def kv_cache_gb(batch, layers, tokens, kv_heads, head_dim, bytes_per=2):
    return 2 * batch * layers * tokens * kv_heads * head_dim * bytes_per / 1e9

batch, layers, tokens, q_heads, head_dim = 8, 32, 4096, 32, 128
for kv_heads in [32, 8, 1]:
    print(f"KV heads={kv_heads:>2}: {kv_cache_gb(batch, layers, tokens, kv_heads, head_dim):.2f} GB")

3. MHA, GQA, and MQA savings

Code cell 8

q_heads = 32
kv_options = np.array([32, 16, 8, 4, 1])
savings = q_heads / kv_options
for kv, s in zip(kv_options, savings):
    print(f"H_kv={kv:>2}: cache reduction versus MHA = {s:.1f}x")

4. Naive attention-score memory

Code cell 10

B, H = 4, 32
lengths = np.array([1024, 2048, 4096, 8192, 16384])
score_gb = B * H * lengths**2 * 2 / 1e9
for T, gb in zip(lengths, score_gb):
    print(f"T={T:>5}: score matrix bf16 memory={gb:8.2f} GB")

5. Online softmax equivalence

Code cell 12

scores = np.array([1.2, -0.5, 3.0, 2.2, -1.0, 0.7])
values = np.arange(len(scores), dtype=float)

full = np.exp(scores - scores.max())
full_out = np.sum((full / full.sum()) * values)

m = -np.inf
l = 0.0
weighted = 0.0
for s, v in zip(scores, values):
    m_new = max(m, s)
    l = l * np.exp(m - m_new) + np.exp(s - m_new)
    weighted = weighted * np.exp(m - m_new) + np.exp(s - m_new) * v
    m = m_new
online_out = weighted / l
print("full softmax output:", full_out)
print("online softmax output:", online_out)
print("difference:", abs(full_out - online_out))

6. Roofline bandwidth intuition

Code cell 14

peak_flops = 312e12
bandwidth = 2e12
ridge = peak_flops / bandwidth
intensity = np.logspace(-1, 4, 80)
achieved = np.minimum(peak_flops, bandwidth * intensity)

fig, ax = plt.subplots(figsize=(7, 4))
ax.loglog(intensity, achieved / 1e12)
ax.axvline(ridge, color="red", linestyle="--", label=f"ridge={ridge:.0f} FLOP/byte")
ax.set_title("Roofline model")
ax.set_xlabel("arithmetic intensity FLOP/byte")
ax.set_ylabel("achieved TFLOP/s")
ax.legend()
fig.tight_layout()
plt.show()
print("ridge point:", ridge)

7. Paged cache waste

Code cell 16

lengths = np.array([10, 17, 33, 64, 65, 100])
block = 16
allocated = np.ceil(lengths / block) * block
waste = allocated - lengths
for l, a, w in zip(lengths, allocated, waste):
    print(f"length={l:>3}, allocated={int(a):>3}, waste={int(w):>2}")
print("waste fraction:", waste.sum() / allocated.sum())

8. Continuous batching toy schedule

Code cell 18

remaining = np.array([5, 2, 7, 1])
active_counts = []
step = 0
while (remaining > 0).any():
    active = remaining > 0
    active_counts.append(active.sum())
    remaining[active] -= 1
    step += 1
print("active batch size by decode step:", active_counts)
print("average active batch:", np.mean(active_counts))

9. Speculative decoding expected speedup

Code cell 20

draft_tokens = 4
accept_rates = np.linspace(0.2, 0.95, 8)
draft_cost_fraction = 0.15
for a in accept_rates:
    expected_tokens = 1 + draft_tokens * a
    cost = 1 + draft_cost_fraction * draft_tokens
    speedup = expected_tokens / cost
    print(f"accept={a:.2f}: expected speedup={speedup:.2f}x")

10. Latency budget

Code cell 22

queue_ms = 12
prefill_ms = 180
output_tokens = 120
tpot_ms = 22
sample_ms = 1
ttft = queue_ms + prefill_ms + sample_ms
total = ttft + output_tokens * tpot_ms
print("TTFT ms:", ttft)
print("total latency ms:", total)
print("tokens/sec after first token:", 1000 / tpot_ms)

11. Cached decode correctness check

Code cell 24

rng = np.random.default_rng(5)
full = rng.normal(size=(4,))
cached = full + rng.normal(scale=1e-6, size=(4,))
max_diff = np.max(np.abs(full - cached))
print("max difference:", max_diff)
print("pass:", max_diff < 1e-5)

12. Inference debugging checklist

Code cell 26

checks = [
    "measure prefill and decode separately",
    "compute KV cache memory from actual batch, context, layers, and KV heads",
    "verify cached decode matches full recomputation on a tiny case",
    "track p50, p95, and p99 latency, not only mean",
    "test quality after quantization or speculative acceleration",
    "attribute bottlenecks to weights, KV cache, kernels, scheduler, or network",
]
for i, check in enumerate(checks, 1):
    print(f"{i}. {check}")