Exercises Notebook
Converted from
exercises.ipynbfor web reading.
Efficient Attention and Inference: Exercises
Ten exercises cover the serving math behind efficient LLM inference: KV cache, GQA, score memory, bandwidth, paging, batching, speculative decoding, and latency budgets.
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
Exercise 1: KV cache memory
Compute KV cache memory in GB.
Code cell 4
# Your Solution
B, L, T, H_kv, d_h, b = 4, 24, 2048, 8, 128, 2
print("Starter: M = 2 * B * L * T * H_kv * d_h * b bytes.")
Code cell 5
# Solution
B, L, T, H_kv, d_h, b = 4, 24, 2048, 8, 128, 2
gb = 2 * B * L * T * H_kv * d_h * b / 1e9
print("KV cache GB:", gb)
Exercise 2: GQA cache savings
Compare MHA with 32 KV heads to GQA with 8 KV heads.
Code cell 7
# Your Solution
H_q = 32
H_kv = 8
print("Starter: savings = H_q / H_kv.")
Code cell 8
# Solution
H_q = 32
H_kv = 8
savings = H_q / H_kv
print("cache reduction:", savings)
Exercise 3: Decode operation proxy
Compute T*d for a single decode attention step.
Code cell 10
# Your Solution
T = 4096
d = 4096
print("Starter: multiply T and d.")
Code cell 11
# Solution
T = 4096
d = 4096
ops = T * d
print("decode attention proxy:", ops)
Exercise 4: Score matrix memory
Compute bf16 attention-score memory for B=2,H=16,T=4096.
Code cell 13
# Your Solution
B, H, T, bytes_per = 2, 16, 4096, 2
print("Starter: B*H*T*T*bytes_per.")
Code cell 14
# Solution
B, H, T, bytes_per = 2, 16, 4096, 2
gb = B * H * T * T * bytes_per / 1e9
print("score memory GB:", gb)
Exercise 5: Bandwidth runtime
Estimate time to read 40 GB at 2 TB/s.
Code cell 16
# Your Solution
gb = 40
bandwidth_gbps = 2000
print("Starter: seconds = gb / bandwidth_gbps.")
Code cell 17
# Solution
gb = 40
bandwidth_gbps = 2000
seconds = gb / bandwidth_gbps
print("milliseconds:", seconds * 1000)
Exercise 6: Paged waste
Compute block waste for request lengths.
Code cell 19
# Your Solution
lengths = np.array([9, 16, 17])
block = 16
print("Starter: allocated=ceil(length/block)*block.")
Code cell 20
# Solution
lengths = np.array([9, 16, 17])
block = 16
allocated = np.ceil(lengths / block) * block
waste = allocated - lengths
print("allocated:", allocated.astype(int))
print("waste:", waste.astype(int))
Exercise 7: Continuous batch
Find average active requests during decode.
Code cell 22
# Your Solution
remaining = np.array([3, 1, 2])
print("Starter: decrement active requests one step at a time.")
Code cell 23
# Solution
remaining = np.array([3, 1, 2])
active_counts = []
while (remaining > 0).any():
active = remaining > 0
active_counts.append(active.sum())
remaining[active] -= 1
print("active counts:", active_counts)
print("average:", np.mean(active_counts))
Exercise 8: Speculative speedup
Estimate speedup with 4 draft tokens, acceptance 0.75, draft cost fraction 0.10.
Code cell 25
# Your Solution
k = 4
a = 0.75
draft_cost = 0.10
print("Starter: expected tokens = 1 + k*a; cost = 1 + k*draft_cost.")
Code cell 26
# Solution
k = 4
a = 0.75
draft_cost = 0.10
speedup = (1 + k * a) / (1 + k * draft_cost)
print("speedup:", speedup)
Exercise 9: Latency budget
Compute TTFT and total latency.
Code cell 28
# Your Solution
queue, prefill, sample = 10, 90, 1
output_tokens, tpot = 50, 20
print("Starter: TTFT=queue+prefill+sample; total=TTFT+output_tokens*tpot.")
Code cell 29
# Solution
queue, prefill, sample = 10, 90, 1
output_tokens, tpot = 50, 20
ttft = queue + prefill + sample
total = ttft + output_tokens * tpot
print("TTFT:", ttft)
print("total:", total)
Exercise 10: Correctness checklist
Write four checks for an inference optimization.
Code cell 31
# Your Solution
print("Starter: include cache match, latency split, memory accounting, quality regression.")
Code cell 32
# Solution
checks = [
"cached decode matches full recomputation",
"prefill and decode latency are measured separately",
"KV cache memory is computed for actual batch and context",
"quality is checked after quantization or speculative decoding",
]
for check in checks:
print("-", check)
Closing Reflection
Efficient inference work should always preserve the target behavior unless an approximation is intentional and measured.