Theory Notebook
Converted from
theory.ipynbfor web reading.
Attention Mechanism Math
This notebook is the executable companion to notes.md. It checks scaled dot-product attention, masks, entropy, multi-head shapes, KV-cache cost, ALiBi bias, and efficient-attention intuition.
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
Code cell 3
COLORS = {
"primary": "#0077BB",
"secondary": "#EE7733",
"tertiary": "#009988",
"error": "#CC3311",
"neutral": "#555555",
"highlight": "#EE3377",
}
def header(title):
print("\n" + "=" * 72)
print(title)
print("=" * 72)
def check_true(condition, name):
ok = bool(condition)
print(f"{'PASS' if ok else 'FAIL'} - {name}")
assert ok, name
def check_close(value, target, tol=1e-8, name="value"):
value = float(value)
target = float(target)
ok = abs(value - target) <= tol
print(f"{'PASS' if ok else 'FAIL'} - {name}: got {value:.6f}, expected {target:.6f}")
assert ok, name
def stable_softmax(scores, axis=-1):
scores = np.asarray(scores, dtype=float)
shifted = scores - np.max(scores, axis=axis, keepdims=True)
exp = np.exp(shifted)
return exp / exp.sum(axis=axis, keepdims=True)
def attention(Q, K, V, mask=None):
dk = Q.shape[-1]
scores = Q @ K.T / np.sqrt(dk)
if mask is not None:
scores = scores + mask
weights = stable_softmax(scores, axis=-1)
return weights @ V, weights, scores
def causal_mask(T):
mask = np.zeros((T, T))
mask[np.triu_indices(T, k=1)] = -1e9
return mask
def attention_entropy(weights):
w = np.clip(weights, 1e-12, 1.0)
return -(w * np.log(w)).sum(axis=-1)
def split_heads(X, n_heads):
B, T, D = X.shape
assert D % n_heads == 0
return X.reshape(B, T, n_heads, D // n_heads).transpose(0, 2, 1, 3)
def combine_heads(H):
B, heads, T, Dh = H.shape
return H.transpose(0, 2, 1, 3).reshape(B, T, heads * Dh)
def alibi_bias(T, slope=-0.5):
i = np.arange(T)[:, None]
j = np.arange(T)[None, :]
return slope * np.maximum(i - j, 0)
print("Attention helpers ready.")
Demo 1: Attention as soft retrieval
This demo turns one attention concept into a small checked computation.
Code cell 5
header("Demo 1: Attention as soft retrieval - scaled attention")
Q = np.array([[1.0, 0.0], [0.0, 1.0]])
K = np.array([[1.0, 0.0], [0.0, 1.0]])
V = np.array([[10.0, 0.0], [0.0, 5.0]])
Y, W, S = attention(Q, K, V)
print("Weights:", np.round(W, 4))
print("Output:", np.round(Y, 4))
check_true(np.allclose(W.sum(axis=1), 1.0), "attention rows sum to one")
Demo 2: Queries keys and values as roles
This demo turns one attention concept into a small checked computation.
Code cell 7
header("Demo 2: Queries keys and values as roles - causal mask")
T = 4
Q = K = np.eye(T)
V = np.arange(T * 2, dtype=float).reshape(T, 2)
Y, W, S = attention(Q, K, V, mask=causal_mask(T))
print("Weights:", np.round(W, 3))
future_mass = W[np.triu_indices(T, k=1)].sum()
check_close(future_mass, 0.0, tol=1e-8, name="future attention mass")
Demo 3: Why scaling is needed
This demo turns one attention concept into a small checked computation.
Code cell 9
header("Demo 3: Why scaling is needed - entropy")
sharp = np.array([[0.98, 0.01, 0.01]])
diffuse = np.array([[1/3, 1/3, 1/3]])
Hs = attention_entropy(sharp)[0]
Hd = attention_entropy(diffuse)[0]
print("Sharp entropy:", round(float(Hs), 4))
print("Diffuse entropy:", round(float(Hd), 4))
check_true(Hd > Hs, "diffuse attention has higher entropy")
Demo 4: Why masks are needed
This demo turns one attention concept into a small checked computation.
Code cell 11
header("Demo 4: Why masks are needed - multi-head shapes")
X = np.zeros((2, 5, 12))
H = split_heads(X, n_heads=3)
X_back = combine_heads(H)
print("Split shape:", H.shape)
print("Combined shape:", X_back.shape)
check_true(X_back.shape == X.shape, "combine restores model width")
Demo 5: Why attention replaced recurrence in LLMs
This demo turns one attention concept into a small checked computation.
Code cell 13
header("Demo 5: Why attention replaced recurrence in LLMs - score scaling")
rng = np.random.default_rng(1)
Q = rng.normal(size=(16, 64))
K = rng.normal(size=(16, 64))
raw = Q @ K.T
scaled = raw / np.sqrt(64)
print("Raw std:", round(float(raw.std()), 4))
print("Scaled std:", round(float(scaled.std()), 4))
check_true(scaled.std() < raw.std(), "scaling controls dot-product variance")
Demo 6: Input hidden-state matrix
This demo turns one attention concept into a small checked computation.
Code cell 15
header("Demo 6: Input hidden-state matrix - padding mask")
scores = np.array([[2.0, 1.0, 5.0]])
mask = np.array([[0.0, 0.0, -1e9]])
w = stable_softmax(scores + mask, axis=-1)
print("Weights:", np.round(w, 6))
check_close(w[0, 2], 0.0, tol=1e-8, name="padded key receives zero weight")
Demo 7: Linear Q K V projections
This demo turns one attention concept into a small checked computation.
Code cell 17
header("Demo 7: Linear Q K V projections - KV cache size")
layers, heads, T, dh = 24, 16, 2048, 64
bytes_per = 2
kv_bytes = layers * 2 * heads * T * dh * bytes_per
print("KV cache MiB:", round(kv_bytes / 2**20, 2))
check_true(kv_bytes > 0, "KV cache memory is positive and grows with T")
Demo 8: Scaled dot-product attention
This demo turns one attention concept into a small checked computation.
Code cell 19
header("Demo 8: Scaled dot-product attention - ALiBi bias")
bias = alibi_bias(5, slope=-0.25)
print("Bias:", np.round(bias, 2))
check_close(bias[4, 0], -1.0, name="linear distance penalty")
Demo 9: Attention weights
This demo turns one attention concept into a small checked computation.
Code cell 21
header("Demo 9: Attention weights - quadratic cost")
T1, T2 = 1024, 2048
ratio = (T2 / T1) ** 2
print("Score matrix cost ratio:", ratio)
check_close(ratio, 4.0, name="doubling context quadruples score matrix")
fig, ax = plt.subplots()
Ts = np.array([512, 1024, 2048, 4096])
ax.plot(Ts, Ts**2 / 1e6, color=COLORS["primary"], label="score entries")
ax.set_title("Quadratic attention score growth")
ax.set_xlabel("Sequence length")
ax.set_ylabel("Score entries (millions)")
ax.legend()
fig.tight_layout()
plt.show()
plt.close(fig)
Demo 10: Causal and padding masks
This demo turns one attention concept into a small checked computation.
Code cell 23
header("Demo 10: Causal and padding masks - FlashAttention invariant")
T = 128
full_scores = T * T
tile_scores = 4 * (T // 4) * T
print("Full score entries conceptually:", full_scores)
print("Tiled computation covers entries:", tile_scores)
check_true(tile_scores == full_scores, "tiling can compute exact same attention pairs")
print("FlashAttention changes memory movement, not the attention formula.")
Demo 11: Softmax normalization
This demo turns one attention concept into a small checked computation.
Code cell 25
header("Demo 11: Softmax normalization - scaled attention")
Q = np.array([[1.0, 0.0], [0.0, 1.0]])
K = np.array([[1.0, 0.0], [0.0, 1.0]])
V = np.array([[10.0, 0.0], [0.0, 5.0]])
Y, W, S = attention(Q, K, V)
print("Weights:", np.round(W, 4))
print("Output:", np.round(Y, 4))
check_true(np.allclose(W.sum(axis=1), 1.0), "attention rows sum to one")
Demo 12: Weighted value aggregation
This demo turns one attention concept into a small checked computation.
Code cell 27
header("Demo 12: Weighted value aggregation - causal mask")
T = 4
Q = K = np.eye(T)
V = np.arange(T * 2, dtype=float).reshape(T, 2)
Y, W, S = attention(Q, K, V, mask=causal_mask(T))
print("Weights:", np.round(W, 3))
future_mass = W[np.triu_indices(T, k=1)].sum()
check_close(future_mass, 0.0, tol=1e-8, name="future attention mass")
Demo 13: Attention entropy
This demo turns one attention concept into a small checked computation.
Code cell 29
header("Demo 13: Attention entropy - entropy")
sharp = np.array([[0.98, 0.01, 0.01]])
diffuse = np.array([[1/3, 1/3, 1/3]])
Hs = attention_entropy(sharp)[0]
Hd = attention_entropy(diffuse)[0]
print("Sharp entropy:", round(float(Hs), 4))
print("Diffuse entropy:", round(float(Hd), 4))
check_true(Hd > Hs, "diffuse attention has higher entropy")
Demo 14: Temperature and score scale
This demo turns one attention concept into a small checked computation.
Code cell 31
header("Demo 14: Temperature and score scale - multi-head shapes")
X = np.zeros((2, 5, 12))
H = split_heads(X, n_heads=3)
X_back = combine_heads(H)
print("Split shape:", H.shape)
print("Combined shape:", X_back.shape)
check_true(X_back.shape == X.shape, "combine restores model width")
Demo 15: Numerical stability
This demo turns one attention concept into a small checked computation.
Code cell 33
header("Demo 15: Numerical stability - score scaling")
rng = np.random.default_rng(1)
Q = rng.normal(size=(16, 64))
K = rng.normal(size=(16, 64))
raw = Q @ K.T
scaled = raw / np.sqrt(64)
print("Raw std:", round(float(raw.std()), 4))
print("Scaled std:", round(float(scaled.std()), 4))
check_true(scaled.std() < raw.std(), "scaling controls dot-product variance")
Demo 16: Head dimensions
This demo turns one attention concept into a small checked computation.
Code cell 35
header("Demo 16: Head dimensions - padding mask")
scores = np.array([[2.0, 1.0, 5.0]])
mask = np.array([[0.0, 0.0, -1e9]])
w = stable_softmax(scores + mask, axis=-1)
print("Weights:", np.round(w, 6))
check_close(w[0, 2], 0.0, tol=1e-8, name="padded key receives zero weight")
Demo 17: Parallel heads
This demo turns one attention concept into a small checked computation.
Code cell 37
header("Demo 17: Parallel heads - KV cache size")
layers, heads, T, dh = 24, 16, 2048, 64
bytes_per = 2
kv_bytes = layers * 2 * heads * T * dh * bytes_per
print("KV cache MiB:", round(kv_bytes / 2**20, 2))
check_true(kv_bytes > 0, "KV cache memory is positive and grows with T")
Demo 18: Concatenation and output projection
This demo turns one attention concept into a small checked computation.
Code cell 39
header("Demo 18: Concatenation and output projection - ALiBi bias")
bias = alibi_bias(5, slope=-0.25)
print("Bias:", np.round(bias, 2))
check_close(bias[4, 0], -1.0, name="linear distance penalty")
Demo 19: Head specialization and redundancy
This demo turns one attention concept into a small checked computation.
Code cell 41
header("Demo 19: Head specialization and redundancy - quadratic cost")
T1, T2 = 1024, 2048
ratio = (T2 / T1) ** 2
print("Score matrix cost ratio:", ratio)
check_close(ratio, 4.0, name="doubling context quadruples score matrix")
fig, ax = plt.subplots()
Ts = np.array([512, 1024, 2048, 4096])
ax.plot(Ts, Ts**2 / 1e6, color=COLORS["primary"], label="score entries")
ax.set_title("Quadratic attention score growth")
ax.set_xlabel("Sequence length")
ax.set_ylabel("Score entries (millions)")
ax.legend()
fig.tight_layout()
plt.show()
plt.close(fig)
Demo 20: Grouped query and multi-query attention
This demo turns one attention concept into a small checked computation.
Code cell 43
header("Demo 20: Grouped query and multi-query attention - FlashAttention invariant")
T = 128
full_scores = T * T
tile_scores = 4 * (T // 4) * T
print("Full score entries conceptually:", full_scores)
print("Tiled computation covers entries:", tile_scores)
check_true(tile_scores == full_scores, "tiling can compute exact same attention pairs")
print("FlashAttention changes memory movement, not the attention formula.")
Demo 21: Autoregressive causal attention
This demo turns one attention concept into a small checked computation.
Code cell 45
header("Demo 21: Autoregressive causal attention - scaled attention")
Q = np.array([[1.0, 0.0], [0.0, 1.0]])
K = np.array([[1.0, 0.0], [0.0, 1.0]])
V = np.array([[10.0, 0.0], [0.0, 5.0]])
Y, W, S = attention(Q, K, V)
print("Weights:", np.round(W, 4))
print("Output:", np.round(Y, 4))
check_true(np.allclose(W.sum(axis=1), 1.0), "attention rows sum to one")
Demo 22: KV cache
This demo turns one attention concept into a small checked computation.
Code cell 47
header("Demo 22: KV cache - causal mask")
T = 4
Q = K = np.eye(T)
V = np.arange(T * 2, dtype=float).reshape(T, 2)
Y, W, S = attention(Q, K, V, mask=causal_mask(T))
print("Weights:", np.round(W, 3))
future_mass = W[np.triu_indices(T, k=1)].sum()
check_close(future_mass, 0.0, tol=1e-8, name="future attention mass")
Demo 23: Prefill versus decode
This demo turns one attention concept into a small checked computation.
Code cell 49
header("Demo 23: Prefill versus decode - entropy")
sharp = np.array([[0.98, 0.01, 0.01]])
diffuse = np.array([[1/3, 1/3, 1/3]])
Hs = attention_entropy(sharp)[0]
Hd = attention_entropy(diffuse)[0]
print("Sharp entropy:", round(float(Hs), 4))
print("Diffuse entropy:", round(float(Hd), 4))
check_true(Hd > Hs, "diffuse attention has higher entropy")
Demo 24: Attention with positional encodings
This demo turns one attention concept into a small checked computation.
Code cell 51
header("Demo 24: Attention with positional encodings - multi-head shapes")
X = np.zeros((2, 5, 12))
H = split_heads(X, n_heads=3)
X_back = combine_heads(H)
print("Split shape:", H.shape)
print("Combined shape:", X_back.shape)
check_true(X_back.shape == X.shape, "combine restores model width")
Demo 25: Cross-attention preview
This demo turns one attention concept into a small checked computation.
Code cell 53
header("Demo 25: Cross-attention preview - score scaling")
rng = np.random.default_rng(1)
Q = rng.normal(size=(16, 64))
K = rng.normal(size=(16, 64))
raw = Q @ K.T
scaled = raw / np.sqrt(64)
print("Raw std:", round(float(raw.std()), 4))
print("Scaled std:", round(float(scaled.std()), 4))
check_true(scaled.std() < raw.std(), "scaling controls dot-product variance")