Theory NotebookMath for LLMs

Language Model Probability

Math for LLMs / Language Model Probability

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Language Model Probability Math: Theory Notebook

This notebook is the executable companion to notes.md. It turns the probability formulas behind LLMs into small, inspectable computations: chain-rule scoring, stable softmax, cross-entropy gradients, perplexity, calibration, and decoding.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

1. A tiny autoregressive language model

A language model gives a next-token distribution for every prefix. In a real LLM, the prefix is encoded by a transformer. Here, a dictionary is enough to expose the probability bookkeeping.

Code cell 4

vocab = ["<BOS>", "the", "cat", "sat", "slept", "<EOS>"]

toy_next = {
    ("<BOS>",): {"the": 0.90, "cat": 0.05, "slept": 0.05},
    ("<BOS>", "the"): {"cat": 0.75, "sat": 0.10, "slept": 0.10, "<EOS>": 0.05},
    ("<BOS>", "the", "cat"): {"sat": 0.55, "slept": 0.35, "<EOS>": 0.10},
    ("<BOS>", "the", "cat", "sat"): {"<EOS>": 0.80, "slept": 0.20},
}

def next_prob(prefix, token):
    dist = toy_next.get(tuple(prefix), {"<EOS>": 1.0})
    return dist.get(token, 1e-12)

sentence = ["<BOS>", "the", "cat", "sat", "<EOS>"]
logp = 0.0
for i in range(1, len(sentence)):
    prefix = sentence[:i]
    token = sentence[i]
    prob = next_prob(prefix, token)
    logp += np.log(prob)
    print(f"{prefix} -> {token:>5s}: p={prob:.3f}, cumulative logp={logp:.3f}")

print("sequence probability =", float(np.exp(logp)))

2. Chain rule versus independence

The chain rule is exact. The approximation is not the product rule itself; the approximation is how the model estimates each conditional.

Code cell 6

joint_by_chain = 0.90 * 0.75 * 0.55 * 0.80
bad_independent = 0.40 * 0.30 * 0.10 * 0.20
print("chain-rule probability:", joint_by_chain)
print("independence-style product:", bad_independent)
print("ratio:", joint_by_chain / bad_independent)

3. N-gram smoothing as a probability ancestor

Before neural LMs, count models estimated P(wiwin+1:i1)P(w_i \mid w_{i-n+1:i-1}). Additive smoothing makes unseen events nonzero.

Code cell 8

corpus = [
    ["<BOS>", "the", "cat", "sat", "<EOS>"],
    ["<BOS>", "the", "cat", "slept", "<EOS>"],
    ["<BOS>", "the", "cat", "sat", "<EOS>"],
    ["<BOS>", "the", "dog", "sat", "<EOS>"],
]
vocab2 = sorted({tok for sent in corpus for tok in sent})
alpha = 0.5

from collections import Counter, defaultdict
counts = defaultdict(Counter)
for sent in corpus:
    for a, b in zip(sent[:-1], sent[1:]):
        counts[a][b] += 1

def smoothed_bigram(prev, tok):
    return (counts[prev][tok] + alpha) / (sum(counts[prev].values()) + alpha * len(vocab2))

for tok in vocab2:
    print(f"P({tok:>5s} | the) = {smoothed_bigram('the', tok):.3f}")
print("normalization:", sum(smoothed_bigram("the", tok) for tok in vocab2))

4. Stable softmax and log-softmax

Large logits can overflow if exponentiated directly. Subtracting the maximum logit preserves probabilities and stabilizes the computation.

Code cell 10

def softmax(z, tau=1.0):
    z = np.asarray(z, dtype=float) / tau
    z = z - z.max()
    exp_z = np.exp(z)
    return exp_z / exp_z.sum()

def log_softmax(z):
    z = np.asarray(z, dtype=float)
    m = z.max()
    lse = m + np.log(np.exp(z - m).sum())
    return z - lse

logits = np.array([1002.0, 1001.0, 999.0, 995.0])
p = softmax(logits)
lp = log_softmax(logits)
print("probabilities:", np.round(p, 6))
print("sum:", p.sum())
print("log probs:", np.round(lp, 6))
print("exp(log probs):", np.round(np.exp(lp), 6))

5. Temperature changes entropy

Temperature rescales logits before softmax. Low temperature sharpens; high temperature flattens.

Code cell 12

temps = [0.4, 0.7, 1.0, 1.5, 2.5]
base_logits = np.array([3.0, 2.0, 1.0, 0.0, -1.0])

fig, ax = plt.subplots(figsize=(9, 5))
for tau in temps:
    probs = softmax(base_logits, tau=tau)
    ax.plot(range(len(base_logits)), probs, marker="o", label=f"tau={tau}")
    entropy = -(probs * np.log(probs + 1e-12)).sum()
    print(f"tau={tau:>3}: entropy={entropy:.3f}, top prob={probs.max():.3f}")
ax.set_title("Temperature reshapes a next-token distribution")
ax.set_xlabel("token index")
ax.set_ylabel("probability")
ax.legend()
fig.tight_layout()
plt.show()

6. Cross-entropy gradient

For one observed target token, the loss is logpy-\log p_y. The derivative with respect to logits is pyp-y.

Code cell 14

z = np.array([1.2, -0.3, 0.7, 2.1])
target = 3
p = softmax(z)
y = np.eye(len(z))[target]
analytic_grad = p - y

def ce_loss(zvec):
    return -log_softmax(zvec)[target]

eps = 1e-5
numeric_grad = np.zeros_like(z)
for i in range(len(z)):
    step = np.zeros_like(z)
    step[i] = eps
    numeric_grad[i] = (ce_loss(z + step) - ce_loss(z - step)) / (2 * eps)

print("probabilities:", np.round(p, 5))
print("analytic grad:", np.round(analytic_grad, 5))
print("numeric grad: ", np.round(numeric_grad, 5))
print("max error:", np.max(np.abs(analytic_grad - numeric_grad)))

7. Masked token loss

Padded tokens should not contribute to the mean loss. Many training bugs are just masking bugs.

Code cell 16

token_losses = np.array([
    [1.1, 0.9, 0.7, 0.0],
    [1.5, 1.3, 0.0, 0.0],
])
mask = np.array([
    [1, 1, 1, 0],
    [1, 1, 0, 0],
])
masked_mean = (token_losses * mask).sum() / mask.sum()
bad_mean = token_losses.mean()
print("masked mean loss:", masked_mean)
print("unmasked mean loss:", bad_mean)
print("difference:", masked_mean - bad_mean)

8. Entropy, cross-entropy, KL, and perplexity

Cross-entropy can be decomposed into irreducible data entropy plus model mismatch.

Code cell 18

q = np.array([0.70, 0.20, 0.10])
p_good = np.array([0.65, 0.25, 0.10])
p_bad = np.array([0.20, 0.30, 0.50])

def entropy(dist):
    return -(dist * np.log(dist + 1e-12)).sum()

def cross_entropy(q, p):
    return -(q * np.log(p + 1e-12)).sum()

def kl(q, p):
    return (q * (np.log(q + 1e-12) - np.log(p + 1e-12))).sum()

for name, model in [("good", p_good), ("bad", p_bad)]:
    ce = cross_entropy(q, model)
    print(name, "CE=", round(ce, 4), "H+KL=", round(entropy(q) + kl(q, model), 4), "PPL=", round(np.exp(ce), 4))

9. Length normalization

Raw log probability usually favors shorter outputs. Average log probability asks how surprising each token is.

Code cell 20

candidates = {
    "short": np.array([-0.20, -0.40]),
    "long_better_per_token": np.array([-0.32, -0.34, -0.35, -0.33, -0.34, -0.32]),
}
for name, logps in candidates.items():
    print(name)
    print("  total logp:", float(logps.sum()))
    print("  mean logp: ", float(logps.mean()))
    print("  perplexity:", float(np.exp(-logps.mean())))

10. Expected calibration error

Calibration compares confidence to empirical correctness. A model can have high accuracy and still be overconfident.

Code cell 22

conf = np.array([0.95, 0.90, 0.85, 0.80, 0.75, 0.70, 0.60, 0.55, 0.45, 0.30])
correct = np.array([1, 1, 0, 1, 0, 1, 0, 0, 0, 0])
bins = np.linspace(0.0, 1.0, 6)

ece = 0.0
for lo, hi in zip(bins[:-1], bins[1:]):
    in_bin = (conf > lo) & (conf <= hi)
    if not in_bin.any():
        continue
    acc = correct[in_bin].mean()
    avg_conf = conf[in_bin].mean()
    weight = in_bin.mean()
    ece += weight * abs(acc - avg_conf)
    print(f"({lo:.1f}, {hi:.1f}] n={in_bin.sum()} acc={acc:.3f} conf={avg_conf:.3f}")
print("ECE:", round(ece, 4))

11. Top-k and top-p filtering

Filtering changes the distribution used for decoding. It does not retrain the model.

Code cell 24

def top_k_filter(probs, k):
    probs = np.asarray(probs, dtype=float)
    keep = np.argsort(probs)[-k:]
    out = np.zeros_like(probs)
    out[keep] = probs[keep]
    return out / out.sum()

def top_p_filter(probs, p_cut):
    probs = np.asarray(probs, dtype=float)
    order = np.argsort(probs)[::-1]
    cumulative = np.cumsum(probs[order])
    keep_sorted = cumulative <= p_cut
    keep_sorted[0] = True
    if not keep_sorted.all():
        keep_sorted[np.argmax(cumulative > p_cut)] = True
    keep = order[keep_sorted]
    out = np.zeros_like(probs)
    out[keep] = probs[keep]
    return out / out.sum()

probs = softmax(np.array([4.0, 3.0, 2.2, 1.5, 0.5, -0.5]))
print("base: ", np.round(probs, 3))
print("top-k:", np.round(top_k_filter(probs, 3), 3))
print("top-p:", np.round(top_p_filter(probs, 0.82), 3))

12. Beam search as approximate sequence MAP

Beam search keeps several partial hypotheses, trading diversity for higher sequence score.

Code cell 26

transitions = {
    (): {"A": 0.55, "B": 0.45},
    ("A",): {"x": 0.52, "y": 0.48},
    ("B",): {"x": 0.90, "y": 0.10},
    ("A", "x"): {"<EOS>": 1.0},
    ("A", "y"): {"<EOS>": 1.0},
    ("B", "x"): {"<EOS>": 1.0},
    ("B", "y"): {"<EOS>": 1.0},
}

beam = [((), 0.0)]
for step in range(3):
    expanded = []
    for prefix, score in beam:
        if prefix and prefix[-1] == "<EOS>":
            expanded.append((prefix, score))
            continue
        for tok, prob in transitions[prefix].items():
            expanded.append((prefix + (tok,), score + np.log(prob)))
    beam = sorted(expanded, key=lambda x: x[1], reverse=True)[:2]
    print("step", step + 1, beam)

13. Tied embeddings and the LM head

Many LMs use the same matrix for input embeddings and output logits. The probability math is unchanged; the parameterization changes.

Code cell 28

vocab_size, d_model = 7, 4
E = np.random.normal(size=(vocab_size, d_model))
h = np.random.normal(size=(d_model,))
logits_tied = E @ h
probs_tied = softmax(logits_tied)
print("embedding table shape:", E.shape)
print("hidden state shape:", h.shape)
print("logits shape:", logits_tied.shape)
print("probability sum:", probs_tied.sum())

14. Conditional scoring

For prompt-answer scoring, average over answer tokens, not prompt tokens, unless the task explicitly asks for full string likelihood.

Code cell 30

prompt_logps = np.array([-0.1, -0.2, -0.1, -0.3])
answer_a = np.array([-0.4, -0.5, -0.4])
answer_b = np.array([-0.2])

print("full score A:", np.r_[prompt_logps, answer_a].mean())
print("full score B:", np.r_[prompt_logps, answer_b].mean())
print("answer-only score A:", answer_a.mean())
print("answer-only score B:", answer_b.mean())
print("Choose the scoring convention before comparing.")

15. Logit bias as a constraint

Adding a large negative bias can effectively ban a token. Adding a positive bias can encourage a token family.

Code cell 32

labels = ["safe", "maybe", "banned", "other"]
logits = np.array([2.0, 1.5, 1.2, 0.5])
bias = np.array([0.0, 0.0, -20.0, 0.0])

before = softmax(logits)
after = softmax(logits + bias)
for lab, b, a in zip(labels, before, after):
    print(f"{lab:>6s}: before={b:.4f}, after={a:.8f}")

16. Tokenizer-dependent perplexity

Per-token perplexity is not directly comparable across tokenizers with different token counts. Bits per byte or bits per character can be fairer when the text is identical.

Code cell 34

nll_same_text = 12.0
tokens_a = 6
tokens_b = 10
bytes_count = 32

ppl_a = np.exp(nll_same_text / tokens_a)
ppl_b = np.exp(nll_same_text / tokens_b)
bpb = nll_same_text / np.log(2) / bytes_count
print("PPL tokenizer A:", round(ppl_a, 3))
print("PPL tokenizer B:", round(ppl_b, 3))
print("bits per byte:", round(bpb, 3))

17. Final probability implementation checklist

Code cell 36

checks = [
    "softmax sums to one on the vocabulary axis",
    "log-softmax uses log-sum-exp stabilization",
    "targets are shifted by one position",
    "causal mask prevents future leakage",
    "padding and prompt masks are applied before averaging",
    "sequence comparisons state the length convention",
    "perplexity comparisons state tokenizer and dataset",
]
for i, check in enumerate(checks, 1):
    print(f"{i}. {check}")