Exercises NotebookMath for LLMs

Attention Mechanism Math

Math for LLMs / Attention Mechanism Math

Run notebook
Exercises Notebook

Exercises Notebook

Converted from exercises.ipynb for web reading.

Exercises: Attention Mechanism Math

There are 10 exercises. Exercises 1-3 cover core attention, 4-7 cover heads and masks, and 8-10 cover serving and diagnostics.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3


COLORS = {
    "primary":   "#0077BB",
    "secondary": "#EE7733",
    "tertiary":  "#009988",
    "error":     "#CC3311",
    "neutral":   "#555555",
    "highlight": "#EE3377",
}

def header(title):
    print("\n" + "=" * 72)
    print(title)
    print("=" * 72)

def check_true(condition, name):
    ok = bool(condition)
    print(f"{'PASS' if ok else 'FAIL'} - {name}")
    assert ok, name

def check_close(value, target, tol=1e-8, name="value"):
    value = float(value)
    target = float(target)
    ok = abs(value - target) <= tol
    print(f"{'PASS' if ok else 'FAIL'} - {name}: got {value:.6f}, expected {target:.6f}")
    assert ok, name

def stable_softmax(scores, axis=-1):
    scores = np.asarray(scores, dtype=float)
    shifted = scores - np.max(scores, axis=axis, keepdims=True)
    exp = np.exp(shifted)
    return exp / exp.sum(axis=axis, keepdims=True)

def attention(Q, K, V, mask=None):
    dk = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(dk)
    if mask is not None:
        scores = scores + mask
    weights = stable_softmax(scores, axis=-1)
    return weights @ V, weights, scores

def causal_mask(T):
    mask = np.zeros((T, T))
    mask[np.triu_indices(T, k=1)] = -1e9
    return mask

def attention_entropy(weights):
    w = np.clip(weights, 1e-12, 1.0)
    return -(w * np.log(w)).sum(axis=-1)

def split_heads(X, n_heads):
    B, T, D = X.shape
    assert D % n_heads == 0
    return X.reshape(B, T, n_heads, D // n_heads).transpose(0, 2, 1, 3)

def combine_heads(H):
    B, heads, T, Dh = H.shape
    return H.transpose(0, 2, 1, 3).reshape(B, T, heads * Dh)

def alibi_bias(T, slope=-0.5):
    i = np.arange(T)[:, None]
    j = np.arange(T)[None, :]
    return slope * np.maximum(i - j, 0)

print("Attention helpers ready.")

Exercise 1: Scaled attention (*)

Compute weights and outputs for a two-token example. State the shapes, compute the result, and explain the LLM consequence.

Code cell 5

# Your Solution - Exercise 1
answer = None
print("Your answer placeholder:", answer)

Code cell 6

# Solution - Exercise 1
header("Exercise 1: Scaled attention")
Q = K = np.eye(2)
V = np.array([[1.0, 2.0], [3.0, 4.0]])
Y, W, S = attention(Q, K, V)
print("Weights:", np.round(W, 4))
check_true(np.allclose(W.sum(axis=1), 1.0), "rows normalize")
print("\nTakeaway: attention is safest when shapes, masks, softmax rows, and value mixing are tested explicitly.")

Exercise 2: Causal mask (*)

Verify future positions receive zero mass. State the shapes, compute the result, and explain the LLM consequence.

Code cell 8

# Your Solution - Exercise 2
answer = None
print("Your answer placeholder:", answer)

Code cell 9

# Solution - Exercise 2
header("Exercise 2: Causal mask")
T = 3
Y, W, S = attention(np.eye(T), np.eye(T), np.eye(T), mask=causal_mask(T))
future = W[np.triu_indices(T, k=1)].sum()
print("Future mass:", future)
check_close(future, 0.0, tol=1e-8, name="causal future mass")
print("\nTakeaway: attention is safest when shapes, masks, softmax rows, and value mixing are tested explicitly.")

Exercise 3: Attention entropy (*)

Compare sharp and diffuse rows. State the shapes, compute the result, and explain the LLM consequence.

Code cell 11

# Your Solution - Exercise 3
answer = None
print("Your answer placeholder:", answer)

Code cell 12

# Solution - Exercise 3
header("Exercise 3: Attention entropy")
h1 = attention_entropy(np.array([[0.9, 0.1]]))[0]
h2 = attention_entropy(np.array([[0.5, 0.5]]))[0]
print("Entropies:", h1, h2)
check_true(h2 > h1, "balanced row has higher entropy")
print("\nTakeaway: attention is safest when shapes, masks, softmax rows, and value mixing are tested explicitly.")

Exercise 4: Multi-head shapes (**)

Split and combine head dimensions. State the shapes, compute the result, and explain the LLM consequence.

Code cell 14

# Your Solution - Exercise 4
answer = None
print("Your answer placeholder:", answer)

Code cell 15

# Solution - Exercise 4
header("Exercise 4: Multi-head shapes")
X = np.zeros((1, 6, 16))
H = split_heads(X, 4)
print("Head shape:", H.shape)
check_true(H.shape == (1, 4, 6, 4), "B heads T Dh")
print("\nTakeaway: attention is safest when shapes, masks, softmax rows, and value mixing are tested explicitly.")

Exercise 5: Padding mask (**)

Mask a padded key before softmax. State the shapes, compute the result, and explain the LLM consequence.

Code cell 17

# Your Solution - Exercise 5
answer = None
print("Your answer placeholder:", answer)

Code cell 18

# Solution - Exercise 5
header("Exercise 5: Padding mask")
w = stable_softmax(np.array([[1.0, 2.0, 99.0]]) + np.array([[0.0, 0.0, -1e9]]))
print("Weights:", w)
check_close(w[0, 2], 0.0, tol=1e-8, name="masked pad weight")
print("\nTakeaway: attention is safest when shapes, masks, softmax rows, and value mixing are tested explicitly.")

Exercise 6: KV cache bytes (**)

Compute cache memory for a small decoder. State the shapes, compute the result, and explain the LLM consequence.

Code cell 20

# Your Solution - Exercise 6
answer = None
print("Your answer placeholder:", answer)

Code cell 21

# Solution - Exercise 6
header("Exercise 6: KV cache bytes")
layers, heads, T, dh, bytes_per = 2, 4, 10, 8, 2
size = layers * 2 * heads * T * dh * bytes_per
print("KV bytes:", size)
check_true(size == 2560, "cache formula counts K and V")
print("\nTakeaway: attention is safest when shapes, masks, softmax rows, and value mixing are tested explicitly.")

Exercise 7: ALiBi bias (**)

Build a linear distance penalty. State the shapes, compute the result, and explain the LLM consequence.

Code cell 23

# Your Solution - Exercise 7
answer = None
print("Your answer placeholder:", answer)

Code cell 24

# Solution - Exercise 7
header("Exercise 7: ALiBi bias")
bias = alibi_bias(3, slope=-1.0)
print("Bias:", bias)
check_close(bias[2, 0], -2.0, name="distance two penalty")
print("\nTakeaway: attention is safest when shapes, masks, softmax rows, and value mixing are tested explicitly.")

Exercise 8: Prefill versus decode (***)

Compare score counts for prompt and one-token decode. State the shapes, compute the result, and explain the LLM consequence.

Code cell 26

# Your Solution - Exercise 8
answer = None
print("Your answer placeholder:", answer)

Code cell 27

# Solution - Exercise 8
header("Exercise 8: Prefill versus decode")
T = 100
prefill_scores = T * T
decode_scores = T
print("Prefill:", prefill_scores, "Decode one token:", decode_scores)
check_true(prefill_scores > decode_scores, "prefill and decode workloads differ")
print("\nTakeaway: attention is safest when shapes, masks, softmax rows, and value mixing are tested explicitly.")

Exercise 9: FlashAttention (***)

Explain exact tiling without changing the formula. State the shapes, compute the result, and explain the LLM consequence.

Code cell 29

# Your Solution - Exercise 9
answer = None
print("Your answer placeholder:", answer)

Code cell 30

# Solution - Exercise 9
header("Exercise 9: FlashAttention")
T = 64
tiles = 8
covered = tiles * (T // tiles) * T
print("Covered score pairs:", covered)
check_true(covered == T * T, "tiling covers exact same pair set")
print("\nTakeaway: attention is safest when shapes, masks, softmax rows, and value mixing are tested explicitly.")

Exercise 10: Prompt boundary mask (***)

Design a visibility matrix for protected tokens. State the shapes, compute the result, and explain the LLM consequence.

Code cell 32

# Your Solution - Exercise 10
answer = None
print("Your answer placeholder:", answer)

Code cell 33

# Solution - Exercise 10
header("Exercise 10: Prompt boundary mask")
visible = np.tril(np.ones((4, 4)))
visible[3, 1] = 0.0
print("Visibility:", visible)
check_true(visible[3, 1] == 0.0 and visible[3, 3] == 1.0, "protected boundary can block a specific key")
print("\nTakeaway: attention is safest when shapes, masks, softmax rows, and value mixing are tested explicitly.")