Exercises Notebook
Converted from
exercises.ipynbfor web reading.
KL Divergence — Exercises
8 exercises covering KL divergence from first principles to modern AI applications.
| Format | Description |
|---|---|
| Problem | Markdown cell with task description |
| Your Solution | Code cell with scaffolding |
| Solution | Code cell with reference solution and checks |
Difficulty Levels
| Level | Exercises | Focus |
|---|---|---|
| ★ | 1–3 | Core mechanics and definitions |
| ★★ | 4–6 | Properties, proofs, and distributions |
| ★★★ | 7–8 | AI applications: RLHF and distillation |
Topic Map
| Topic | Exercise |
|---|---|
| Definition and coding interpretation | 1 |
| Gibbs' inequality (non-negativity) | 2 |
| Forward vs reverse KL | 3 |
| Chain rule for KL | 4 |
| Gaussian KL + VAE encoder | 5 |
| f-Divergences and Pinsker | 6 |
| RLHF optimal policy | 7 |
| Knowledge distillation | 8 |
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
Code cell 3
import numpy as np
import numpy.linalg as la
try:
import matplotlib.pyplot as plt
HAS_MPL = True
except ImportError:
HAS_MPL = False
from scipy.special import expit # sigmoid
from scipy.optimize import brentq, minimize
np.set_printoptions(precision=6, suppress=True)
np.random.seed(42)
def header(title):
print()
print('=' * len(title))
print(title)
print('=' * len(title))
def check_close(name, got, expected, tol=1e-6):
ok = np.allclose(got, expected, atol=tol, rtol=tol)
print(f"{'PASS' if ok else 'FAIL'} — {name}")
if not ok:
print(f' expected: {expected}')
print(f' got: {got}')
return ok
def check_true(name, cond):
print(f"{'PASS' if cond else 'FAIL'} — {name}")
return cond
def kl_divergence(p, q, eps=1e-300):
"""D_KL(p || q) for discrete distributions"""
p, q = np.asarray(p, float), np.asarray(q, float)
p = p / p.sum(); q = q / q.sum()
mask = p > 0
return float(np.sum(p[mask] * np.log(p[mask] / np.maximum(q[mask], eps))))
def softmax(z, tau=1.0):
z = np.asarray(z, float)
z_shift = (z - z.max()) / tau
e = np.exp(z_shift)
return e / e.sum()
print('Setup complete.')
Exercise 1 ★ — KL Divergence: Definition and Coding Interpretation
A language model assigns probabilities to the next token in the sentence "The Eiffel Tower is in ___":
| Token | True | Model |
|---|---|---|
| 'Paris' | 0.90 | 0.65 |
| 'London' | 0.04 | 0.15 |
| 'Rome' | 0.03 | 0.10 |
| 'Berlin' | 0.02 | 0.05 |
| 'other' | 0.01 | 0.05 |
(a) Compute and in nats.
(b) Compute the cross-entropy and verify .
(c) Interpret : how many extra nats per token does the model waste?
(d) Which direction ( or ) is larger? Explain why in one sentence.
Code cell 5
# Your Solution
print("Write your solution here, then compare with the reference solution below.")
Code cell 6
# Solution
# Exercise 1: Solution
p = np.array([0.90, 0.04, 0.03, 0.02, 0.01])
q = np.array([0.65, 0.15, 0.10, 0.05, 0.05])
# (a)
kl_pq = kl_divergence(p, q)
kl_qp = kl_divergence(q, p)
# (b)
H_p = -np.sum(p * np.log(p))
H_pq = -np.sum(p * np.log(q))
header('Exercise 1: KL Divergence — Definition and Coding')
print(f'D_KL(p||q) = {kl_pq:.6f} nats [{kl_pq/np.log(2):.4f} bits]')
print(f'D_KL(q||p) = {kl_qp:.6f} nats [{kl_qp/np.log(2):.4f} bits]')
print(f'H(p) = {H_p:.6f} nats')
print(f'H(p,q) = {H_pq:.6f} nats')
check_true('D_KL(p||q) >= 0', kl_pq >= 0)
check_true('D_KL(q||p) >= 0', kl_qp >= 0)
check_close('H(p,q) = H(p) + D_KL(p||q)', H_pq, H_p + kl_pq)
check_true('D_KL(p||q) != D_KL(q||p) (asymmetric)', not np.isclose(kl_pq, kl_qp))
print(f'\n(c) Model wastes {kl_pq:.4f} nats = {kl_pq/np.log(2):.4f} bits per prediction')
print(f'(d) D_KL(q||p) = {kl_qp:.4f} > D_KL(p||q) = {kl_pq:.4f}')
print(' q assigns 15% to London but p=4%: reverse KL penalizes this heavily')
print('\nTakeaway: D_KL(p||q) = extra bits from model mismatch; KL = perplexity gap')
Exercise 2 ★ — Gibbs' Inequality: Proof and Verification
Prove for the specific distributions: and .
(a) Verify numerically: compute both KL directions and confirm both are .
(b) Verify the log-inequality: for each symbol , compute and . Confirm for each.
(c) Use (b) to construct the proof: show that .
(d) Under what conditions does equality hold? Verify numerically.
Code cell 8
# Your Solution
print("Write your solution here, then compare with the reference solution below.")
Code cell 9
# Solution
# Exercise 2: Solution
p = np.array([0.6, 0.3, 0.1])
q = np.array([0.2, 0.5, 0.3])
kl_pq = kl_divergence(p, q)
kl_qp = kl_divergence(q, p)
log_ratio = np.log(q / p) # ln(q/p)
linear_m1 = q / p - 1 # q/p - 1
# Proof step: sum p * ln(q/p) <= sum p * (q/p - 1)
lhs = np.sum(p * log_ratio) # -D_KL(p||q)
rhs = np.sum(p * linear_m1) # = sum q - sum p = 1 - 1 = 0
header('Exercise 2: Gibbs Inequality')
print(f'D_KL(p||q) = {kl_pq:.6f} (should be >= 0)')
print(f'D_KL(q||p) = {kl_qp:.6f} (should be >= 0)')
print()
print('Pointwise: ln(q/p) vs q/p - 1:')
for i, (l, r) in enumerate(zip(log_ratio, linear_m1)):
print(f' x={i}: ln(q/p)={l:.4f} <= q/p-1={r:.4f}: {l <= r + 1e-12}')
print()
print(f'sum p*ln(q/p) = {lhs:.6f} (= -D_KL(p||q) = {-kl_pq:.6f})')
print(f'sum p*(q/p-1) = {rhs:.6f} (= 0 by normalization)')
check_true('D_KL(p||q) >= 0', kl_pq >= 0)
check_true('D_KL(q||p) >= 0', kl_qp >= 0)
check_true('pointwise: ln(q/p) <= q/p - 1', (log_ratio <= linear_m1 + 1e-12).all())
check_close('sum p*(q/p-1) = 0', rhs, 0.0)
# Equality: D_KL(p||p) = 0
kl_self = kl_divergence(p, p)
check_close('D_KL(p||p) = 0 (equality at p=q)', kl_self, 0.0)
print('\nTakeaway: Gibbs = Jensen on concave log; equality iff p=q everywhere')
Exercise 3 ★ — Forward vs Reverse KL on a Bimodal Target
The true distribution is a mixture of two Gaussians discretised onto a grid: . We fit a unimodal Gaussian by minimising KL in each direction.
(a) Build the discrete grid (501 points) and construct .
(b) Minimise (forward KL) over using scipy.optimize.minimize.
Print the optimal and call the result .
(c) Minimise (reverse KL) in the same way. Print the optimal and call the result .
(d) Explain in 1–2 sentences why forward KL is mass-covering and reverse KL is mode-seeking.
(e) (Optional) Plot , , on one figure.
Code cell 11
# Your Solution
print("Write your solution here, then compare with the reference solution below.")
Code cell 12
# Solution
# Exercise 3: Solution
from scipy.stats import norm as sp_norm
from scipy.optimize import minimize as sp_minimize
x = np.linspace(-8, 8, 501)
dx = x[1] - x[0]
# (a) Bimodal target
p_raw = sp_norm.pdf(x, -3, 0.8) + sp_norm.pdf(x, 3, 0.8)
p = p_raw / (p_raw.sum() * dx)
p_prob = p * dx # discrete probabilities
p_prob /= p_prob.sum()
def gaussian_q_prob(params):
mu, log_s = params
sigma = np.exp(log_s)
q_raw = sp_norm.pdf(x, mu, sigma)
q_prob = q_raw * dx
return q_prob / q_prob.sum()
eps = 1e-300
def fwd_kl(params):
q = gaussian_q_prob(params)
mask = p_prob > eps
return float(np.sum(p_prob[mask] * np.log(p_prob[mask] / np.maximum(q[mask], eps))))
def rev_kl(params):
q = gaussian_q_prob(params)
mask = q > eps
return float(np.sum(q[mask] * np.log(q[mask] / np.maximum(p_prob[mask], eps))))
res_fwd = sp_minimize(fwd_kl, [0.0, np.log(2.0)], method='Nelder-Mead',
options={'xatol': 1e-6, 'fatol': 1e-8, 'maxiter': 5000})
res_rev = sp_minimize(rev_kl, [-3.0, np.log(0.8)], method='Nelder-Mead',
options={'xatol': 1e-6, 'fatol': 1e-8, 'maxiter': 5000})
mu_fwd, sig_fwd = res_fwd.x[0], np.exp(res_fwd.x[1])
mu_rev, sig_rev = res_rev.x[0], np.exp(res_rev.x[1])
header('Exercise 3: Forward vs Reverse KL')
print(f'Forward KL optimal: mu={mu_fwd:.3f}, sigma={sig_fwd:.3f} D_KL={res_fwd.fun:.4f}')
print(f'Reverse KL optimal: mu={mu_rev:.3f}, sigma={sig_rev:.3f} D_KL={res_rev.fun:.4f}')
check_true('Forward KL: |mu| < 0.5 (mean of bimodal -> near 0)', abs(mu_fwd) < 0.5)
check_true('Forward KL: sigma > 2.5 (mass-covering, wide)', sig_fwd > 2.5)
check_true('Reverse KL: |mu| > 1.5 (mode-seeking, latches a mode)', abs(mu_rev) > 1.5)
check_true('Reverse KL: sigma < 1.5 (mode-seeking, narrow)', sig_rev < 1.5)
print('\n(d) Forward KL penalises p(x)>0, q(x)=0 (zero-avoidance),\n'
' so q must cover both modes -> broad, mean-seeking fit.')
print(' Reverse KL penalises q(x)>0, p(x)=0,\n'
' so q avoids the inter-modal gap -> narrow, mode-seeking fit.')
print('\nTakeaway: forward KL=MLE objective; reverse KL=variational inference (ELBO)')
Exercise 4 ★★ — Chain Rule Decomposition
Let be a pair of discrete random variables with joint PMF . The chain rule for KL states:
Use the following joint PMFs:
(a) Compute directly (flatten the joint).
(b) Compute the marginals and the conditionals , etc.
(c) Compute the right-hand side of the chain rule and verify equality.
(d) State why the chain rule implies (the Data Processing Inequality).
Code cell 14
# Your Solution
print("Write your solution here, then compare with the reference solution below.")
Code cell 15
# Solution
# Exercise 4: Solution
P = np.array([[0.30, 0.15, 0.05],
[0.20, 0.20, 0.10]])
Q = np.array([[0.20, 0.20, 0.10],
[0.15, 0.25, 0.10]])
# (a)
kl_joint = kl_divergence(P.ravel(), Q.ravel())
# (b) Marginals
p_x = P.sum(axis=1) # shape (2,)
q_x = Q.sum(axis=1)
# Conditionals p(y|x) = p(x,y) / p(x)
p_y_given_x = P / p_x[:, None] # shape (2, 3)
q_y_given_x = Q / q_x[:, None]
# (c) Chain rule
kl_marg = kl_divergence(p_x, q_x)
kl_cond_avg = sum(
p_x[i] * kl_divergence(p_y_given_x[i], q_y_given_x[i])
for i in range(2)
)
chain_rule_rhs = kl_marg + kl_cond_avg
header('Exercise 4: Chain Rule for KL')
print(f'D_KL(p_XY || q_XY) direct = {kl_joint:.8f}')
print(f'D_KL(p_X || q_X) marginal = {kl_marg:.8f}')
print(f'E_p[D_KL(p_Y|X || q_Y|X)] = {kl_cond_avg:.8f}')
print(f'Chain rule RHS = {chain_rule_rhs:.8f}')
check_close('Chain rule: joint = marginal + expected conditional', kl_joint, chain_rule_rhs)
check_true('DPI: D_KL(joint) >= D_KL(marginal)', kl_joint >= kl_marg - 1e-10)
print('\n(d) E_p[D_KL(p_Y|X || q_Y|X)] >= 0 (Gibbs), so')
print(' D_KL(p_XY || q_XY) = D_KL(p_X || q_X) + non-negative term >= D_KL(p_X || q_X)')
print(' Marginalisation (a deterministic function of (X,Y)) can only lose information.')
print('\nTakeaway: Chain rule => DPI; conditioning adds non-negative KL contributions')
Exercise 5 ★★ — Gaussian KL and the VAE Encoder
For two univariate Gaussians and , the closed-form KL is:
In a VAE, the encoder outputs and the KL penalty is .
(a) Implement kl_gaussian(mu1, sigma1, mu2, sigma2) using the closed-form formula.
(b) Verify with , : compute analytically and cross-check with a Monte Carlo estimate using samples.
(c) Implement the VAE KL penalty kl_vae(mu_phi, sigma_phi) = .
Verify it simplifies to .
(d) Compute the gradient of the VAE KL. Show it equals and verify numerically for , .
Code cell 17
# Your Solution
print("Write your solution here, then compare with the reference solution below.")
Code cell 18
# Solution
# Exercise 5: Solution
from scipy.stats import norm as sp_norm
# (a)
def kl_gaussian(mu1, sigma1, mu2, sigma2):
return (np.log(sigma2 / sigma1)
+ (sigma1**2 + (mu1 - mu2)**2) / (2 * sigma2**2)
- 0.5)
# (b)
mu1, s1 = 2.0, 1.5
mu2, s2 = 0.0, 1.0
kl_analytic = kl_gaussian(mu1, s1, mu2, s2)
samples = np.random.normal(mu1, s1, 1_000_000)
log_p = sp_norm.logpdf(samples, mu1, s1)
log_q = sp_norm.logpdf(samples, mu2, s2)
kl_mc = float(np.mean(log_p - log_q))
# (c) VAE KL = D_KL(N(mu, sigma^2) || N(0,1))
# = 0.5 * (mu^2 + sigma^2 - ln(sigma^2) - 1)
def kl_vae(mu_phi, sigma_phi):
return 0.5 * (mu_phi**2 + sigma_phi**2 - np.log(sigma_phi**2) - 1.0)
kl_vae_val = kl_vae(mu1, s1)
# (d) Gradient wrt mu_phi: d/d(mu) [0.5*(mu^2 + ...)] = mu
eps_fd = 1e-5
grad_mu_fd = (kl_vae(mu1 + eps_fd, s1) - kl_vae(mu1 - eps_fd, s1)) / (2 * eps_fd)
grad_mu_analytic = mu1 # = 2.0
header('Exercise 5: Gaussian KL and VAE Encoder')
print(f'D_KL(N({mu1},{s1}^2)||N(0,1)) analytic = {kl_analytic:.6f}')
print(f'D_KL Monte Carlo (1M samples) = {kl_mc:.6f}')
print(f'VAE formula 0.5*(mu^2+sigma^2-ln(sigma^2)-1) = {kl_vae_val:.6f}')
check_close('Analytic == MC (within 1e-2)', kl_analytic, kl_mc, tol=0.02)
check_close('Analytic == VAE formula', kl_analytic, kl_vae_val)
check_close('Gradient d/d(mu) = mu_phi', grad_mu_fd, grad_mu_analytic)
print(f'\nGradient d KL/d mu_phi: finite-diff={grad_mu_fd:.6f}, analytic={grad_mu_analytic:.6f}')
print('\nTakeaway: VAE encoder trained to minimise recon + 0.5*(mu^2+sigma^2-ln(sigma^2)-1)')
Exercise 6 ★★ — f-Divergences and Pinsker's Inequality
The Csiszár f-divergence is for a convex with .
| Divergence | |
|---|---|
| KL | |
| Reverse KL | |
| Jensen–Shannon | |
| Total Variation | $\frac{1}{2} |
Pinsker's inequality: .
Use , .
(a) Implement js_divergence(p, q) and chi2_divergence(p, q). Verify .
(b) Compute the Total Variation distance .
(c) Verify Pinsker's inequality: .
(d) Compute and verify the tighter Pinsker for JS: .
Code cell 20
# Your Solution
print("Write your solution here, then compare with the reference solution below.")
Code cell 21
# Solution
# Exercise 6: Solution
p = np.array([0.5, 0.3, 0.2])
q = np.array([0.2, 0.5, 0.3])
# (a)
def js_divergence(p, q):
p, q = np.asarray(p, float), np.asarray(q, float)
m = 0.5 * (p + q)
return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)
def chi2_divergence(p, q):
p, q = np.asarray(p, float), np.asarray(q, float)
return float(np.sum((p - q)**2 / np.maximum(q, 1e-300)))
js = js_divergence(p, q)
chi2 = chi2_divergence(p, q)
# (b)
tv = 0.5 * np.sum(np.abs(p - q))
# (c) Pinsker
kl_pq = kl_divergence(p, q)
pinsker_bound = np.sqrt(0.5 * kl_pq)
# (d) JS Pinsker
js_pinsker_bound = np.sqrt(js)
header('Exercise 6: f-Divergences and Pinsker\'s Inequality')
print(f'D_KL(p||q) = {kl_pq:.6f} nats')
print(f'JS(p,q) = {js:.6f} nats (in [0, ln2={np.log(2):.4f}])')
print(f'chi^2(p||q) = {chi2:.6f}')
print(f'TV(p,q) = {tv:.6f}')
print(f'Pinsker bound sqrt(KL/2) = {pinsker_bound:.6f}')
print(f'JS Pinsker sqrt(JS) = {js_pinsker_bound:.6f}')
check_true('JS >= 0', js >= 0)
check_true('JS <= ln(2)', js <= np.log(2) + 1e-10)
check_true('Pinsker: TV <= sqrt(KL/2)', tv <= pinsker_bound + 1e-10)
check_true('JS Pinsker: TV <= sqrt(JS)', tv <= js_pinsker_bound + 1e-10)
check_true('JS Pinsker is tighter (smaller bound)', js_pinsker_bound <= pinsker_bound + 1e-10)
print('\nTakeaway: JS symmetric, bounded in [0,ln2], used in GANs; Pinsker links TV to KL')
Exercise 7 ★★★ — RLHF Optimal Policy and DPO
In RLHF the constrained optimisation problem is:
The analytic solution is the Gibbs / softmax policy: .
DPO eliminates the need to train a separate reward model: it shows that the implicit reward is .
Use a toy example with 5 candidate responses: r = [3.0, 1.5, 0.5, -0.5, -2.0], log_pi_ref = [-1.0, -1.5, -2.0, -2.5, -3.0], .
(a) Compute in log-space (numerically stable). Print the optimal policy.
(b) Compute and the expected reward under .
(c) Verify that satisfies the first-order condition: .
(d) Compute the DPO implicit reward . Show it ranks responses in the same order as (up to the constant ).
Code cell 23
# Your Solution
print("Write your solution here, then compare with the reference solution below.")
Code cell 24
# Solution
# Exercise 7: Solution
r = np.array([3.0, 1.5, 0.5, -0.5, -2.0])
log_ref = np.array([-1.0, -1.5, -2.0, -2.5, -3.0])
beta = 1.0
# (a) Optimal policy — numerically stable log-space computation
log_pi_unnorm = log_ref + r / beta
log_Z = np.log(np.sum(np.exp(log_pi_unnorm - log_pi_unnorm.max()))) + log_pi_unnorm.max()
log_pi_star = log_pi_unnorm - log_Z
pi_star = np.exp(log_pi_star)
pi_ref = np.exp(log_ref - np.log(np.exp(log_ref).sum()))
# (b)
kl_star_ref = kl_divergence(pi_star, pi_ref)
E_reward = float(np.sum(pi_star * r))
E_reward_ref= float(np.sum(pi_ref * r))
# (c) First-order condition
lhs = log_pi_star - log_ref # should equal r/beta - log_Z
rhs = r / beta - log_Z
# (d) DPO implicit reward
r_dpo = beta * (log_pi_star - log_ref) # = r - beta*log_Z (same ranking)
header('Exercise 7: RLHF Optimal Policy and DPO')
print('pi_ref (normalised):', pi_ref.round(4))
print('pi*: ', pi_star.round(4))
print(f'log Z = {log_Z:.4f}')
print(f'E[r] under pi_ref = {E_reward_ref:.4f}')
print(f'E[r] under pi* = {E_reward:.4f} (improved by {E_reward - E_reward_ref:.4f})')
print(f'KL(pi*||pi_ref) = {kl_star_ref:.4f}')
print(f'RLHF objective = {E_reward - beta*kl_star_ref:.4f}')
check_close('pi* sums to 1', pi_star.sum(), 1.0)
check_close('First-order condition: log pi* - log ref = r/beta - logZ', lhs, rhs)
check_true('DPO reward ranks same as r', np.all(np.argsort(r_dpo)[::-1] == np.argsort(r)[::-1]))
check_true('pi* concentrates on high-reward responses', pi_star[0] > pi_ref[0])
print('\nDPO implicit reward (beta * log ratio):', r_dpo.round(4))
print('True reward r: ', r.round(4))
print('\nTakeaway: RLHF optimal policy = reference * exp(r/beta); DPO uses this analytically')
Exercise 8 ★★★ — Knowledge Distillation
In knowledge distillation a small student network is trained to match a large teacher by minimising a KL-based loss.
Two common objectives:
| Objective | Formula | Behaviour |
|---|---|---|
| Forward KL | Mean-seeking; preserves soft labels | |
| Reverse KL | Mode-seeking; may ignore dark knowledge |
A teacher produces logits over 6 classes. A student is parameterised by its own 6 logits. We study how temperature affects the distillation KL.
Teacher logits: z_T = [4.0, 2.0, 1.0, 0.5, 0.0, -1.0]
(a) Compute = softmax(z_T, tau) for . Print entropy for each.
(b) For each , find the student logits z_S that minimise when = softmax(z_S, 1). (Hint: the minimum is attained at , i.e., . Verify this analytically.)
(c) Compute the KL loss at the optimal for each — it should equal 0.
(d) Now fix a suboptimal student with logits z_S_bad = [3.5, 1.8, 0.8, 0.4, 0.0, -0.6]. Compute for each . Explain why higher gives a larger or smaller KL.
(e) Which direction of KL should a practitioner use for distillation and why?
Code cell 26
# Your Solution
print("Write your solution here, then compare with the reference solution below.")
Code cell 27
# Solution
# Exercise 8: Solution
z_T = np.array([4.0, 2.0, 1.0, 0.5, 0.0, -1.0])
taus = [0.5, 1.0, 2.0, 5.0]
header('Exercise 8: Knowledge Distillation')
# (a) Teacher distributions and entropy
print('(a) Teacher softmax distributions:')
print("{:>6} | {:>8} | {:>10} | {}".format("tau", "H(p_T)", "max prob", "distribution"))
print('-' * 60)
teacher_dists = {}
for tau in taus:
p_tau = softmax(z_T, tau)
teacher_dists[tau] = p_tau
H_tau = -np.sum(p_tau * np.log(p_tau + 1e-300))
print(f"{tau:>6.1f} | {H_tau:>8.4f} | {p_tau.max():>10.4f} | {p_tau.round(3)}")
# (b) Optimal student: z_S = tau * z_T satisfies q_S = softmax(z_S,1) = softmax(z_T, tau)
print('\n(b-c) Optimal student: z_S = tau * z_T')
for tau in taus:
p_tau = teacher_dists[tau]
z_opt = tau * z_T # log p_tau is proportional to tau*z_T
q_opt = softmax(z_opt, 1.0)
kl_opt = kl_divergence(p_tau, q_opt)
print(f' tau={tau}: KL(p_T||q_opt) = {kl_opt:.2e} (should be ~0)')
check_true('Optimal student achieves KL~0 for all temps',
all(kl_divergence(softmax(z_T,t), softmax(t*z_T,1.0)) < 1e-10 for t in taus))
# (d) Suboptimal student
z_bad = np.array([3.5, 1.8, 0.8, 0.4, 0.0, -0.6])
q_bad = softmax(z_bad, 1.0)
print('\n(d) KL loss with suboptimal student:')
kl_vals = []
for tau in taus:
p_tau = teacher_dists[tau]
kl_bad = kl_divergence(p_tau, q_bad)
kl_vals.append(kl_bad)
print(f' tau={tau}: KL(p_T(tau)||q_bad) = {kl_bad:.6f}')
# Higher tau -> softer teacher -> easier for student to match -> smaller KL
kl_monotone_decreasing = kl_vals[0] > kl_vals[-1]
check_true('Higher tau => smaller KL loss (softer teacher easier to match)', kl_monotone_decreasing)
print('\n(e) Recommendation: use FORWARD KL = D_KL(p_T || q_student).')
print(' Forward KL forces q to cover all modes where p_T > 0,')
print(' preserving dark knowledge (non-argmax probabilities).')
print(' Reverse KL is mode-seeking: student ignores secondary modes.')
print('\nTakeaway: Higher temp softens teacher -> smaller KL gap; forward KL = dark knowledge')
Exercise 9: Temperature Distillation KL
Compare teacher and student distributions under a distillation temperature and compute the KL penalty used to train soft targets.
Code cell 29
# Your Solution
print("Compute KL between softened teacher and student probabilities.")
Code cell 30
# Solution
header("Exercise 9: Distillation KL")
teacher_logits = np.array([4.0, 1.5, -1.0])
student_logits = np.array([2.5, 1.0, 0.0])
T = 2.0
p_t = softmax(teacher_logits, tau=T)
p_s = softmax(student_logits, tau=T)
kl_ts = kl_divergence(p_t, p_s)
print("teacher:", np.round(p_t, 4))
print("student:", np.round(p_s, 4))
print("KL teacher||student:", round(kl_ts, 6))
check_true("KL is nonnegative", kl_ts >= -1e-12)
print("Takeaway: distillation transfers dark knowledge through the full softened distribution.")
Exercise 10: KL-Regularized Policy Update
Given a reference policy and rewards, compute the closed-form KL-regularized optimal policy \pi^*(a) \propto \pi_{ref}(a) e^{r(a)/eta}.
Code cell 32
# Your Solution
print("Compute the KL-regularized policy update.")
Code cell 33
# Solution
header("Exercise 10: KL-Regularized Policy")
pi_ref = np.array([0.5, 0.3, 0.2])
r = np.array([0.0, 1.0, 1.8])
beta = 0.7
unnorm = pi_ref * np.exp(r / beta)
pi_star = unnorm / unnorm.sum()
print("pi_ref:", np.round(pi_ref, 4))
print("pi_star:", np.round(pi_star, 4))
check_close("policy normalizes", pi_star.sum(), 1.0)
check_true("rewarded action gains mass", pi_star[-1] > pi_ref[-1])
print("Takeaway: RLHF-style KL control tilts a reference model toward high reward without abandoning it.")
Summary
| Exercise | Topic | Key Result |
|---|---|---|
| 1 ★ | KL definition | ; model wastes KL nats |
| 2 ★ | Gibbs' inequality | ; equality iff |
| 3 ★ | Forward vs reverse KL | Fwd = mass-covering; rev = mode-seeking |
| 4 ★★ | Chain rule | |
| 5 ★★ | Gaussian KL + VAE | ; gradient = |
| 6 ★★ | f-Divergences + Pinsker | ; JS tighter |
| 7 ★★★ | RLHF + DPO | ; DPO implicit reward |
| 8 ★★★ | Knowledge distillation | Forward KL preserves dark knowledge; higher softens teacher |
Next section: Mutual Information →