Exercises NotebookMath for LLMs

KL Divergence

Information Theory / KL Divergence

Run notebook
Exercises Notebook

Exercises Notebook

Converted from exercises.ipynb for web reading.

KL Divergence — Exercises

8 exercises covering KL divergence from first principles to modern AI applications.

FormatDescription
ProblemMarkdown cell with task description
Your SolutionCode cell with scaffolding
SolutionCode cell with reference solution and checks

Difficulty Levels

LevelExercisesFocus
1–3Core mechanics and definitions
★★4–6Properties, proofs, and distributions
★★★7–8AI applications: RLHF and distillation

Topic Map

TopicExercise
Definition and coding interpretation1
Gibbs' inequality (non-negativity)2
Forward vs reverse KL3
Chain rule for KL4
Gaussian KL + VAE encoder5
f-Divergences and Pinsker6
RLHF optimal policy7
Knowledge distillation8

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3

import numpy as np
import numpy.linalg as la

try:
    import matplotlib.pyplot as plt
    HAS_MPL = True
except ImportError:
    HAS_MPL = False

from scipy.special import expit  # sigmoid
from scipy.optimize import brentq, minimize

np.set_printoptions(precision=6, suppress=True)
np.random.seed(42)

def header(title):
    print()
    print('=' * len(title))
    print(title)
    print('=' * len(title))

def check_close(name, got, expected, tol=1e-6):
    ok = np.allclose(got, expected, atol=tol, rtol=tol)
    print(f"{'PASS' if ok else 'FAIL'}{name}")
    if not ok:
        print(f'  expected: {expected}')
        print(f'  got:      {got}')
    return ok

def check_true(name, cond):
    print(f"{'PASS' if cond else 'FAIL'}{name}")
    return cond

def kl_divergence(p, q, eps=1e-300):
    """D_KL(p || q) for discrete distributions"""
    p, q = np.asarray(p, float), np.asarray(q, float)
    p = p / p.sum(); q = q / q.sum()
    mask = p > 0
    return float(np.sum(p[mask] * np.log(p[mask] / np.maximum(q[mask], eps))))

def softmax(z, tau=1.0):
    z = np.asarray(z, float)
    z_shift = (z - z.max()) / tau
    e = np.exp(z_shift)
    return e / e.sum()

print('Setup complete.')

Exercise 1 ★ — KL Divergence: Definition and Coding Interpretation

A language model assigns probabilities to the next token in the sentence "The Eiffel Tower is in ___":

TokenTrue ppModel qq
'Paris'0.900.65
'London'0.040.15
'Rome'0.030.10
'Berlin'0.020.05
'other'0.010.05

(a) Compute DKL(pq)D_{\mathrm{KL}}(p\|q) and DKL(qp)D_{\mathrm{KL}}(q\|p) in nats.

(b) Compute the cross-entropy H(p,q)H(p, q) and verify H(p,q)=H(p)+DKL(pq)H(p,q) = H(p) + D_{\mathrm{KL}}(p\|q).

(c) Interpret DKL(pq)D_{\mathrm{KL}}(p\|q): how many extra nats per token does the model waste?

(d) Which direction (pqp\|q or qpq\|p) is larger? Explain why in one sentence.

Code cell 5

# Your Solution
print("Write your solution here, then compare with the reference solution below.")

Code cell 6

# Solution
# Exercise 1: Solution

p = np.array([0.90, 0.04, 0.03, 0.02, 0.01])
q = np.array([0.65, 0.15, 0.10, 0.05, 0.05])

# (a)
kl_pq = kl_divergence(p, q)
kl_qp = kl_divergence(q, p)

# (b)
H_p  = -np.sum(p * np.log(p))
H_pq = -np.sum(p * np.log(q))

header('Exercise 1: KL Divergence — Definition and Coding')
print(f'D_KL(p||q) = {kl_pq:.6f} nats  [{kl_pq/np.log(2):.4f} bits]')
print(f'D_KL(q||p) = {kl_qp:.6f} nats  [{kl_qp/np.log(2):.4f} bits]')
print(f'H(p) = {H_p:.6f} nats')
print(f'H(p,q) = {H_pq:.6f} nats')

check_true('D_KL(p||q) >= 0', kl_pq >= 0)
check_true('D_KL(q||p) >= 0', kl_qp >= 0)
check_close('H(p,q) = H(p) + D_KL(p||q)', H_pq, H_p + kl_pq)
check_true('D_KL(p||q) != D_KL(q||p) (asymmetric)', not np.isclose(kl_pq, kl_qp))

print(f'\n(c) Model wastes {kl_pq:.4f} nats = {kl_pq/np.log(2):.4f} bits per prediction')
print(f'(d) D_KL(q||p) = {kl_qp:.4f} > D_KL(p||q) = {kl_pq:.4f}')
print('    q assigns 15% to London but p=4%: reverse KL penalizes this heavily')
print('\nTakeaway: D_KL(p||q) = extra bits from model mismatch; KL = perplexity gap')

Exercise 2 ★ — Gibbs' Inequality: Proof and Verification

Prove DKL(pq)0D_{\mathrm{KL}}(p\|q) \ge 0 for the specific distributions: p=(0.6,0.3,0.1)p = (0.6, 0.3, 0.1) and q=(0.2,0.5,0.3)q = (0.2, 0.5, 0.3).

(a) Verify numerically: compute both KL directions and confirm both are 0\ge 0.

(b) Verify the log-inequality: for each symbol xx, compute ln(q(x)/p(x))\ln(q(x)/p(x)) and q(x)/p(x)1q(x)/p(x) - 1. Confirm ln(q/p)q/p1\ln(q/p) \le q/p - 1 for each.

(c) Use (b) to construct the proof: show that xp(x)ln(q(x)/p(x))xp(x)(q(x)/p(x)1)=0\sum_x p(x)\ln(q(x)/p(x)) \le \sum_x p(x)(q(x)/p(x)-1) = 0.

(d) Under what conditions does equality hold? Verify numerically.

Code cell 8

# Your Solution
print("Write your solution here, then compare with the reference solution below.")

Code cell 9

# Solution
# Exercise 2: Solution

p = np.array([0.6, 0.3, 0.1])
q = np.array([0.2, 0.5, 0.3])

kl_pq = kl_divergence(p, q)
kl_qp = kl_divergence(q, p)

log_ratio    = np.log(q / p)  # ln(q/p)
linear_m1    = q / p - 1      # q/p - 1

# Proof step: sum p * ln(q/p) <= sum p * (q/p - 1)
lhs = np.sum(p * log_ratio)    # -D_KL(p||q)
rhs = np.sum(p * linear_m1)    # = sum q - sum p = 1 - 1 = 0

header('Exercise 2: Gibbs Inequality')
print(f'D_KL(p||q) = {kl_pq:.6f}  (should be >= 0)')
print(f'D_KL(q||p) = {kl_qp:.6f}  (should be >= 0)')
print()
print('Pointwise: ln(q/p) vs q/p - 1:')
for i, (l, r) in enumerate(zip(log_ratio, linear_m1)):
    print(f'  x={i}: ln(q/p)={l:.4f} <= q/p-1={r:.4f}: {l <= r + 1e-12}')
print()
print(f'sum p*ln(q/p) = {lhs:.6f}  (= -D_KL(p||q) = {-kl_pq:.6f})')
print(f'sum p*(q/p-1) = {rhs:.6f}  (= 0 by normalization)')

check_true('D_KL(p||q) >= 0', kl_pq >= 0)
check_true('D_KL(q||p) >= 0', kl_qp >= 0)
check_true('pointwise: ln(q/p) <= q/p - 1', (log_ratio <= linear_m1 + 1e-12).all())
check_close('sum p*(q/p-1) = 0', rhs, 0.0)

# Equality: D_KL(p||p) = 0
kl_self = kl_divergence(p, p)
check_close('D_KL(p||p) = 0 (equality at p=q)', kl_self, 0.0)

print('\nTakeaway: Gibbs = Jensen on concave log; equality iff p=q everywhere')

Exercise 3 ★ — Forward vs Reverse KL on a Bimodal Target

The true distribution pp is a mixture of two Gaussians discretised onto a grid: pN(3,0.82)+N(3,0.82)p \propto \mathcal{N}(-3, 0.8^2) + \mathcal{N}(3, 0.8^2). We fit a unimodal Gaussian qμ,σq_{\mu,\sigma} by minimising KL in each direction.

(a) Build the discrete grid x[8,8]x \in [-8, 8] (501 points) and construct pp.

(b) Minimise DKL(pq)D_{\mathrm{KL}}(p \| q) (forward KL) over (μ,σ)(\mu, \sigma) using scipy.optimize.minimize. Print the optimal (μ,σ)(\mu^*, \sigma^*) and call the result qfwdq_{\text{fwd}}.

(c) Minimise DKL(qp)D_{\mathrm{KL}}(q \| p) (reverse KL) in the same way. Print the optimal (μ,σ)(\mu^*, \sigma^*) and call the result qrevq_{\text{rev}}.

(d) Explain in 1–2 sentences why forward KL is mass-covering and reverse KL is mode-seeking.

(e) (Optional) Plot pp, qfwdq_{\text{fwd}}, qrevq_{\text{rev}} on one figure.

Code cell 11

# Your Solution
print("Write your solution here, then compare with the reference solution below.")

Code cell 12

# Solution
# Exercise 3: Solution
from scipy.stats import norm as sp_norm
from scipy.optimize import minimize as sp_minimize

x = np.linspace(-8, 8, 501)
dx = x[1] - x[0]

# (a) Bimodal target
p_raw = sp_norm.pdf(x, -3, 0.8) + sp_norm.pdf(x, 3, 0.8)
p = p_raw / (p_raw.sum() * dx)
p_prob = p * dx          # discrete probabilities
p_prob /= p_prob.sum()

def gaussian_q_prob(params):
    mu, log_s = params
    sigma = np.exp(log_s)
    q_raw = sp_norm.pdf(x, mu, sigma)
    q_prob = q_raw * dx
    return q_prob / q_prob.sum()

eps = 1e-300

def fwd_kl(params):
    q = gaussian_q_prob(params)
    mask = p_prob > eps
    return float(np.sum(p_prob[mask] * np.log(p_prob[mask] / np.maximum(q[mask], eps))))

def rev_kl(params):
    q = gaussian_q_prob(params)
    mask = q > eps
    return float(np.sum(q[mask] * np.log(q[mask] / np.maximum(p_prob[mask], eps))))

res_fwd = sp_minimize(fwd_kl, [0.0, np.log(2.0)], method='Nelder-Mead',
                      options={'xatol': 1e-6, 'fatol': 1e-8, 'maxiter': 5000})
res_rev = sp_minimize(rev_kl, [-3.0, np.log(0.8)], method='Nelder-Mead',
                      options={'xatol': 1e-6, 'fatol': 1e-8, 'maxiter': 5000})

mu_fwd, sig_fwd = res_fwd.x[0], np.exp(res_fwd.x[1])
mu_rev, sig_rev = res_rev.x[0], np.exp(res_rev.x[1])

header('Exercise 3: Forward vs Reverse KL')
print(f'Forward KL optimal: mu={mu_fwd:.3f}, sigma={sig_fwd:.3f}  D_KL={res_fwd.fun:.4f}')
print(f'Reverse KL optimal: mu={mu_rev:.3f}, sigma={sig_rev:.3f}  D_KL={res_rev.fun:.4f}')

check_true('Forward KL: |mu| < 0.5 (mean of bimodal -> near 0)',  abs(mu_fwd) < 0.5)
check_true('Forward KL: sigma > 2.5 (mass-covering, wide)',       sig_fwd > 2.5)
check_true('Reverse KL: |mu| > 1.5 (mode-seeking, latches a mode)', abs(mu_rev) > 1.5)
check_true('Reverse KL: sigma < 1.5 (mode-seeking, narrow)',       sig_rev < 1.5)

print('\n(d) Forward KL penalises p(x)>0, q(x)=0 (zero-avoidance),\n'
'    so q must cover both modes -> broad, mean-seeking fit.')
print('    Reverse KL penalises q(x)>0, p(x)=0,\n'
'    so q avoids the inter-modal gap -> narrow, mode-seeking fit.')

print('\nTakeaway: forward KL=MLE objective; reverse KL=variational inference (ELBO)')

Exercise 4 ★★ — Chain Rule Decomposition

Let (X,Y)(X, Y) be a pair of discrete random variables with joint PMF p(x,y)p(x, y). The chain rule for KL states:

DKL(p(x,y)q(x,y))=DKL(p(x)q(x))+Exp[DKL(p(yx)q(yx))]D_{\mathrm{KL}}(p(x,y) \| q(x,y)) = D_{\mathrm{KL}}(p(x) \| q(x)) + \mathbb{E}_{x \sim p}\bigl[D_{\mathrm{KL}}(p(y|x) \| q(y|x))\bigr]

Use the following 2×32 \times 3 joint PMFs:

p(x,y)=(0.300.150.050.200.200.10),q(x,y)=(0.200.200.100.150.250.10)p(x,y) = \begin{pmatrix}0.30 & 0.15 & 0.05\\ 0.20 & 0.20 & 0.10\end{pmatrix}, \quad q(x,y) = \begin{pmatrix}0.20 & 0.20 & 0.10\\ 0.15 & 0.25 & 0.10\end{pmatrix}

(a) Compute DKL(pXYqXY)D_{\mathrm{KL}}(p_{XY} \| q_{XY}) directly (flatten the joint).

(b) Compute the marginals pX,qXp_X, q_X and the conditionals pYX=0,pYX=1p_{Y|X=0}, p_{Y|X=1}, etc.

(c) Compute the right-hand side of the chain rule and verify equality.

(d) State why the chain rule implies DKL(pXYqXY)DKL(pXqX)D_{\mathrm{KL}}(p_{XY} \| q_{XY}) \ge D_{\mathrm{KL}}(p_X \| q_X) (the Data Processing Inequality).

Code cell 14

# Your Solution
print("Write your solution here, then compare with the reference solution below.")

Code cell 15

# Solution
# Exercise 4: Solution

P = np.array([[0.30, 0.15, 0.05],
              [0.20, 0.20, 0.10]])
Q = np.array([[0.20, 0.20, 0.10],
              [0.15, 0.25, 0.10]])

# (a)
kl_joint = kl_divergence(P.ravel(), Q.ravel())

# (b) Marginals
p_x = P.sum(axis=1)      # shape (2,)
q_x = Q.sum(axis=1)

# Conditionals p(y|x) = p(x,y) / p(x)
p_y_given_x = P / p_x[:, None]   # shape (2, 3)
q_y_given_x = Q / q_x[:, None]

# (c) Chain rule
kl_marg = kl_divergence(p_x, q_x)
kl_cond_avg = sum(
    p_x[i] * kl_divergence(p_y_given_x[i], q_y_given_x[i])
    for i in range(2)
)
chain_rule_rhs = kl_marg + kl_cond_avg

header('Exercise 4: Chain Rule for KL')
print(f'D_KL(p_XY || q_XY) direct  = {kl_joint:.8f}')
print(f'D_KL(p_X  || q_X)  marginal = {kl_marg:.8f}')
print(f'E_p[D_KL(p_Y|X || q_Y|X)]  = {kl_cond_avg:.8f}')
print(f'Chain rule RHS              = {chain_rule_rhs:.8f}')

check_close('Chain rule: joint = marginal + expected conditional', kl_joint, chain_rule_rhs)
check_true('DPI: D_KL(joint) >= D_KL(marginal)', kl_joint >= kl_marg - 1e-10)

print('\n(d) E_p[D_KL(p_Y|X || q_Y|X)] >= 0 (Gibbs), so')
print('    D_KL(p_XY || q_XY) = D_KL(p_X || q_X) + non-negative term >= D_KL(p_X || q_X)')
print('    Marginalisation (a deterministic function of (X,Y)) can only lose information.')
print('\nTakeaway: Chain rule => DPI; conditioning adds non-negative KL contributions')

Exercise 5 ★★ — Gaussian KL and the VAE Encoder

For two univariate Gaussians p=N(μ1,σ12)p = \mathcal{N}(\mu_1, \sigma_1^2) and q=N(μ2,σ22)q = \mathcal{N}(\mu_2, \sigma_2^2), the closed-form KL is:

DKL(pq)=lnσ2σ1+σ12+(μ1μ2)22σ2212D_{\mathrm{KL}}(p \| q) = \ln\frac{\sigma_2}{\sigma_1} + \frac{\sigma_1^2 + (\mu_1-\mu_2)^2}{2\sigma_2^2} - \frac{1}{2}

In a VAE, the encoder outputs (μϕ,σϕ)(\mu_\phi, \sigma_\phi) and the KL penalty is DKL(N(μϕ,σϕ2)N(0,1))D_{\mathrm{KL}}(\mathcal{N}(\mu_\phi, \sigma_\phi^2) \| \mathcal{N}(0, 1)).

(a) Implement kl_gaussian(mu1, sigma1, mu2, sigma2) using the closed-form formula.

(b) Verify with p=N(2,1.52)p = \mathcal{N}(2, 1.5^2), q=N(0,1)q = \mathcal{N}(0, 1): compute analytically and cross-check with a Monte Carlo estimate using 10610^6 samples.

(c) Implement the VAE KL penalty kl_vae(mu_phi, sigma_phi) = DKL(N(μϕ,σϕ2)N(0,1))D_{\mathrm{KL}}(\mathcal{N}(\mu_\phi, \sigma_\phi^2) \| \mathcal{N}(0,1)). Verify it simplifies to 12(μϕ2+σϕ2lnσϕ21)\frac{1}{2}(\mu_\phi^2 + \sigma_\phi^2 - \ln\sigma_\phi^2 - 1).

(d) Compute the gradient /μϕ\partial / \partial \mu_\phi of the VAE KL. Show it equals μϕ\mu_\phi and verify numerically for μϕ=2\mu_\phi = 2, σϕ=1.5\sigma_\phi = 1.5.

Code cell 17

# Your Solution
print("Write your solution here, then compare with the reference solution below.")

Code cell 18

# Solution
# Exercise 5: Solution
from scipy.stats import norm as sp_norm

# (a)
def kl_gaussian(mu1, sigma1, mu2, sigma2):
    return (np.log(sigma2 / sigma1)
            + (sigma1**2 + (mu1 - mu2)**2) / (2 * sigma2**2)
            - 0.5)

# (b)
mu1, s1 = 2.0, 1.5
mu2, s2 = 0.0, 1.0
kl_analytic = kl_gaussian(mu1, s1, mu2, s2)

samples = np.random.normal(mu1, s1, 1_000_000)
log_p = sp_norm.logpdf(samples, mu1, s1)
log_q = sp_norm.logpdf(samples, mu2, s2)
kl_mc = float(np.mean(log_p - log_q))

# (c) VAE KL = D_KL(N(mu, sigma^2) || N(0,1))
#   = 0.5 * (mu^2 + sigma^2 - ln(sigma^2) - 1)
def kl_vae(mu_phi, sigma_phi):
    return 0.5 * (mu_phi**2 + sigma_phi**2 - np.log(sigma_phi**2) - 1.0)

kl_vae_val = kl_vae(mu1, s1)

# (d) Gradient wrt mu_phi: d/d(mu) [0.5*(mu^2 + ...)] = mu
eps_fd = 1e-5
grad_mu_fd = (kl_vae(mu1 + eps_fd, s1) - kl_vae(mu1 - eps_fd, s1)) / (2 * eps_fd)
grad_mu_analytic = mu1  # = 2.0

header('Exercise 5: Gaussian KL and VAE Encoder')
print(f'D_KL(N({mu1},{s1}^2)||N(0,1)) analytic = {kl_analytic:.6f}')
print(f'D_KL Monte Carlo (1M samples)            = {kl_mc:.6f}')
print(f'VAE formula 0.5*(mu^2+sigma^2-ln(sigma^2)-1) = {kl_vae_val:.6f}')

check_close('Analytic == MC (within 1e-2)', kl_analytic, kl_mc, tol=0.02)
check_close('Analytic == VAE formula', kl_analytic, kl_vae_val)
check_close('Gradient d/d(mu) = mu_phi', grad_mu_fd, grad_mu_analytic)

print(f'\nGradient d KL/d mu_phi: finite-diff={grad_mu_fd:.6f}, analytic={grad_mu_analytic:.6f}')
print('\nTakeaway: VAE encoder trained to minimise recon + 0.5*(mu^2+sigma^2-ln(sigma^2)-1)')

Exercise 6 ★★ — f-Divergences and Pinsker's Inequality

The Csiszár f-divergence is Df(pq)=xq(x)f ⁣(p(x)q(x))D_f(p \| q) = \sum_x q(x)\,f\!\left(\frac{p(x)}{q(x)}\right) for a convex ff with f(1)=0f(1) = 0.

Divergencef(t)f(t)
KLtlntt \ln t
Reverse KLlnt-\ln t
Jensen–Shannont+12lnt+12+t2lnt-\frac{t+1}{2}\ln\frac{t+1}{2} + \frac{t}{2}\ln t
Total Variation$\frac{1}{2}
χ2\chi^2(t1)2(t-1)^2

Pinsker's inequality: TV(p,q)12DKL(pq)\mathrm{TV}(p,q) \le \sqrt{\frac{1}{2}D_{\mathrm{KL}}(p\|q)}.

Use p=(0.5,0.3,0.2)p = (0.5, 0.3, 0.2), q=(0.2,0.5,0.3)q = (0.2, 0.5, 0.3).

(a) Implement js_divergence(p, q) and chi2_divergence(p, q). Verify 0JSln20 \le \mathrm{JS} \le \ln 2.

(b) Compute the Total Variation distance TV(p,q)=12pq1\mathrm{TV}(p,q) = \tfrac{1}{2}\|p-q\|_1.

(c) Verify Pinsker's inequality: TV212DKL(pq)\mathrm{TV}^2 \le \tfrac{1}{2}D_{\mathrm{KL}}(p\|q).

(d) Compute JS\mathrm{JS} and verify the tighter Pinsker for JS: TV(p,q)JS(p,q)\mathrm{TV}(p,q) \le \sqrt{\mathrm{JS}(p,q)}.

Code cell 20

# Your Solution
print("Write your solution here, then compare with the reference solution below.")

Code cell 21

# Solution
# Exercise 6: Solution

p = np.array([0.5, 0.3, 0.2])
q = np.array([0.2, 0.5, 0.3])

# (a)
def js_divergence(p, q):
    p, q = np.asarray(p, float), np.asarray(q, float)
    m = 0.5 * (p + q)
    return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)

def chi2_divergence(p, q):
    p, q = np.asarray(p, float), np.asarray(q, float)
    return float(np.sum((p - q)**2 / np.maximum(q, 1e-300)))

js = js_divergence(p, q)
chi2 = chi2_divergence(p, q)

# (b)
tv = 0.5 * np.sum(np.abs(p - q))

# (c) Pinsker
kl_pq = kl_divergence(p, q)
pinsker_bound = np.sqrt(0.5 * kl_pq)

# (d) JS Pinsker
js_pinsker_bound = np.sqrt(js)

header('Exercise 6: f-Divergences and Pinsker\'s Inequality')
print(f'D_KL(p||q)         = {kl_pq:.6f} nats')
print(f'JS(p,q)            = {js:.6f} nats  (in [0, ln2={np.log(2):.4f}])')
print(f'chi^2(p||q)        = {chi2:.6f}')
print(f'TV(p,q)            = {tv:.6f}')
print(f'Pinsker bound sqrt(KL/2) = {pinsker_bound:.6f}')
print(f'JS Pinsker sqrt(JS)      = {js_pinsker_bound:.6f}')

check_true('JS >= 0', js >= 0)
check_true('JS <= ln(2)', js <= np.log(2) + 1e-10)
check_true('Pinsker: TV <= sqrt(KL/2)', tv <= pinsker_bound + 1e-10)
check_true('JS Pinsker: TV <= sqrt(JS)', tv <= js_pinsker_bound + 1e-10)
check_true('JS Pinsker is tighter (smaller bound)', js_pinsker_bound <= pinsker_bound + 1e-10)

print('\nTakeaway: JS symmetric, bounded in [0,ln2], used in GANs; Pinsker links TV to KL')

Exercise 7 ★★★ — RLHF Optimal Policy and DPO

In RLHF the constrained optimisation problem is:

maxπEyπ[r(x,y)]βDKL(ππref)\max_{\pi} \mathbb{E}_{y \sim \pi}[r(x,y)] - \beta\, D_{\mathrm{KL}}(\pi \| \pi_{\mathrm{ref}})

The analytic solution is the Gibbs / softmax policy: π(yx)=1Z(x)πref(yx)exp ⁣(r(x,y)β)\pi^*(y|x) = \frac{1}{Z(x)}\pi_{\mathrm{ref}}(y|x)\exp\!\left(\frac{r(x,y)}{\beta}\right).

DPO eliminates the need to train a separate reward model: it shows that the implicit reward is r(x,y)=βln(πθ(yx)/πref(yx))+βlnZr(x,y) = \beta\ln(\pi_\theta(y|x)/\pi_{\mathrm{ref}}(y|x)) + \beta\ln Z.

Use a toy example with 5 candidate responses: r = [3.0, 1.5, 0.5, -0.5, -2.0], log_pi_ref = [-1.0, -1.5, -2.0, -2.5, -3.0], β=1.0\beta = 1.0.

(a) Compute π\pi^* in log-space (numerically stable). Print the optimal policy.

(b) Compute DKL(ππref)D_{\mathrm{KL}}(\pi^* \| \pi_{\mathrm{ref}}) and the expected reward under π\pi^*.

(c) Verify that π\pi^* satisfies the first-order condition: lnπ(y)lnπref(y)=r(y)/βlnZ\ln \pi^*(y) - \ln \pi_{\mathrm{ref}}(y) = r(y)/\beta - \ln Z.

(d) Compute the DPO implicit reward r^(y)=βln(π(y)/πref(y))\hat{r}(y) = \beta\ln(\pi^*(y)/\pi_{\mathrm{ref}}(y)). Show it ranks responses in the same order as r(y)r(y) (up to the constant lnZ\ln Z).

Code cell 23

# Your Solution
print("Write your solution here, then compare with the reference solution below.")

Code cell 24

# Solution
# Exercise 7: Solution

r       = np.array([3.0, 1.5, 0.5, -0.5, -2.0])
log_ref = np.array([-1.0, -1.5, -2.0, -2.5, -3.0])
beta    = 1.0

# (a) Optimal policy — numerically stable log-space computation
log_pi_unnorm = log_ref + r / beta
log_Z = np.log(np.sum(np.exp(log_pi_unnorm - log_pi_unnorm.max()))) + log_pi_unnorm.max()
log_pi_star = log_pi_unnorm - log_Z
pi_star = np.exp(log_pi_star)
pi_ref  = np.exp(log_ref - np.log(np.exp(log_ref).sum()))

# (b)
kl_star_ref = kl_divergence(pi_star, pi_ref)
E_reward    = float(np.sum(pi_star * r))
E_reward_ref= float(np.sum(pi_ref * r))

# (c) First-order condition
lhs = log_pi_star - log_ref            # should equal r/beta - log_Z
rhs = r / beta - log_Z

# (d) DPO implicit reward
r_dpo = beta * (log_pi_star - log_ref)  # = r - beta*log_Z (same ranking)

header('Exercise 7: RLHF Optimal Policy and DPO')
print('pi_ref (normalised):', pi_ref.round(4))
print('pi*:                ', pi_star.round(4))
print(f'log Z = {log_Z:.4f}')
print(f'E[r] under pi_ref = {E_reward_ref:.4f}')
print(f'E[r] under pi*    = {E_reward:.4f}  (improved by {E_reward - E_reward_ref:.4f})')
print(f'KL(pi*||pi_ref)   = {kl_star_ref:.4f}')
print(f'RLHF objective    = {E_reward - beta*kl_star_ref:.4f}')

check_close('pi* sums to 1', pi_star.sum(), 1.0)
check_close('First-order condition: log pi* - log ref = r/beta - logZ', lhs, rhs)
check_true('DPO reward ranks same as r', np.all(np.argsort(r_dpo)[::-1] == np.argsort(r)[::-1]))
check_true('pi* concentrates on high-reward responses', pi_star[0] > pi_ref[0])

print('\nDPO implicit reward (beta * log ratio):', r_dpo.round(4))
print('True reward r:                          ', r.round(4))
print('\nTakeaway: RLHF optimal policy = reference * exp(r/beta); DPO uses this analytically')

Exercise 8 ★★★ — Knowledge Distillation

In knowledge distillation a small student network qθq_\theta is trained to match a large teacher pTp_T by minimising a KL-based loss.

Two common objectives:

ObjectiveFormulaBehaviour
Forward KLDKL(pTqθ)D_{\mathrm{KL}}(p_T \| q_\theta)Mean-seeking; preserves soft labels
Reverse KLDKL(qθpT)D_{\mathrm{KL}}(q_\theta \| p_T)Mode-seeking; may ignore dark knowledge

A teacher produces logits over 6 classes. A student is parameterised by its own 6 logits. We study how temperature τ\tau affects the distillation KL.

Teacher logits: z_T = [4.0, 2.0, 1.0, 0.5, 0.0, -1.0]

(a) Compute pT(τ)p_T(\tau) = softmax(z_T, tau) for τ{0.5,1,2,5}\tau \in \{0.5, 1, 2, 5\}. Print entropy H(pT(τ))H(p_T(\tau)) for each.

(b) For each τ\tau, find the student logits z_S that minimise DKL(pT(τ)qS)D_{\mathrm{KL}}(p_T(\tau) \| q_S) when qSq_S = softmax(z_S, 1). (Hint: the minimum is attained at zSlogpT(τ)z_S \propto \log p_T(\tau), i.e., zS=τzT+cz_S = \tau \cdot z_T + c. Verify this analytically.)

(c) Compute the KL loss at the optimal zSz_S for each τ\tau — it should equal 0.

(d) Now fix a suboptimal student with logits z_S_bad = [3.5, 1.8, 0.8, 0.4, 0.0, -0.6]. Compute DKL(pT(τ)qS,bad)D_{\mathrm{KL}}(p_T(\tau) \| q_{S,\mathrm{bad}}) for each τ\tau. Explain why higher τ\tau gives a larger or smaller KL.

(e) Which direction of KL should a practitioner use for distillation and why?

Code cell 26

# Your Solution
print("Write your solution here, then compare with the reference solution below.")

Code cell 27

# Solution
# Exercise 8: Solution

z_T = np.array([4.0, 2.0, 1.0, 0.5, 0.0, -1.0])
taus = [0.5, 1.0, 2.0, 5.0]

header('Exercise 8: Knowledge Distillation')

# (a) Teacher distributions and entropy
print('(a) Teacher softmax distributions:')
print("{:>6} | {:>8} | {:>10} | {}".format("tau", "H(p_T)", "max prob", "distribution"))
print('-' * 60)
teacher_dists = {}
for tau in taus:
    p_tau = softmax(z_T, tau)
    teacher_dists[tau] = p_tau
    H_tau = -np.sum(p_tau * np.log(p_tau + 1e-300))
    print(f"{tau:>6.1f} | {H_tau:>8.4f} | {p_tau.max():>10.4f} | {p_tau.round(3)}")

# (b) Optimal student: z_S = tau * z_T satisfies q_S = softmax(z_S,1) = softmax(z_T, tau)
print('\n(b-c) Optimal student: z_S = tau * z_T')
for tau in taus:
    p_tau = teacher_dists[tau]
    z_opt = tau * z_T  # log p_tau is proportional to tau*z_T
    q_opt = softmax(z_opt, 1.0)
    kl_opt = kl_divergence(p_tau, q_opt)
    print(f'  tau={tau}: KL(p_T||q_opt) = {kl_opt:.2e}  (should be ~0)')

check_true('Optimal student achieves KL~0 for all temps',
           all(kl_divergence(softmax(z_T,t), softmax(t*z_T,1.0)) < 1e-10 for t in taus))

# (d) Suboptimal student
z_bad = np.array([3.5, 1.8, 0.8, 0.4, 0.0, -0.6])
q_bad = softmax(z_bad, 1.0)

print('\n(d) KL loss with suboptimal student:')
kl_vals = []
for tau in taus:
    p_tau = teacher_dists[tau]
    kl_bad = kl_divergence(p_tau, q_bad)
    kl_vals.append(kl_bad)
    print(f'  tau={tau}: KL(p_T(tau)||q_bad) = {kl_bad:.6f}')

# Higher tau -> softer teacher -> easier for student to match -> smaller KL
kl_monotone_decreasing = kl_vals[0] > kl_vals[-1]
check_true('Higher tau => smaller KL loss (softer teacher easier to match)', kl_monotone_decreasing)

print('\n(e) Recommendation: use FORWARD KL = D_KL(p_T || q_student).')
print('    Forward KL forces q to cover all modes where p_T > 0,')
print('    preserving dark knowledge (non-argmax probabilities).')
print('    Reverse KL is mode-seeking: student ignores secondary modes.')
print('\nTakeaway: Higher temp softens teacher -> smaller KL gap; forward KL = dark knowledge')

Exercise 9: Temperature Distillation KL

Compare teacher and student distributions under a distillation temperature and compute the KL penalty used to train soft targets.

Code cell 29

# Your Solution
print("Compute KL between softened teacher and student probabilities.")

Code cell 30

# Solution
header("Exercise 9: Distillation KL")
teacher_logits = np.array([4.0, 1.5, -1.0])
student_logits = np.array([2.5, 1.0, 0.0])
T = 2.0
p_t = softmax(teacher_logits, tau=T)
p_s = softmax(student_logits, tau=T)
kl_ts = kl_divergence(p_t, p_s)
print("teacher:", np.round(p_t, 4))
print("student:", np.round(p_s, 4))
print("KL teacher||student:", round(kl_ts, 6))
check_true("KL is nonnegative", kl_ts >= -1e-12)
print("Takeaway: distillation transfers dark knowledge through the full softened distribution.")

Exercise 10: KL-Regularized Policy Update

Given a reference policy and rewards, compute the closed-form KL-regularized optimal policy \pi^*(a) \propto \pi_{ref}(a) e^{r(a)/eta}.

Code cell 32

# Your Solution
print("Compute the KL-regularized policy update.")

Code cell 33

# Solution
header("Exercise 10: KL-Regularized Policy")
pi_ref = np.array([0.5, 0.3, 0.2])
r = np.array([0.0, 1.0, 1.8])
beta = 0.7
unnorm = pi_ref * np.exp(r / beta)
pi_star = unnorm / unnorm.sum()
print("pi_ref:", np.round(pi_ref, 4))
print("pi_star:", np.round(pi_star, 4))
check_close("policy normalizes", pi_star.sum(), 1.0)
check_true("rewarded action gains mass", pi_star[-1] > pi_ref[-1])
print("Takeaway: RLHF-style KL control tilts a reference model toward high reward without abandoning it.")

Summary

ExerciseTopicKey Result
1 ★KL definitionH(p,q)=H(p)+DKL(pq)H(p,q) = H(p) + D_{\mathrm{KL}}(p\|q); model wastes KL nats
2 ★Gibbs' inequalitylntt1DKL0\ln t \le t-1 \Rightarrow D_{\mathrm{KL}} \ge 0; equality iff p=qp=q
3 ★Forward vs reverse KLFwd = mass-covering; rev = mode-seeking
4 ★★Chain ruleDKL(pXYqXY)=DKL(pXqX)+E[DKL]D_{\mathrm{KL}}(p_{XY}\|q_{XY}) = D_{\mathrm{KL}}(p_X\|q_X) + \mathbb{E}[D_{\mathrm{KL}}]
5 ★★Gaussian KL + VAE12(μ2+σ2lnσ21)\frac{1}{2}(\mu^2+\sigma^2-\ln\sigma^2-1); gradient = μ\mu
6 ★★f-Divergences + PinskerTV12KL\mathrm{TV}\le\sqrt{\tfrac{1}{2}\mathrm{KL}}; JS tighter
7 ★★★RLHF + DPOππrefexp(r/β)\pi^*\propto\pi_{\mathrm{ref}}\exp(r/\beta); DPO implicit reward
8 ★★★Knowledge distillationForward KL preserves dark knowledge; higher τ\tau softens teacher

Next section: Mutual Information →