Theory NotebookMath for LLMs

Attention Mechanism Math

Math for LLMs / Attention Mechanism Math

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Attention Mechanism Math

This notebook is the executable companion to notes.md. It checks scaled dot-product attention, masks, entropy, multi-head shapes, KV-cache cost, ALiBi bias, and efficient-attention intuition.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3


COLORS = {
    "primary":   "#0077BB",
    "secondary": "#EE7733",
    "tertiary":  "#009988",
    "error":     "#CC3311",
    "neutral":   "#555555",
    "highlight": "#EE3377",
}

def header(title):
    print("\n" + "=" * 72)
    print(title)
    print("=" * 72)

def check_true(condition, name):
    ok = bool(condition)
    print(f"{'PASS' if ok else 'FAIL'} - {name}")
    assert ok, name

def check_close(value, target, tol=1e-8, name="value"):
    value = float(value)
    target = float(target)
    ok = abs(value - target) <= tol
    print(f"{'PASS' if ok else 'FAIL'} - {name}: got {value:.6f}, expected {target:.6f}")
    assert ok, name

def stable_softmax(scores, axis=-1):
    scores = np.asarray(scores, dtype=float)
    shifted = scores - np.max(scores, axis=axis, keepdims=True)
    exp = np.exp(shifted)
    return exp / exp.sum(axis=axis, keepdims=True)

def attention(Q, K, V, mask=None):
    dk = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(dk)
    if mask is not None:
        scores = scores + mask
    weights = stable_softmax(scores, axis=-1)
    return weights @ V, weights, scores

def causal_mask(T):
    mask = np.zeros((T, T))
    mask[np.triu_indices(T, k=1)] = -1e9
    return mask

def attention_entropy(weights):
    w = np.clip(weights, 1e-12, 1.0)
    return -(w * np.log(w)).sum(axis=-1)

def split_heads(X, n_heads):
    B, T, D = X.shape
    assert D % n_heads == 0
    return X.reshape(B, T, n_heads, D // n_heads).transpose(0, 2, 1, 3)

def combine_heads(H):
    B, heads, T, Dh = H.shape
    return H.transpose(0, 2, 1, 3).reshape(B, T, heads * Dh)

def alibi_bias(T, slope=-0.5):
    i = np.arange(T)[:, None]
    j = np.arange(T)[None, :]
    return slope * np.maximum(i - j, 0)

print("Attention helpers ready.")

Demo 1: Attention as soft retrieval

This demo turns one attention concept into a small checked computation.

Code cell 5

header("Demo 1: Attention as soft retrieval - scaled attention")
Q = np.array([[1.0, 0.0], [0.0, 1.0]])
K = np.array([[1.0, 0.0], [0.0, 1.0]])
V = np.array([[10.0, 0.0], [0.0, 5.0]])
Y, W, S = attention(Q, K, V)
print("Weights:", np.round(W, 4))
print("Output:", np.round(Y, 4))
check_true(np.allclose(W.sum(axis=1), 1.0), "attention rows sum to one")

Demo 2: Queries keys and values as roles

This demo turns one attention concept into a small checked computation.

Code cell 7

header("Demo 2: Queries keys and values as roles - causal mask")
T = 4
Q = K = np.eye(T)
V = np.arange(T * 2, dtype=float).reshape(T, 2)
Y, W, S = attention(Q, K, V, mask=causal_mask(T))
print("Weights:", np.round(W, 3))
future_mass = W[np.triu_indices(T, k=1)].sum()
check_close(future_mass, 0.0, tol=1e-8, name="future attention mass")

Demo 3: Why scaling is needed

This demo turns one attention concept into a small checked computation.

Code cell 9

header("Demo 3: Why scaling is needed - entropy")
sharp = np.array([[0.98, 0.01, 0.01]])
diffuse = np.array([[1/3, 1/3, 1/3]])
Hs = attention_entropy(sharp)[0]
Hd = attention_entropy(diffuse)[0]
print("Sharp entropy:", round(float(Hs), 4))
print("Diffuse entropy:", round(float(Hd), 4))
check_true(Hd > Hs, "diffuse attention has higher entropy")

Demo 4: Why masks are needed

This demo turns one attention concept into a small checked computation.

Code cell 11

header("Demo 4: Why masks are needed - multi-head shapes")
X = np.zeros((2, 5, 12))
H = split_heads(X, n_heads=3)
X_back = combine_heads(H)
print("Split shape:", H.shape)
print("Combined shape:", X_back.shape)
check_true(X_back.shape == X.shape, "combine restores model width")

Demo 5: Why attention replaced recurrence in LLMs

This demo turns one attention concept into a small checked computation.

Code cell 13

header("Demo 5: Why attention replaced recurrence in LLMs - score scaling")
rng = np.random.default_rng(1)
Q = rng.normal(size=(16, 64))
K = rng.normal(size=(16, 64))
raw = Q @ K.T
scaled = raw / np.sqrt(64)
print("Raw std:", round(float(raw.std()), 4))
print("Scaled std:", round(float(scaled.std()), 4))
check_true(scaled.std() < raw.std(), "scaling controls dot-product variance")

Demo 6: Input hidden-state matrix

This demo turns one attention concept into a small checked computation.

Code cell 15

header("Demo 6: Input hidden-state matrix - padding mask")
scores = np.array([[2.0, 1.0, 5.0]])
mask = np.array([[0.0, 0.0, -1e9]])
w = stable_softmax(scores + mask, axis=-1)
print("Weights:", np.round(w, 6))
check_close(w[0, 2], 0.0, tol=1e-8, name="padded key receives zero weight")

Demo 7: Linear Q K V projections

This demo turns one attention concept into a small checked computation.

Code cell 17

header("Demo 7: Linear Q K V projections - KV cache size")
layers, heads, T, dh = 24, 16, 2048, 64
bytes_per = 2
kv_bytes = layers * 2 * heads * T * dh * bytes_per
print("KV cache MiB:", round(kv_bytes / 2**20, 2))
check_true(kv_bytes > 0, "KV cache memory is positive and grows with T")

Demo 8: Scaled dot-product attention

This demo turns one attention concept into a small checked computation.

Code cell 19

header("Demo 8: Scaled dot-product attention - ALiBi bias")
bias = alibi_bias(5, slope=-0.25)
print("Bias:", np.round(bias, 2))
check_close(bias[4, 0], -1.0, name="linear distance penalty")

Demo 9: Attention weights

This demo turns one attention concept into a small checked computation.

Code cell 21

header("Demo 9: Attention weights - quadratic cost")
T1, T2 = 1024, 2048
ratio = (T2 / T1) ** 2
print("Score matrix cost ratio:", ratio)
check_close(ratio, 4.0, name="doubling context quadruples score matrix")
fig, ax = plt.subplots()
Ts = np.array([512, 1024, 2048, 4096])
ax.plot(Ts, Ts**2 / 1e6, color=COLORS["primary"], label="score entries")
ax.set_title("Quadratic attention score growth")
ax.set_xlabel("Sequence length")
ax.set_ylabel("Score entries (millions)")
ax.legend()
fig.tight_layout()
plt.show()
plt.close(fig)

Demo 10: Causal and padding masks

This demo turns one attention concept into a small checked computation.

Code cell 23

header("Demo 10: Causal and padding masks - FlashAttention invariant")
T = 128
full_scores = T * T
tile_scores = 4 * (T // 4) * T
print("Full score entries conceptually:", full_scores)
print("Tiled computation covers entries:", tile_scores)
check_true(tile_scores == full_scores, "tiling can compute exact same attention pairs")
print("FlashAttention changes memory movement, not the attention formula.")

Demo 11: Softmax normalization

This demo turns one attention concept into a small checked computation.

Code cell 25

header("Demo 11: Softmax normalization - scaled attention")
Q = np.array([[1.0, 0.0], [0.0, 1.0]])
K = np.array([[1.0, 0.0], [0.0, 1.0]])
V = np.array([[10.0, 0.0], [0.0, 5.0]])
Y, W, S = attention(Q, K, V)
print("Weights:", np.round(W, 4))
print("Output:", np.round(Y, 4))
check_true(np.allclose(W.sum(axis=1), 1.0), "attention rows sum to one")

Demo 12: Weighted value aggregation

This demo turns one attention concept into a small checked computation.

Code cell 27

header("Demo 12: Weighted value aggregation - causal mask")
T = 4
Q = K = np.eye(T)
V = np.arange(T * 2, dtype=float).reshape(T, 2)
Y, W, S = attention(Q, K, V, mask=causal_mask(T))
print("Weights:", np.round(W, 3))
future_mass = W[np.triu_indices(T, k=1)].sum()
check_close(future_mass, 0.0, tol=1e-8, name="future attention mass")

Demo 13: Attention entropy

This demo turns one attention concept into a small checked computation.

Code cell 29

header("Demo 13: Attention entropy - entropy")
sharp = np.array([[0.98, 0.01, 0.01]])
diffuse = np.array([[1/3, 1/3, 1/3]])
Hs = attention_entropy(sharp)[0]
Hd = attention_entropy(diffuse)[0]
print("Sharp entropy:", round(float(Hs), 4))
print("Diffuse entropy:", round(float(Hd), 4))
check_true(Hd > Hs, "diffuse attention has higher entropy")

Demo 14: Temperature and score scale

This demo turns one attention concept into a small checked computation.

Code cell 31

header("Demo 14: Temperature and score scale - multi-head shapes")
X = np.zeros((2, 5, 12))
H = split_heads(X, n_heads=3)
X_back = combine_heads(H)
print("Split shape:", H.shape)
print("Combined shape:", X_back.shape)
check_true(X_back.shape == X.shape, "combine restores model width")

Demo 15: Numerical stability

This demo turns one attention concept into a small checked computation.

Code cell 33

header("Demo 15: Numerical stability - score scaling")
rng = np.random.default_rng(1)
Q = rng.normal(size=(16, 64))
K = rng.normal(size=(16, 64))
raw = Q @ K.T
scaled = raw / np.sqrt(64)
print("Raw std:", round(float(raw.std()), 4))
print("Scaled std:", round(float(scaled.std()), 4))
check_true(scaled.std() < raw.std(), "scaling controls dot-product variance")

Demo 16: Head dimensions

This demo turns one attention concept into a small checked computation.

Code cell 35

header("Demo 16: Head dimensions - padding mask")
scores = np.array([[2.0, 1.0, 5.0]])
mask = np.array([[0.0, 0.0, -1e9]])
w = stable_softmax(scores + mask, axis=-1)
print("Weights:", np.round(w, 6))
check_close(w[0, 2], 0.0, tol=1e-8, name="padded key receives zero weight")

Demo 17: Parallel heads

This demo turns one attention concept into a small checked computation.

Code cell 37

header("Demo 17: Parallel heads - KV cache size")
layers, heads, T, dh = 24, 16, 2048, 64
bytes_per = 2
kv_bytes = layers * 2 * heads * T * dh * bytes_per
print("KV cache MiB:", round(kv_bytes / 2**20, 2))
check_true(kv_bytes > 0, "KV cache memory is positive and grows with T")

Demo 18: Concatenation and output projection

This demo turns one attention concept into a small checked computation.

Code cell 39

header("Demo 18: Concatenation and output projection - ALiBi bias")
bias = alibi_bias(5, slope=-0.25)
print("Bias:", np.round(bias, 2))
check_close(bias[4, 0], -1.0, name="linear distance penalty")

Demo 19: Head specialization and redundancy

This demo turns one attention concept into a small checked computation.

Code cell 41

header("Demo 19: Head specialization and redundancy - quadratic cost")
T1, T2 = 1024, 2048
ratio = (T2 / T1) ** 2
print("Score matrix cost ratio:", ratio)
check_close(ratio, 4.0, name="doubling context quadruples score matrix")
fig, ax = plt.subplots()
Ts = np.array([512, 1024, 2048, 4096])
ax.plot(Ts, Ts**2 / 1e6, color=COLORS["primary"], label="score entries")
ax.set_title("Quadratic attention score growth")
ax.set_xlabel("Sequence length")
ax.set_ylabel("Score entries (millions)")
ax.legend()
fig.tight_layout()
plt.show()
plt.close(fig)

Demo 20: Grouped query and multi-query attention

This demo turns one attention concept into a small checked computation.

Code cell 43

header("Demo 20: Grouped query and multi-query attention - FlashAttention invariant")
T = 128
full_scores = T * T
tile_scores = 4 * (T // 4) * T
print("Full score entries conceptually:", full_scores)
print("Tiled computation covers entries:", tile_scores)
check_true(tile_scores == full_scores, "tiling can compute exact same attention pairs")
print("FlashAttention changes memory movement, not the attention formula.")

Demo 21: Autoregressive causal attention

This demo turns one attention concept into a small checked computation.

Code cell 45

header("Demo 21: Autoregressive causal attention - scaled attention")
Q = np.array([[1.0, 0.0], [0.0, 1.0]])
K = np.array([[1.0, 0.0], [0.0, 1.0]])
V = np.array([[10.0, 0.0], [0.0, 5.0]])
Y, W, S = attention(Q, K, V)
print("Weights:", np.round(W, 4))
print("Output:", np.round(Y, 4))
check_true(np.allclose(W.sum(axis=1), 1.0), "attention rows sum to one")

Demo 22: KV cache

This demo turns one attention concept into a small checked computation.

Code cell 47

header("Demo 22: KV cache - causal mask")
T = 4
Q = K = np.eye(T)
V = np.arange(T * 2, dtype=float).reshape(T, 2)
Y, W, S = attention(Q, K, V, mask=causal_mask(T))
print("Weights:", np.round(W, 3))
future_mass = W[np.triu_indices(T, k=1)].sum()
check_close(future_mass, 0.0, tol=1e-8, name="future attention mass")

Demo 23: Prefill versus decode

This demo turns one attention concept into a small checked computation.

Code cell 49

header("Demo 23: Prefill versus decode - entropy")
sharp = np.array([[0.98, 0.01, 0.01]])
diffuse = np.array([[1/3, 1/3, 1/3]])
Hs = attention_entropy(sharp)[0]
Hd = attention_entropy(diffuse)[0]
print("Sharp entropy:", round(float(Hs), 4))
print("Diffuse entropy:", round(float(Hd), 4))
check_true(Hd > Hs, "diffuse attention has higher entropy")

Demo 24: Attention with positional encodings

This demo turns one attention concept into a small checked computation.

Code cell 51

header("Demo 24: Attention with positional encodings - multi-head shapes")
X = np.zeros((2, 5, 12))
H = split_heads(X, n_heads=3)
X_back = combine_heads(H)
print("Split shape:", H.shape)
print("Combined shape:", X_back.shape)
check_true(X_back.shape == X.shape, "combine restores model width")

Demo 25: Cross-attention preview

This demo turns one attention concept into a small checked computation.

Code cell 53

header("Demo 25: Cross-attention preview - score scaling")
rng = np.random.default_rng(1)
Q = rng.normal(size=(16, 64))
K = rng.normal(size=(16, 64))
raw = Q @ K.T
scaled = raw / np.sqrt(64)
print("Raw std:", round(float(raw.std()), 4))
print("Scaled std:", round(float(scaled.std()), 4))
check_true(scaled.std() < raw.std(), "scaling controls dot-product variance")