Exercises NotebookMath for LLMs

Efficient Attention and Inference

Math for LLMs / Efficient Attention and Inference

Run notebook
Exercises Notebook

Exercises Notebook

Converted from exercises.ipynb for web reading.

Efficient Attention and Inference: Exercises

Ten exercises cover the serving math behind efficient LLM inference: KV cache, GQA, score memory, bandwidth, paging, batching, speculative decoding, and latency budgets.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Exercise 1: KV cache memory

Compute KV cache memory in GB.

Code cell 4

# Your Solution
B, L, T, H_kv, d_h, b = 4, 24, 2048, 8, 128, 2
print("Starter: M = 2 * B * L * T * H_kv * d_h * b bytes.")

Code cell 5

# Solution
B, L, T, H_kv, d_h, b = 4, 24, 2048, 8, 128, 2
gb = 2 * B * L * T * H_kv * d_h * b / 1e9
print("KV cache GB:", gb)

Exercise 2: GQA cache savings

Compare MHA with 32 KV heads to GQA with 8 KV heads.

Code cell 7

# Your Solution
H_q = 32
H_kv = 8
print("Starter: savings = H_q / H_kv.")

Code cell 8

# Solution
H_q = 32
H_kv = 8
savings = H_q / H_kv
print("cache reduction:", savings)

Exercise 3: Decode operation proxy

Compute T*d for a single decode attention step.

Code cell 10

# Your Solution
T = 4096
d = 4096
print("Starter: multiply T and d.")

Code cell 11

# Solution
T = 4096
d = 4096
ops = T * d
print("decode attention proxy:", ops)

Exercise 4: Score matrix memory

Compute bf16 attention-score memory for B=2,H=16,T=4096.

Code cell 13

# Your Solution
B, H, T, bytes_per = 2, 16, 4096, 2
print("Starter: B*H*T*T*bytes_per.")

Code cell 14

# Solution
B, H, T, bytes_per = 2, 16, 4096, 2
gb = B * H * T * T * bytes_per / 1e9
print("score memory GB:", gb)

Exercise 5: Bandwidth runtime

Estimate time to read 40 GB at 2 TB/s.

Code cell 16

# Your Solution
gb = 40
bandwidth_gbps = 2000
print("Starter: seconds = gb / bandwidth_gbps.")

Code cell 17

# Solution
gb = 40
bandwidth_gbps = 2000
seconds = gb / bandwidth_gbps
print("milliseconds:", seconds * 1000)

Exercise 6: Paged waste

Compute block waste for request lengths.

Code cell 19

# Your Solution
lengths = np.array([9, 16, 17])
block = 16
print("Starter: allocated=ceil(length/block)*block.")

Code cell 20

# Solution
lengths = np.array([9, 16, 17])
block = 16
allocated = np.ceil(lengths / block) * block
waste = allocated - lengths
print("allocated:", allocated.astype(int))
print("waste:", waste.astype(int))

Exercise 7: Continuous batch

Find average active requests during decode.

Code cell 22

# Your Solution
remaining = np.array([3, 1, 2])
print("Starter: decrement active requests one step at a time.")

Code cell 23

# Solution
remaining = np.array([3, 1, 2])
active_counts = []
while (remaining > 0).any():
    active = remaining > 0
    active_counts.append(active.sum())
    remaining[active] -= 1
print("active counts:", active_counts)
print("average:", np.mean(active_counts))

Exercise 8: Speculative speedup

Estimate speedup with 4 draft tokens, acceptance 0.75, draft cost fraction 0.10.

Code cell 25

# Your Solution
k = 4
a = 0.75
draft_cost = 0.10
print("Starter: expected tokens = 1 + k*a; cost = 1 + k*draft_cost.")

Code cell 26

# Solution
k = 4
a = 0.75
draft_cost = 0.10
speedup = (1 + k * a) / (1 + k * draft_cost)
print("speedup:", speedup)

Exercise 9: Latency budget

Compute TTFT and total latency.

Code cell 28

# Your Solution
queue, prefill, sample = 10, 90, 1
output_tokens, tpot = 50, 20
print("Starter: TTFT=queue+prefill+sample; total=TTFT+output_tokens*tpot.")

Code cell 29

# Solution
queue, prefill, sample = 10, 90, 1
output_tokens, tpot = 50, 20
ttft = queue + prefill + sample
total = ttft + output_tokens * tpot
print("TTFT:", ttft)
print("total:", total)

Exercise 10: Correctness checklist

Write four checks for an inference optimization.

Code cell 31

# Your Solution
print("Starter: include cache match, latency split, memory accounting, quality regression.")

Code cell 32

# Solution
checks = [
    "cached decode matches full recomputation",
    "prefill and decode latency are measured separately",
    "KV cache memory is computed for actual batch and context",
    "quality is checked after quantization or speculative decoding",
]
for check in checks:
    print("-", check)

Closing Reflection

Efficient inference work should always preserve the target behavior unless an approximation is intentional and measured.