Theory Notebook
Theory Notebook
Converted from
theory.ipynbfor web reading.
Transformer Architecture: Theory Notebook
This notebook makes transformer architecture concrete: attention matrices, masks, heads, MLPs, normalization, positional signals, KV cache memory, and diagnostics.
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
1. Scaled dot-product attention
Code cell 4
def softmax(x, axis=-1):
x = np.asarray(x, dtype=float)
x = x - np.max(x, axis=axis, keepdims=True)
e = np.exp(x)
return e / e.sum(axis=axis, keepdims=True)
Q = np.array([[1.0, 0.0], [0.0, 1.0]])
K = np.array([[1.0, 0.0], [1.0, 1.0], [0.0, 1.0]])
V = np.array([[10.0, 0.0], [5.0, 5.0], [0.0, 10.0]])
scores = Q @ K.T / np.sqrt(Q.shape[-1])
A = softmax(scores, axis=-1)
O = A @ V
print("scores:\n", np.round(scores, 3))
print("weights:\n", np.round(A, 3))
print("output:\n", np.round(O, 3))
2. Causal mask
Code cell 6
T = 5
scores = np.arange(T*T, dtype=float).reshape(T, T)
mask = np.triu(np.ones((T, T), dtype=bool), k=1)
masked_scores = scores.copy()
masked_scores[mask] = -1e9
print("causal mask:\n", mask.astype(int))
print("masked row 2:", masked_scores[2])
3. Head splitting shapes
Code cell 8
B, T, d_model, heads = 2, 4, 12, 3
d_head = d_model // heads
H = np.zeros((B, T, d_model))
split = H.reshape(B, T, heads, d_head).transpose(0, 2, 1, 3)
merged = split.transpose(0, 2, 1, 3).reshape(B, T, d_model)
print("split shape:", split.shape)
print("merged shape:", merged.shape)
4. Attention parameter count
Code cell 10
d_model = 768
qkv_o_params = 4 * d_model * d_model
print("Q,K,V,O projection params without bias:", qkv_o_params)
5. MLP parameter count
Code cell 12
d_model = 768
d_ff = 4 * d_model
mlp_params = 2 * d_model * d_ff
attn_params = 4 * d_model * d_model
print("MLP params:", mlp_params)
print("attention projection params:", attn_params)
print("MLP/attention ratio:", mlp_params / attn_params)
6. LayerNorm and RMSNorm
Code cell 14
x = np.array([1.0, 2.0, 4.0, 8.0])
eps = 1e-5
layer_norm = (x - x.mean()) / np.sqrt(x.var() + eps)
rms_norm = x / np.sqrt(np.mean(x**2) + eps)
print("LayerNorm:", np.round(layer_norm, 4))
print("RMSNorm:", np.round(rms_norm, 4))
7. Pre-LN versus Post-LN equations
Code cell 16
x = np.array([1.0, -0.5, 0.25])
F = lambda z: 0.2 * z
ln = lambda z: (z - z.mean()) / np.sqrt(z.var() + 1e-5)
pre_ln = x + F(ln(x))
post_ln = ln(x + F(x))
print("Pre-LN output:", np.round(pre_ln, 4))
print("Post-LN output:", np.round(post_ln, 4))
8. Sinusoidal positional encoding
Code cell 18
T, d = 32, 16
pos = np.arange(T)[:, None]
i = np.arange(d)[None, :]
angles = pos / (10000 ** (2 * (i // 2) / d))
P = np.zeros((T, d))
P[:, 0::2] = np.sin(angles[:, 0::2])
P[:, 1::2] = np.cos(angles[:, 1::2])
plt.imshow(P, aspect="auto", cmap="coolwarm")
plt.title("Sinusoidal positional encoding")
plt.xlabel("dimension")
plt.ylabel("position")
plt.colorbar()
plt.tight_layout()
plt.show()
print("P shape:", P.shape)
9. Attention entropy
Code cell 20
def softmax(x, axis=-1):
x = np.asarray(x, dtype=float)
x = x - np.max(x, axis=axis, keepdims=True)
e = np.exp(x)
return e / e.sum(axis=axis, keepdims=True)
rng = np.random.default_rng(0)
scores = rng.normal(size=(4, 6))
A = softmax(scores, axis=-1)
entropy = -np.sum(A * np.log(A + 1e-12), axis=-1)
print("attention entropy per query:", np.round(entropy, 3))
print("max entropy for 6 keys:", np.log(6))
10. KV cache memory
Code cell 22
B, L, T, H_kv, d_h, bytes_per = 4, 32, 4096, 8, 128, 2
gb = 2 * B * L * T * H_kv * d_h * bytes_per / 1e9
print("KV cache GB:", gb)
11. Weight tying
Code cell 24
vocab, d_model = 50000, 768
untied = 2 * vocab * d_model
tied = vocab * d_model
print("untied embedding+head params:", untied)
print("tied params:", tied)
print("savings:", untied - tied)
12. Mask leakage test
Code cell 26
causal = np.tril(np.ones((5, 5), dtype=int))
future_visible = np.any(np.triu(causal, k=1) != 0)
print("future visible:", future_visible)
assert not future_visible
13. Residual update ratio
Code cell 28
rng = np.random.default_rng(3)
x = rng.normal(size=(10, 32))
update = 0.1 * rng.normal(size=(10, 32))
ratio = np.linalg.norm(update) / np.linalg.norm(x)
print("update/residual norm ratio:", ratio)
14. Transformer checklist
Code cell 30
checks = [
"Q,K,V and score tensor shapes are written down",
"padding and causal masks are tested separately",
"head dimension divides model dimension",
"MLP and attention parameter counts are understood",
"normalization placement matches the intended architecture",
"KV cache memory is computed for serving settings",
]
for i, check in enumerate(checks, 1):
print(f"{i}. {check}")