Private notes
0/8000

Notes stay private to your browser until account sync is configured.

Part 1
29 min read18 headingsSplit lesson page

Lesson overview | Lesson overview | Next part

KL Divergence: Part 1: Intuition to 5. KL for Specific Distributions

1. Intuition

1.1 What Is KL Divergence?

Suppose you are a weather forecaster. Every day, you predict a probability distribution over tomorrow's weather: 70% sunny, 20% cloudy, 10% rain. But the true distribution - what nature actually produces - is 50% sunny, 30% cloudy, 20% rain. How much has your forecast missed? Shannon entropy measures the uncertainty in a single distribution. KL divergence measures how much one distribution deviates from another.

Formally, DKL(pq)D_{\mathrm{KL}}(p \| q) is the expected extra number of bits (or nats, if using natural logarithms) you waste when you encode data actually generated from pp using a code optimized for qq. If you design a Huffman code for qq, the optimal code assigns short codewords to events that qq considers probable. But if the true distribution is pp, your code will be suboptimal whenever p(x)q(x)p(x) \ne q(x). The KL divergence precisely quantifies the average penalty:

DKL(pq)=xp(x)logp(x)q(x)D_{\mathrm{KL}}(p \| q) = \sum_x p(x) \log \frac{p(x)}{q(x)}

The ratio p(x)/q(x)p(x)/q(x) captures the mismatch event by event: events that pp considers much more likely than qq does (large p/qp/q) waste the most bits. The expectation under pp reflects that the penalty is averaged over what actually happens.

Three equivalent ways to read DKL(pq)D_{\mathrm{KL}}(p \| q):

  1. Coding view: Expected excess code length when coding pp-data with a qq-optimal code.
  2. Surprise view: How much more surprised you are on average by events under qq than under pp.
  3. Testing view: The expected log-likelihood ratio Ep[log(p(X)/q(X))]\mathbb{E}_p[\log(p(X)/q(X))] - the average evidence per sample that data came from pp rather than qq in a hypothesis test.

All three views are the same formula. The richness of KL divergence comes from this triple identity: it simultaneously measures compression inefficiency, distributional mismatch, and statistical distinguishability.

For AI: When a language model pθp_{\boldsymbol{\theta}} is trained on data drawn from pdatap_{\mathrm{data}}, the negative log-likelihood loss Expdata[logpθ(x)]-\mathbb{E}_{x \sim p_{\mathrm{data}}}[\log p_{\boldsymbol{\theta}}(x)] equals H(pdata)+DKL(pdatapθ)H(p_{\mathrm{data}}) + D_{\mathrm{KL}}(p_{\mathrm{data}} \| p_{\boldsymbol{\theta}}). Since H(pdata)H(p_{\mathrm{data}}) is a constant w.r.t. θ\boldsymbol{\theta}, minimizing cross-entropy loss is identical to minimizing DKL(pdatapθ)D_{\mathrm{KL}}(p_{\mathrm{data}} \| p_{\boldsymbol{\theta}}). Every gradient descent step that improves a language model is a step toward closing the KL gap between model and data distribution.

1.2 The Asymmetry Insight

The defining feature that most surprises newcomers is asymmetry: DKL(pq)D_{\mathrm{KL}}(p \| q) and DKL(qp)D_{\mathrm{KL}}(q \| p) are generally different, and this is not a flaw - it is a deeply meaningful distinction that determines which algorithm you should use.

To see why asymmetry arises, consider what each direction penalizes:

  • DKL(pq)D_{\mathrm{KL}}(p \| q): the expectation is under pp. Events where p(x)>0p(x) > 0 but q(x)0q(x) \approx 0 contribute hugely (the ratio p/qp/q blows up). This is called forward KL or inclusive KL: if pp puts mass somewhere, qq must also put mass there. qq must cover all of pp's support. Minimizing this forces qq to be "mean-seeking" - to match the whole spread of pp, even at multimodal distributions.

  • DKL(qp)D_{\mathrm{KL}}(q \| p): the expectation is now under qq. Events where q(x)>0q(x) > 0 but p(x)0p(x) \approx 0 contribute hugely. This is called reverse KL or exclusive KL: if qq puts mass somewhere, pp must also put mass there - so qq avoids regions where pp is small. Minimizing this forces qq to be "mode-seeking" - to concentrate on one mode of pp and ignore the others.

Concrete example: Let pp be a bimodal mixture 0.5N(3,1)+0.5N(3,1)0.5\,\mathcal{N}(-3,1) + 0.5\,\mathcal{N}(3,1) and let q=N(μ,σ2)q = \mathcal{N}(\mu, \sigma^2) be a unimodal Gaussian we fit to pp:

  • Minimizing DKL(pq)D_{\mathrm{KL}}(p \| q): qq must cover both modes, so qN(0,10)q \approx \mathcal{N}(0, 10) - wide, centered between the modes, placing mass in low-probability regions.
  • Minimizing DKL(qp)D_{\mathrm{KL}}(q \| p): qq collapses onto one mode, e.g., qN(3,1)q \approx \mathcal{N}(3, 1), ignoring the other mode entirely.

Neither approximation is "wrong" - they answer different questions. Which direction to use is one of the most important design decisions in probabilistic ML.

For AI: Variational inference for Bayesian neural networks minimizes reverse KL DKL(qp)D_{\mathrm{KL}}(q\|p) - the variational posterior qq collapses to one mode of the intractable true posterior pp. This is why Bayesian deep learning with mean-field VI tends to underestimate uncertainty (it ignores alternative modes). Maximum likelihood training of generative models minimizes forward KL DKL(pdatapθ)D_{\mathrm{KL}}(p_{\mathrm{data}}\|p_{\boldsymbol{\theta}}) - forcing the model to cover the entire data distribution, at the cost of sometimes generating blurry or average-looking samples.

1.3 Why KL Divergence Matters for AI

KL divergence is not one tool among many - it is the unifying framework behind most of the training objectives and alignment techniques in modern AI:

ApplicationWhich KLHow
Language model training (GPT, LLaMA, Claude)DKL(pdatapθ)D_{\mathrm{KL}}(p_{\mathrm{data}} \| p_{\boldsymbol{\theta}})Minimizing cross-entropy = minimizing forward KL
VAE regularizer (Kingma & Welling, 2014)DKL(qϕ(zx)p(z))D_{\mathrm{KL}}(q_\phi(\mathbf{z}\mid\mathbf{x}) \| p(\mathbf{z}))Reverse KL keeps encoder close to Gaussian prior
RLHF KL penalty (Christiano et al., 2017)DKL(πθπref)D_{\mathrm{KL}}(\pi_{\boldsymbol{\theta}} \| \pi_{\mathrm{ref}})Prevents reward hacking; preserves language coherence
PPO trust region (Schulman et al., 2017)Approx. DKL(πoldπnew)D_{\mathrm{KL}}(\pi_{\mathrm{old}} \| \pi_{\mathrm{new}})Clip objective approximates KL constraint
DPO (Rafailov et al., 2023)Implicit KLClosed-form solution has same fixed point as RLHF
Knowledge distillation (Hinton et al., 2015)DKL(pteacherpstudent)D_{\mathrm{KL}}(p_{\mathrm{teacher}} \| p_{\mathrm{student}})Forward KL to match teacher's full soft distribution
Variational inference (VI)DKL(qp)D_{\mathrm{KL}}(q \| p)Reverse KL for tractable posterior approximation
Normalizing flows (Rezende & Mohamed, 2015)DKL(pdatapθ)D_{\mathrm{KL}}(p_{\mathrm{data}} \| p_{\boldsymbol{\theta}})Forward KL via exact log-likelihood
Diffusion models (Ho et al., 2020)KL at each reverse stepELBO decomposes as sum of per-step KL terms
Differential privacy (Renyi DP)Renyi divergenceRenyi KL bounds privacy loss composition

The breadth of this table reflects a deep fact: KL divergence is the natural measure of distributional distance when you care about expected log-probability ratios, which is exactly what gradient-based optimization of probabilistic models does.

1.4 Historical Timeline

KL DIVERGENCE - KEY MILESTONES


  1951  Kullback & Leibler, "On Information and Sufficiency"
        Introduce D_KL as "information for discrimination" I(1:2)
        Prove non-negativity; connect to sufficiency of statistics

  1959  Kullback, "Information Theory and Statistics"
        Systematic treatment; connection to likelihood ratio tests

  1967  Csiszar, "Information-Type Measures of Difference"
        Generalizes to f-divergences; unifies KL, Hellinger, 2

  1972  Chernoff; Blahut, Arimoto
        Exponential decay of error rates in hypothesis testing
        Connection to Renyi divergence

  1985  Amari, "Differential-Geometric Methods in Statistics"
        Information geometry: KL as non-symmetric Bregman divergence
        Exponential and mixture geodesics; natural gradient

  1986  Hinton & Camp, "Keeping Neural Networks Simple"
        Minimum Description Length  KL regularization

  2013  Kingma & Welling, "Auto-Encoding Variational Bayes"
        VAE: ELBO = reconstruction - D_KL(q_  p)
        Reparameterization trick for gradient through KL

  2015  Hinton, Vinyals & Dean, "Distilling the Knowledge"
        Forward KL with temperature-softened teacher distribution

  2017  Christiano et al., "Deep Reinforcement Learning from Human Feedback"
        KL penalty in RLHF; controls deviation from reference policy

  2017  Schulman et al., "Proximal Policy Optimization"
        PPO-clip approximates KL trust region constraint

  2022  Rafailov et al., "Direct Preference Optimization"
        DPO: closed-form KL-constrained policy optimization

  2024+  All frontier LLMs (GPT-4, Claude, Gemini, LLaMA-3)
         Trained with cross-entropy (= forward KL) + RLHF KL penalty


2. Formal Definitions

2.1 Discrete KL Divergence

Definition (Kullback-Leibler Divergence - discrete). Let pp and qq be probability mass functions on a countable alphabet X\mathcal{X}. The KL divergence from qq to pp (also called the relative entropy of pp with respect to qq) is:

DKL(pq)=xXp(x)logp(x)q(x)D_{\mathrm{KL}}(p \| q) = \sum_{x \in \mathcal{X}} p(x) \log \frac{p(x)}{q(x)}

where:

  • The logarithm is natural (nats) by convention in this curriculum, matching the NOTATION_GUIDE. To convert to bits, divide by ln2\ln 2.
  • Boundary conventions: 0log(0/q)=00 \log(0/q) = 0 for any q0q \ge 0 (by continuity: limp0+plogp=0\lim_{p \to 0^+} p \log p = 0); plog(p/0)=+p \log(p/0) = +\infty for p>0p > 0 (code length is infinite if you cannot encode a symbol at all).
  • DKL(pq)=+D_{\mathrm{KL}}(p \| q) = +\infty whenever there exists xx with p(x)>0p(x) > 0 and q(x)=0q(x) = 0.

The notation convention in this repository (following Cover & Thomas 2006 and NOTATION_GUIDE Section 6): DKL(pq)D_{\mathrm{KL}}(p \| q) reads as "KL divergence from qq to pp." That is, pp is the reference (true) distribution and qq is the approximation. Some sources use the opposite convention; when reading papers, always check which argument is which.

Standard examples:

Example 2.1 (Binary distributions): Let p=(θ,1θ)p = (\theta, 1-\theta) and q=(q0,1q0)q = (q_0, 1-q_0) on {0,1}\{0,1\}. Then:

DKL(pq)=θlnθq0+(1θ)ln1θ1q0D_{\mathrm{KL}}(p \| q) = \theta \ln\frac{\theta}{q_0} + (1-\theta)\ln\frac{1-\theta}{1-q_0}

This is the log-likelihood ratio test statistic for a Bernoulli parameter hypothesis test.

Example 2.2 (Weather forecast): True distribution p=(0.5,0.3,0.2)p = (0.5, 0.3, 0.2) over {sunny, cloudy, rain}; forecast q=(0.7,0.2,0.1)q = (0.7, 0.2, 0.1):

DKL(pq)=0.5ln0.50.7+0.3ln0.30.2+0.2ln0.20.10.168+0.122+0.1390.093 natsD_{\mathrm{KL}}(p \| q) = 0.5\ln\frac{0.5}{0.7} + 0.3\ln\frac{0.3}{0.2} + 0.2\ln\frac{0.2}{0.1} \approx -0.168 + 0.122 + 0.139 \approx 0.093 \text{ nats}

The model wastes approximately 0.093 nats per observation. In the reverse direction: DKL(qp)0.097D_{\mathrm{KL}}(q \| p) \approx 0.097 nats - different but close, because the two distributions are not very far apart.

Example 2.3 (Categorical logits): In a classification model with softmax output q(k)=softmax(z)kq(k) = \text{softmax}(\mathbf{z})_k and one-hot label p(k)=1[k=y]p(k) = \mathbb{1}[k = y]:

DKL(pq)=logq(y)=zy+logkezkD_{\mathrm{KL}}(p \| q) = -\log q(y) = -z_y + \log\sum_k e^{z_k}

This is exactly the cross-entropy loss. (The entropy H(p)=0H(p) = 0 for one-hot pp, so H(p,q)=H(p)+DKL(pq)=DKL(pq)H(p,q) = H(p) + D_{\mathrm{KL}}(p\|q) = D_{\mathrm{KL}}(p\|q).)

2.2 Continuous KL Divergence

Definition (KL divergence - continuous). Let pp and qq be probability density functions on Rd\mathbb{R}^d with pp absolutely continuous with respect to qq (written pqp \ll q, meaning q(x)=0p(x)=0q(x) = 0 \Rightarrow p(x) = 0 a.e.). Then:

DKL(pq)=Rdp(x)logp(x)q(x)dx=Exp ⁣[logp(x)q(x)]D_{\mathrm{KL}}(p \| q) = \int_{\mathbb{R}^d} p(\mathbf{x}) \log \frac{p(\mathbf{x})}{q(\mathbf{x})} \, d\mathbf{x} = \mathbb{E}_{\mathbf{x} \sim p}\!\left[\log \frac{p(\mathbf{x})}{q(\mathbf{x})}\right]

The absolute continuity condition pqp \ll q is essential: it ensures the Radon-Nikodym derivative dp/dqdp/dq exists, making the ratio p(x)/q(x)p(\mathbf{x})/q(\mathbf{x}) well-defined a.e. under pp. When p≪̸qp \not\ll q, DKL(pq)=+D_{\mathrm{KL}}(p\|q) = +\infty by convention.

Measure-theoretic form. Using the Radon-Nikodym derivative dpdq\frac{dp}{dq}, the general definition that works for both discrete and continuous cases:

DKL(pq)=logdpdqdp=Ep ⁣[logdpdq]D_{\mathrm{KL}}(p \| q) = \int \log \frac{dp}{dq} \, dp = \mathbb{E}_p\!\left[\log \frac{dp}{dq}\right]

This unified form shows that discreteness vs continuity is just a choice of dominating measure.

2.3 Relative Entropy Interpretation

The phrase "relative entropy" is justified by a beautiful coding theorem. Suppose X1,X2,X_1, X_2, \ldots are i.i.d. from pp. You observe nn samples and design a code using distribution qq. The optimal code assigns length logq(x)-\log q(x) to symbol xx (by Shannon's source coding theorem). The expected code length for one symbol is Ep[logq(X)]=H(p,q)\mathbb{E}_p[-\log q(X)] = H(p, q) - the cross-entropy. But the optimal code for pp has expected length H(p)H(p). The excess code length per symbol is:

H(p,q)H(p)=xp(x)logq(x)+xp(x)logp(x)=xp(x)logp(x)q(x)=DKL(pq)H(p, q) - H(p) = -\sum_x p(x)\log q(x) + \sum_x p(x)\log p(x) = \sum_x p(x)\log\frac{p(x)}{q(x)} = D_{\mathrm{KL}}(p \| q)

This is exactly KL divergence. The relative entropy DKL(pq)D_{\mathrm{KL}}(p\|q) measures how many extra nats/bits per symbol you pay for using the wrong code qq when the truth is pp. It is relative because it measures information relative to the reference distribution pp.

Information gain interpretation. In Bayesian inference, if pp is the posterior and qq is the prior, then DKL(pq)D_{\mathrm{KL}}(p\|q) is the information gain from the prior to the posterior - the reduction in uncertainty after observing data. The Kullback-Leibler 1951 paper used the term "information for discrimination": how much information the data provides for discriminating between pp and qq.

2.4 Non-Examples and Edge Cases

Understanding when KL divergence is not well-behaved clarifies the definition.

Non-example 2.1 (Mismatched support, DKL=+D_{\mathrm{KL}} = +\infty): Let p=N(0,1)p = \mathcal{N}(0, 1) and q=U(1,1)q = \mathcal{U}(-1, 1) (uniform on [1,1][-1,1]). Since q(x)=0q(x) = 0 for x>1|x| > 1 but p(x)>0p(x) > 0 for all xx, we have p≪̸qp \not\ll q and DKL(pq)=+D_{\mathrm{KL}}(p\|q) = +\infty. However DKL(qp)<D_{\mathrm{KL}}(q\|p) < \infty since p(x)>0p(x) > 0 everywhere, so qpq \ll p. This illustrates the extreme asymmetry that can occur when supports differ.

Non-example 2.2 (Zero KL does not mean identical distributions numerically): DKL(pq)=0D_{\mathrm{KL}}(p\|q) = 0 implies p(x)=q(x)p(x) = q(x) for pp-almost all xx, but for distributions with different supports having measure zero under pp, the distributions can differ on a null set. In practice, if DKL=0D_{\mathrm{KL}} = 0 numerically, pp and qq are identical where it matters.

Non-example 2.3 (Equal KL values do not imply related distributions): DKL(pq)=DKL(rs)=0.5D_{\mathrm{KL}}(p\|q) = D_{\mathrm{KL}}(r\|s) = 0.5 nats gives no useful information about the relationship between (p,q)(p,q) and (r,s)(r,s). KL is an absolute quantity for a specific pair, not a geometry with reference points.

Non-example 2.4 (KL is not symmetric by default): For p=(0.9,0.1)p = (0.9, 0.1) and q=(0.1,0.9)q = (0.1, 0.9):

DKL(pq)=0.9ln0.90.1+0.1ln0.10.9=0.9ln9+0.1ln(1/9)1.758 natsD_{\mathrm{KL}}(p\|q) = 0.9\ln\frac{0.9}{0.1} + 0.1\ln\frac{0.1}{0.9} = 0.9\ln 9 + 0.1\ln(1/9) \approx 1.758 \text{ nats}

By symmetry of this example, DKL(qp)=DKL(pq)1.758D_{\mathrm{KL}}(q\|p) = D_{\mathrm{KL}}(p\|q) \approx 1.758 nats - equal only because pp and qq are exact reverses of each other. For generic asymmetric distributions, the two values differ.

3. Properties of KL Divergence

3.1 Non-Negativity: Gibbs' Inequality

Theorem (Gibbs' Inequality). For any two probability distributions pp and qq on the same alphabet:

DKL(pq)0D_{\mathrm{KL}}(p \| q) \ge 0

with equality if and only if p(x)=q(x)p(x) = q(x) for all xx (or pp-almost all xx in the continuous case).

Proof via Jensen's inequality. The function f(t)=lntf(t) = -\ln t is strictly convex on (0,)(0, \infty) (its second derivative 1/t2>01/t^2 > 0). Equivalently, lnt\ln t is strictly concave. By Jensen's inequality applied to the concave function ln\ln:

DKL(pq)=xp(x)lnq(x)p(x)ln ⁣(xp(x)q(x)p(x))=ln ⁣(xq(x))=ln1=0-D_{\mathrm{KL}}(p \| q) = \sum_x p(x)\ln\frac{q(x)}{p(x)} \le \ln\!\left(\sum_x p(x) \cdot \frac{q(x)}{p(x)}\right) = \ln\!\left(\sum_x q(x)\right) = \ln 1 = 0

Therefore DKL(pq)0D_{\mathrm{KL}}(p\|q) \ge 0. Equality holds in Jensen's inequality for a strictly concave function iff the argument is constant, i.e., q(x)/p(x)=cq(x)/p(x) = c for all xx with p(x)>0p(x) > 0. Since both pp and qq sum to 1, this forces c=1c = 1 and hence p=qp = q everywhere. \square

Alternative proof via lntt1\ln t \le t - 1. The inequality lntt1\ln t \le t - 1 (with equality iff t=1t = 1) implies:

DKL(pq)=xp(x)lnq(x)p(x)xp(x)(q(x)p(x)1)=xq(x)xp(x)=11=0-D_{\mathrm{KL}}(p \| q) = \sum_x p(x)\ln\frac{q(x)}{p(x)} \le \sum_x p(x)\left(\frac{q(x)}{p(x)} - 1\right) = \sum_x q(x) - \sum_x p(x) = 1 - 1 = 0

This proof is elementary (needs no Jensen's inequality) and directly shows the equality condition.

Corollary (Entropy upper bound). Taking q=uq = u (uniform on X\mathcal{X} with X=n|\mathcal{X}| = n):

0DKL(pu)=xp(x)lnp(x)1/n=lnnH(p)0 \le D_{\mathrm{KL}}(p \| u) = \sum_x p(x)\ln\frac{p(x)}{1/n} = \ln n - H(p)

Therefore H(p)lnn=H(u)H(p) \le \ln n = H(u). Gibbs' inequality is the reason entropy is maximized at the uniform distribution (proven rigorously in Section 09-01).

For AI: Non-negativity of KL underpins the ELBO. The VAE lower bound comes from:

logp(x)=L(ϕ,θ;x)+DKL(qϕ(zx)pθ(zx))\log p(\mathbf{x}) = \mathcal{L}(\boldsymbol{\phi}, \boldsymbol{\theta}; \mathbf{x}) + D_{\mathrm{KL}}(q_\phi(\mathbf{z}\mid\mathbf{x}) \| p_\theta(\mathbf{z}\mid\mathbf{x}))

Since DKL0D_{\mathrm{KL}} \ge 0, we have logp(x)L\log p(\mathbf{x}) \ge \mathcal{L} - the ELBO is indeed a lower bound on log-evidence. This is the fundamental inequality that makes variational inference tractable.

3.2 Asymmetry

DKL(pq)DKL(qp)D_{\mathrm{KL}}(p \| q) \ne D_{\mathrm{KL}}(q \| p) in general. KL divergence is not a distance in the metric sense. We have already seen this intuitively; here we quantify it.

Example 3.2: p=(0.8,0.1,0.1)p = (0.8, 0.1, 0.1), q=(0.1,0.8,0.1)q = (0.1, 0.8, 0.1):

DKL(pq)=0.8ln8+0.1ln(1/8)+0.1ln1=0.8(2.079)+0.1(2.079)+0=1.455 natsD_{\mathrm{KL}}(p \| q) = 0.8\ln 8 + 0.1\ln(1/8) + 0.1\ln 1 = 0.8(2.079) + 0.1(-2.079) + 0 = 1.455 \text{ nats} DKL(qp)=0.1ln(0.1/0.8)+0.8ln8+0.1ln1=0.1(2.079)+0.8(2.079)+0=1.455 natsD_{\mathrm{KL}}(q \| p) = 0.1\ln(0.1/0.8) + 0.8\ln 8 + 0.1\ln 1 = 0.1(-2.079) + 0.8(2.079) + 0 = 1.455 \text{ nats}

In this symmetric case they are equal, but for p=(0.9,0.05,0.05)p = (0.9, 0.05, 0.05), q=(0.05,0.9,0.05)q = (0.05, 0.9, 0.05):

DKL(pq)=0.9ln18+0.05ln(0.05/0.9)+0.05ln12.629+(0.144)+0=2.485 natsD_{\mathrm{KL}}(p \| q) = 0.9\ln 18 + 0.05\ln(0.05/0.9) + 0.05\ln 1 \approx 2.629 + (-0.144) + 0 = 2.485 \text{ nats} DKL(qp)=0.05ln(0.05/0.9)+0.9ln18+0.05ln10.144+2.629+0=2.485 natsD_{\mathrm{KL}}(q \| p) = 0.05\ln(0.05/0.9) + 0.9\ln 18 + 0.05\ln 1 \approx -0.144 + 2.629 + 0 = 2.485 \text{ nats}

Both happen to be equal again due to the symmetric structure. For a genuinely asymmetric example: p=(0.99,0.01)p = (0.99, 0.01), q=(0.01,0.99)q = (0.01, 0.99):

DKL(pq)=0.99ln99+0.01ln(0.01/0.99)0.99(4.595)+0.01(4.595)=4.549 natsD_{\mathrm{KL}}(p\|q) = 0.99\ln 99 + 0.01\ln(0.01/0.99) \approx 0.99(4.595) + 0.01(-4.595) = 4.549 \text{ nats}

and DKL(qp)=DKL(pq)4.549D_{\mathrm{KL}}(q\|p) = D_{\mathrm{KL}}(p\|q) \approx 4.549 nats (by symmetry of this particular pair).

For a non-symmetric example: p=(0.9,0.1)p = (0.9, 0.1), q=(0.5,0.5)q = (0.5, 0.5):

DKL(pq)=0.9ln(1.8)+0.1ln(0.2)=0.9(0.588)+0.1(1.609)0.368 natsD_{\mathrm{KL}}(p\|q) = 0.9\ln(1.8) + 0.1\ln(0.2) = 0.9(0.588) + 0.1(-1.609) \approx 0.368 \text{ nats} DKL(qp)=0.5ln(5/9)+0.5ln5=0.5(0.588)+0.5(1.609)0.511 natsD_{\mathrm{KL}}(q\|p) = 0.5\ln(5/9) + 0.5\ln 5 = 0.5(-0.588) + 0.5(1.609) \approx 0.511 \text{ nats}

So DKL(pq)=0.3680.511=DKL(qp)D_{\mathrm{KL}}(p\|q) = 0.368 \ne 0.511 = D_{\mathrm{KL}}(q\|p).

Jeffrey's symmetrized KL. Harold Jeffreys (1946) proposed the symmetric version:

J(p,q)=12[DKL(pq)+DKL(qp)]J(p, q) = \frac{1}{2}\left[D_{\mathrm{KL}}(p \| q) + D_{\mathrm{KL}}(q \| p)\right]

This satisfies J(p,q)=J(q,p)0J(p,q) = J(q,p) \ge 0 and J(p,q)=0p=qJ(p,q) = 0 \Leftrightarrow p = q, but still fails the triangle inequality. Jeffreys divergence is used in some applications where symmetry is important. The Jensen-Shannon divergence (Section 7.4) is a better-behaved symmetric variant that is also bounded.

3.3 Failure of Triangle Inequality

KL divergence is not a metric. Specifically, the triangle inequality DKL(pr)DKL(pq)+DKL(qr)D_{\mathrm{KL}}(p\|r) \le D_{\mathrm{KL}}(p\|q) + D_{\mathrm{KL}}(q\|r) does not hold in general.

Counterexample. Let p=(1,0)p = (1, 0), q=(0.5,0.5)q = (0.5, 0.5), r=(0,1)r = (0, 1) on {0,1}\{0, 1\}:

DKL(pr)=1ln(1/0)=+D_{\mathrm{KL}}(p\|r) = 1 \cdot \ln(1/0) = +\infty DKL(pq)=1ln(1/0.5)=ln20.693 natsD_{\mathrm{KL}}(p\|q) = 1 \cdot \ln(1/0.5) = \ln 2 \approx 0.693 \text{ nats} DKL(qr)=0.5ln(0.5/0)+0.5ln(0.5/1)=+D_{\mathrm{KL}}(q\|r) = 0.5\ln(0.5/0) + 0.5\ln(0.5/1) = +\infty

In this degenerate case DKL(pr)=+D_{\mathrm{KL}}(p\|r) = +\infty and DKL(pq)+DKL(qr)=+D_{\mathrm{KL}}(p\|q) + D_{\mathrm{KL}}(q\|r) = +\infty, so the triangle inequality "holds" trivially. For a cleaner counterexample with all finite values: p=(0.9,0.1)p = (0.9, 0.1), q=(0.5,0.5)q = (0.5, 0.5), r=(0.1,0.9)r = (0.1, 0.9):

DKL(pr)=0.9ln9+0.1ln(1/9)1.758 natsD_{\mathrm{KL}}(p\|r) = 0.9\ln 9 + 0.1\ln(1/9) \approx 1.758 \text{ nats} DKL(pq)+DKL(qr)0.368+0.368=0.736 natsD_{\mathrm{KL}}(p\|q) + D_{\mathrm{KL}}(q\|r) \approx 0.368 + 0.368 = 0.736 \text{ nats}

Since 1.758>0.7361.758 > 0.736, the triangle inequality is violated. The reason: KL "skips" over intermediate distributions in a fundamentally non-Euclidean way.

3.4 Joint Convexity

Theorem. DKL(pq)D_{\mathrm{KL}}(p \| q) is jointly convex in the pair (p,q)(p, q): for any λ[0,1]\lambda \in [0,1],

DKL(λp1+(1λ)p2λq1+(1λ)q2)λDKL(p1q1)+(1λ)DKL(p2q2)D_{\mathrm{KL}}(\lambda p_1 + (1-\lambda)p_2 \| \lambda q_1 + (1-\lambda)q_2) \le \lambda D_{\mathrm{KL}}(p_1 \| q_1) + (1-\lambda) D_{\mathrm{KL}}(p_2 \| q_2)

Proof sketch. The perspective function of a convex function ff is jointly convex: if ff is convex, then g(x,t)=tf(x/t)g(x, t) = tf(x/t) is jointly convex in (x,t)(x, t) for t>0t > 0. Taking f(u)=uloguf(u) = u\log u (which is convex since (ulogu)=1/u>0(u\log u)'' = 1/u > 0), the perspective is t(x/t)log(x/t)=xlog(x/t)t \cdot (x/t)\log(x/t) = x\log(x/t). Summing over xx gives DKL(pq)=xp(x)log(p(x)/q(x))D_{\mathrm{KL}}(p\|q) = \sum_x p(x)\log(p(x)/q(x)), which is jointly convex in (p,q)(p, q) as a sum of perspective functions.

Consequence for the EM algorithm. The EM algorithm alternates between computing the expected log-likelihood (E-step) and maximizing it (M-step). Each M-step minimizes a KL divergence. Joint convexity guarantees that this alternating minimization makes well-defined progress toward a stationary point.

Convexity in pp alone. DKL(pq)D_{\mathrm{KL}}(p\|q) is also convex in pp for fixed qq (since it equals a sum of convex functions p(x)logp(x)p(x)logq(x)p(x)\log p(x) - p(x)\log q(x)).

Convexity in qq alone. DKL(pq)=xp(x)logq(x)+constD_{\mathrm{KL}}(p\|q) = -\sum_x p(x)\log q(x) + \text{const} for fixed pp. This is convex in qq because logq(x)-\log q(x) is convex and p(x)0p(x) \ge 0. This means: for fixed pp, minimizing over qq is a convex optimization problem with a unique minimum at q=pq = p.

3.5 Chain Rule for KL Divergence

Theorem (Chain Rule). For joint distributions P(X,Y)P(X,Y) and Q(X,Y)Q(X,Y):

DKL(P(X,Y)Q(X,Y))=DKL(PXQX)+ExPX ⁣[DKL(PYX=xQYX=x)]D_{\mathrm{KL}}(P(X,Y) \| Q(X,Y)) = D_{\mathrm{KL}}(P_X \| Q_X) + \mathbb{E}_{x \sim P_X}\!\left[D_{\mathrm{KL}}(P_{Y|X=x} \| Q_{Y|X=x})\right]

where PXP_X is the marginal of PP over XX, and PYXP_{Y|X}, QYXQ_{Y|X} are the conditional distributions.

Proof.

DKL(PQ)=x,yP(x,y)logP(x,y)Q(x,y)D_{\mathrm{KL}}(P \| Q) = \sum_{x,y} P(x,y)\log\frac{P(x,y)}{Q(x,y)}

Using P(x,y)=P(x)P(yx)P(x,y) = P(x)P(y|x) and Q(x,y)=Q(x)Q(yx)Q(x,y) = Q(x)Q(y|x):

=x,yP(x,y)logP(x)P(yx)Q(x)Q(yx)=x,yP(x,y)logP(x)Q(x)+x,yP(x,y)logP(yx)Q(yx)= \sum_{x,y} P(x,y)\log\frac{P(x)P(y|x)}{Q(x)Q(y|x)} = \sum_{x,y} P(x,y)\log\frac{P(x)}{Q(x)} + \sum_{x,y} P(x,y)\log\frac{P(y|x)}{Q(y|x)} =xP(x)logP(x)Q(x)+xP(x)yP(yx)logP(yx)Q(yx)= \sum_x P(x)\log\frac{P(x)}{Q(x)} + \sum_x P(x)\sum_y P(y|x)\log\frac{P(y|x)}{Q(y|x)} =DKL(PXQX)+ExPX ⁣[DKL(PYX=xQYX=x)]= D_{\mathrm{KL}}(P_X \| Q_X) + \mathbb{E}_{x \sim P_X}\!\left[D_{\mathrm{KL}}(P_{Y|X=x} \| Q_{Y|X=x})\right] \quad \square

For AI: In a hierarchical generative model p(x,z)=p(z)p(xz)p(\mathbf{x}, \mathbf{z}) = p(\mathbf{z}) p(\mathbf{x}|\mathbf{z}), the chain rule decomposes the KL between the variational distribution and true posterior into a prior KL plus an expected conditional KL. This is exactly the ELBO decomposition:

DKL(q(zx)p(zx))=DKL(q(zx)p(z))Eq[logp(xz)]+logp(x)D_{\mathrm{KL}}(q(\mathbf{z}|\mathbf{x}) \| p(\mathbf{z}|\mathbf{x})) = D_{\mathrm{KL}}(q(\mathbf{z}|\mathbf{x}) \| p(\mathbf{z})) - \mathbb{E}_{q}\left[\log p(\mathbf{x}|\mathbf{z})\right] + \log p(\mathbf{x})

Rearranging: logp(x)=Eq[logp(xz)]DKL(qp)ELBO+DKL(q(zx)p(zx))ELBO\log p(\mathbf{x}) = \underbrace{\mathbb{E}_q[\log p(\mathbf{x}|\mathbf{z})] - D_{\mathrm{KL}}(q\|p)}_{\text{ELBO}} + D_{\mathrm{KL}}(q(\mathbf{z}|\mathbf{x})\|p(\mathbf{z}|\mathbf{x})) \ge \text{ELBO}.

3.6 Data Processing Inequality

Theorem (Data Processing Inequality for KL). Let T:XYT: \mathcal{X} \to \mathcal{Y} be any (possibly randomized) function. Let pTp_T and qTq_T be the distributions of T(X)T(X) under pp and qq respectively. Then:

DKL(pTqT)DKL(pq)D_{\mathrm{KL}}(p_T \| q_T) \le D_{\mathrm{KL}}(p \| q)

Proof sketch. By the chain rule applied to the joint (X,T(X))(X, T(X)):

DKL(p(X,T(X))q(X,T(X)))=DKL(pXqX)+E[DKL(pTXqTX)]D_{\mathrm{KL}}(p(X, T(X)) \| q(X, T(X))) = D_{\mathrm{KL}}(p_X \| q_X) + \mathbb{E}\left[D_{\mathrm{KL}}(p_{T|X}\|q_{T|X})\right]

For a deterministic function TT, the conditional DKL(pTXqTX)=0D_{\mathrm{KL}}(p_{T|X}\|q_{T|X}) = 0. By the reverse chain rule decomposing via TT first, DKL(pTqT)+E[DKL(pXTqXT)]=DKL(pXqX)D_{\mathrm{KL}}(p_T\|q_T) + \mathbb{E}[D_{\mathrm{KL}}(p_{X|T}\|q_{X|T})] = D_{\mathrm{KL}}(p_X\|q_X). Since the second term is 0\ge 0, DKL(pTqT)DKL(pXqX)D_{\mathrm{KL}}(p_T\|q_T) \le D_{\mathrm{KL}}(p_X\|q_X). \square

Intuition: Processing data can only destroy information, not create it. Two distributions that are DD nats apart look at most DD nats apart after transformation. If you run a neural network encoder on samples from two distributions, the KL in embedding space is at most the KL in input space.

For AI: This is why representation learning cannot increase discriminability between classes beyond what is in the raw data. It also explains why the KL penalty in RLHF acts on the token distribution: any post-processing of the output tokens (sampling, filtering) can only reduce the KL from the reference policy, not increase it.

4. Forward KL vs Reverse KL

This section develops the most practically important distinction in applied KL theory: which direction to minimize, and what kind of approximation each direction produces.

4.1 Forward KL: Mean-Seeking Behavior

Forward KL is DKL(pq)D_{\mathrm{KL}}(p \| q) where pp is the target distribution (the truth) and qq is the approximating distribution we optimize. We minimize over qq given fixed pp.

The zero-avoiding (mass-covering) property: Because the expectation is under pp, regions where p(x)>0p(x) > 0 but q(x)=0q(x) = 0 contribute ++\infty to DKL(pq)D_{\mathrm{KL}}(p\|q). Therefore, any minimizer qq^* must satisfy q(x)>0q^*(x) > 0 wherever p(x)>0p(x) > 0 - the approximation must cover all the mass of the target. For this reason, forward KL is also called inclusive KL or zero-avoiding KL.

Fitting a Gaussian to a bimodal distribution. Let p(x)=0.5N(x;3,1)+0.5N(x;3,1)p(x) = 0.5\,\mathcal{N}(x;-3,1) + 0.5\,\mathcal{N}(x;3,1) and qμ(x)=N(x;μ,σ2)q_\mu(x) = \mathcal{N}(x;\mu,\sigma^2). Minimizing DKL(pqμ)D_{\mathrm{KL}}(p\|q_\mu) over μ\mu and σ2\sigma^2:

μDKL(pqμ)=0    μ=Ep[X]=0.5(3)+0.5(3)=0\frac{\partial}{\partial \mu} D_{\mathrm{KL}}(p \| q_\mu) = 0 \implies \mu^* = \mathbb{E}_p[X] = 0.5(-3) + 0.5(3) = 0 σ2=Ep[(Xμ)2]=0.5(9+1)+0.5(9+1)=10\sigma^{*2} = \mathbb{E}_p[(X - \mu^*)^2] = 0.5(9 + 1) + 0.5(9 + 1) = 10

So q=N(0,10)q^* = \mathcal{N}(0, 10) - a wide Gaussian centered between the modes, placing substantial mass in the valley between them where pp is near zero. This is the mean-seeking behavior: qq computes the mean and variance of pp (moment matching).

General result. For any exponential family qq with sufficient statistics t(x)\mathbf{t}(x), minimizing forward KL performs moment matching: Eq[t(X)]=Ep[t(X)]\mathbb{E}_q[\mathbf{t}(X)] = \mathbb{E}_p[\mathbf{t}(X)]. The approximation matches the moments of pp, which averages across modes.

For AI: Maximum likelihood estimation minimizes forward KL. This is why MLE on multimodal data can produce "blurry" generative models that average over modes - GANs were introduced partly to address this by using a minimax objective related to Jensen-Shannon divergence (Section 7.4) instead of direct MLE.

4.2 Reverse KL: Mode-Seeking Behavior

Reverse KL is DKL(qp)D_{\mathrm{KL}}(q \| p) where qq is the approximation we optimize and pp is fixed. Now the expectation is under qq - we penalize regions where q(x)>0q(x) > 0 but p(x)0p(x) \approx 0.

The zero-forcing (mass-concentrating) property: Regions where p(x)=0p(x) = 0 but q(x)>0q(x) > 0 contribute ++\infty to DKL(qp)D_{\mathrm{KL}}(q\|p). Therefore qq^* must avoid regions where pp is zero - it must concentrate its mass on high-probability regions of pp. This is called exclusive KL or zero-forcing KL.

Fitting a Gaussian to a bimodal distribution (reverse direction). Now minimizing DKL(qμp)D_{\mathrm{KL}}(q_\mu \| p):

DKL(qμp)=Eqμ ⁣[logqμ(X)p(X)]D_{\mathrm{KL}}(q_\mu \| p) = \mathbb{E}_{q_\mu}\!\left[\log\frac{q_\mu(X)}{p(X)}\right]

This has two local minima: qN(3,1)q^* \approx \mathcal{N}(-3, 1) or qN(3,1)q^* \approx \mathcal{N}(3, 1) - the approximation collapses onto one mode of pp. The wide, mean-covering solution N(0,10)\mathcal{N}(0, 10) from forward KL is actually a local maximum for reverse KL: it places mass where pp is near zero (in the valley), incurring large penalties.

General result. Reverse KL drives the approximation qq to be a mode of pp. The gradient of reverse KL points toward locally consistent regions of pp.

4.3 Geometric Interpretation

FORWARD KL vs REVERSE KL: THE KEY DIFFERENCE


  TRUE DISTRIBUTION p (bimodal):        FORWARD KL (minimizer q*)
                                         must cover both modes:
    
                                   q* = N(0, 10)
                                    
                           (wide, mean-seeking)
    ->         places mass in valley
       -3          3          
    

  REVERSE KL (minimizer q*):
  q* = N(-3, 1) OR N(3, 1)
       mode-seeking: picks one mode,
       ignores the other

  Intuition: forward KL sees regions p > 0 as "must cover"
             reverse KL sees regions p  0 as "must avoid"


The information-geometric view formalizes this with the concept of projections:

  • I-projection (information projection): q=argminqDKL(qp)q^* = \arg\min_q D_{\mathrm{KL}}(q \| p) - projects pp onto the constraint family Q\mathcal{Q} using reverse KL. Selects the qQq \in \mathcal{Q} that is closest to pp in the reverse KL sense.
  • M-projection (moment projection): q=argminqDKL(pq)q^* = \arg\min_q D_{\mathrm{KL}}(p \| q) - projects pp onto Q\mathcal{Q} using forward KL. Produces moment-matched approximations.

For exponential families: the M-projection always produces a unique solution (the moment-matched distribution). The I-projection may have multiple local minima (one per mode).

4.4 Consequences for Variational Inference

Variational Bayes minimizes DKL(q(z)p(zx))D_{\mathrm{KL}}(q(\mathbf{z}) \| p(\mathbf{z} | \mathbf{x})) - reverse KL. This is computationally convenient (the ELBO is a tractable lower bound) but has important consequences:

Posterior collapse in VAEs. In a Variational Autoencoder, each latent dimension ziz_i has approximate posterior qϕ(zix)=N(μi,σi2)q_\phi(z_i | \mathbf{x}) = \mathcal{N}(\mu_i, \sigma_i^2) and prior p(zi)=N(0,1)p(z_i) = \mathcal{N}(0, 1). If the decoder is powerful enough to reconstruct x\mathbf{x} using only a subset of dimensions, the ELBO optimization drives the unused dimensions' KL term to zero - meaning qϕ(zix)p(zi)=N(0,1)q_\phi(z_i|\mathbf{x}) \to p(z_i) = \mathcal{N}(0,1) - a constant prior. This posterior collapse means the latent representation ignores the input for those dimensions.

Why reverse KL causes collapse: Reverse KL DKL(qpprior)D_{\mathrm{KL}}(q\|p_{\mathrm{prior}}) is minimized (= 0) when q=ppriorq = p_{\mathrm{prior}}. Forward KL would penalize this collapse heavily because DKL(pposteriorq)D_{\mathrm{KL}}(p_{\mathrm{posterior}}\|q) would be large if qq ignores the data. The zero-forcing property of reverse KL allows collapse; the zero-avoiding property of forward KL would prevent it.

Beta-VAE fix (Higgins et al., 2017). Adding a hyperparameter β>1\beta > 1 to weight the KL term:

Lβ=Eq[logp(xz)]βDKL(qϕ(zx)p(z))\mathcal{L}_\beta = \mathbb{E}_q[\log p(\mathbf{x}|\mathbf{z})] - \beta\, D_{\mathrm{KL}}(q_\phi(\mathbf{z}|\mathbf{x}) \| p(\mathbf{z}))

Higher β\beta increases the penalty for unused latents, encouraging disentangled representations. This is used in modern VAE-based image models and speech representations.

KL annealing. Starting with β=0\beta = 0 and gradually increasing it during training prevents early collapse. Used in language model VAEs (Bowman et al., 2016) to avoid the degenerate solution where the encoder is ignored.

5. KL for Specific Distributions

5.1 KL Between Gaussians

The KL divergence between Gaussian distributions has a closed form used in dozens of ML algorithms.

Scalar Gaussians. For p=N(μ1,σ12)p = \mathcal{N}(\mu_1, \sigma_1^2) and q=N(μ2,σ22)q = \mathcal{N}(\mu_2, \sigma_2^2):

DKL(pq)=logσ2σ1+σ12+(μ1μ2)22σ2212D_{\mathrm{KL}}(p \| q) = \log\frac{\sigma_2}{\sigma_1} + \frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} - \frac{1}{2}

Derivation. The log-ratio is:

logp(x)q(x)=logσ2σ1(xμ1)22σ12+(xμ2)22σ22\log\frac{p(x)}{q(x)} = \log\frac{\sigma_2}{\sigma_1} - \frac{(x-\mu_1)^2}{2\sigma_1^2} + \frac{(x-\mu_2)^2}{2\sigma_2^2}

Taking expectation under pp, using Ep[(Xμ1)2]=σ12\mathbb{E}_p[(X-\mu_1)^2] = \sigma_1^2 and Ep[(Xμ2)2]=σ12+(μ1μ2)2\mathbb{E}_p[(X-\mu_2)^2] = \sigma_1^2 + (\mu_1-\mu_2)^2:

DKL(pq)=logσ2σ1σ122σ12+σ12+(μ1μ2)22σ22=logσ2σ1+σ12+(μ1μ2)22σ2212D_{\mathrm{KL}}(p\|q) = \log\frac{\sigma_2}{\sigma_1} - \frac{\sigma_1^2}{2\sigma_1^2} + \frac{\sigma_1^2 + (\mu_1-\mu_2)^2}{2\sigma_2^2} = \log\frac{\sigma_2}{\sigma_1} + \frac{\sigma_1^2 + (\mu_1-\mu_2)^2}{2\sigma_2^2} - \frac{1}{2}

Intuition on terms:

  • log(σ2/σ1)\log(\sigma_2/\sigma_1): penalty for scale mismatch
  • σ12/(2σ22)\sigma_1^2/(2\sigma_2^2): penalty for pp being wider than qq
  • (μ1μ2)2/(2σ22)(\mu_1 - \mu_2)^2/(2\sigma_2^2): penalty for mean mismatch (like Mahalanobis distance)
  • 1/2-1/2: normalization constant

Multivariate Gaussians. For p=N(μ1,Σ1)p = \mathcal{N}(\boldsymbol{\mu}_1, \Sigma_1) and q=N(μ2,Σ2)q = \mathcal{N}(\boldsymbol{\mu}_2, \Sigma_2):

DKL(pq)=12[tr(Σ21Σ1)+(μ2μ1)Σ21(μ2μ1)d+logdetΣ2detΣ1]D_{\mathrm{KL}}(p \| q) = \frac{1}{2}\left[\operatorname{tr}(\Sigma_2^{-1}\Sigma_1) + (\boldsymbol{\mu}_2 - \boldsymbol{\mu}_1)^\top \Sigma_2^{-1}(\boldsymbol{\mu}_2 - \boldsymbol{\mu}_1) - d + \log\frac{\det\Sigma_2}{\det\Sigma_1}\right]

where dd is the dimension. The four terms correspond to: trace ratio (variance mismatch), Mahalanobis distance (mean mismatch), dimension offset, log-determinant ratio (volume mismatch).

VAE application. In a standard VAE, the encoder outputs qϕ(zx)=N(μϕ(x),diag(σϕ2(x)))q_\phi(\mathbf{z}|\mathbf{x}) = \mathcal{N}(\boldsymbol{\mu}_\phi(\mathbf{x}), \operatorname{diag}(\boldsymbol{\sigma}_\phi^2(\mathbf{x}))) and the prior is p(z)=N(0,I)p(\mathbf{z}) = \mathcal{N}(\mathbf{0}, I). The KL term simplifies because Σ2=I\Sigma_2 = I:

DKL(qϕ(zx)p(z))=12j=1d(μj2+σj2lnσj21)D_{\mathrm{KL}}(q_\phi(\mathbf{z}|\mathbf{x}) \| p(\mathbf{z})) = \frac{1}{2}\sum_{j=1}^d \left(\mu_j^2 + \sigma_j^2 - \ln\sigma_j^2 - 1\right)

This is differentiable in μϕ\boldsymbol{\mu}_\phi and σϕ\boldsymbol{\sigma}_\phi, enabling direct gradient descent. The reparameterization trick z=μϕ+σϕϵ\mathbf{z} = \boldsymbol{\mu}_\phi + \boldsymbol{\sigma}_\phi \odot \boldsymbol{\epsilon}, ϵN(0,I)\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0},I) makes the sampling step differentiable.

5.2 KL Between Categorical Distributions

For discrete distributions p=(p1,,pK)p = (p_1, \ldots, p_K) and q=(q1,,qK)q = (q_1, \ldots, q_K):

DKL(pq)=k=1KpklnpkqkD_{\mathrm{KL}}(p \| q) = \sum_{k=1}^K p_k \ln\frac{p_k}{q_k}

This equals H(p,q)H(p)H(p, q) - H(p) where H(p,q)=kpklnqkH(p,q) = -\sum_k p_k \ln q_k is the cross-entropy. In particular, when pp is a one-hot vector (empirical distribution on a single label yy), H(p)=0H(p) = 0 and:

DKL(pq)=H(p,q)=lnqy=cross-entropy lossD_{\mathrm{KL}}(p \| q) = H(p, q) = -\ln q_y = \text{cross-entropy loss}

Temperature scaling. In knowledge distillation, the teacher distribution is softened using temperature τ>1\tau > 1:

psoft(k)=ezkT/τjezjT/τp_{\mathrm{soft}}(k) = \frac{e^{z_k^T/\tau}}{\sum_j e^{z_j^T/\tau}}

The student is trained to minimize DKL(psoftqθ)D_{\mathrm{KL}}(p_{\mathrm{soft}} \| q_{\boldsymbol{\theta}}) rather than the one-hot cross-entropy. Soft targets carry more information (entropy scales up with τ\tau), allowing the student to learn the teacher's "dark knowledge" about relative class similarities.

5.3 KL Within the Exponential Family

Members of the exponential family have a remarkably elegant KL formula. An exponential family distribution has density:

pη(x)=h(x)exp ⁣(ηt(x)A(η))p_{\boldsymbol{\eta}}(\mathbf{x}) = h(\mathbf{x}) \exp\!\left(\boldsymbol{\eta}^\top \mathbf{t}(\mathbf{x}) - A(\boldsymbol{\eta})\right)

where η\boldsymbol{\eta} are natural parameters, t(x)\mathbf{t}(\mathbf{x}) are sufficient statistics, and A(η)=logh(x)eηt(x)dxA(\boldsymbol{\eta}) = \log \int h(\mathbf{x})e^{\boldsymbol{\eta}^\top \mathbf{t}(\mathbf{x})} d\mathbf{x} is the log-partition function (which is convex in η\boldsymbol{\eta}).

KL between exponential family members:

DKL(pη1pη2)=A(η2)A(η1)A(η1)(η2η1)D_{\mathrm{KL}}(p_{\boldsymbol{\eta}_1} \| p_{\boldsymbol{\eta}_2}) = A(\boldsymbol{\eta}_2) - A(\boldsymbol{\eta}_1) - \nabla A(\boldsymbol{\eta}_1)^\top(\boldsymbol{\eta}_2 - \boldsymbol{\eta}_1)

This is the Bregman divergence generated by the convex function A(η)A(\boldsymbol{\eta}): BA(η2,η1)=A(η2)A(η1)A(η1)(η2η1)B_A(\boldsymbol{\eta}_2, \boldsymbol{\eta}_1) = A(\boldsymbol{\eta}_2) - A(\boldsymbol{\eta}_1) - \nabla A(\boldsymbol{\eta}_1)^\top(\boldsymbol{\eta}_2 - \boldsymbol{\eta}_1). This is the error of the first-order Taylor approximation of AA around η1\boldsymbol{\eta}_1 - which is always 0\ge 0 by convexity of AA, recovering Gibbs' inequality.

Examples:

  • Gaussian (μ\mu, σ2\sigma^2): A(μ,σ2)=μ2/(2σ2)+12lnσ2A(\mu,\sigma^2) = \mu^2/(2\sigma^2) + \frac{1}{2}\ln\sigma^2
  • Bernoulli (pp): A(p)=ln(1+ep)A(p) = \ln(1 + e^p) (logistic); Bregman divergence = binary KL
  • Poisson (λ\lambda): A(λ)=eλA(\lambda) = e^\lambda

5.4 VAE Closed-Form KL and the Reparameterization Trick

The VAE's KL term per latent dimension is:

DKL(N(μ,σ2)N(0,1))=12(μ2+σ2lnσ21)D_{\mathrm{KL}}(\mathcal{N}(\mu, \sigma^2) \| \mathcal{N}(0, 1)) = \frac{1}{2}(\mu^2 + \sigma^2 - \ln\sigma^2 - 1)

Derivation. Using the scalar Gaussian formula with μ2=0\mu_2 = 0, σ22=1\sigma_2^2 = 1:

DKL=ln1σ+σ2+μ2212=12lnσ2+σ2+μ2212=12(μ2+σ2lnσ21)D_{\mathrm{KL}} = \ln\frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} = -\frac{1}{2}\ln\sigma^2 + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} = \frac{1}{2}(\mu^2 + \sigma^2 - \ln\sigma^2 - 1)

Properties: This quantity is 0 when μ=0,σ=1\mu = 0, \sigma = 1 (the posterior equals the prior); strictly positive otherwise. Gradient w.r.t. μ\mu: μ\mu. Gradient w.r.t. σ2\sigma^2: 12(11/σ2)\frac{1}{2}(1 - 1/\sigma^2). Both are simple, enabling efficient backpropagation through the KL term.

Reparameterization trick. The challenge: we need Ezqϕ(zx)[logpθ(xz)]\mathbb{E}_{z \sim q_\phi(z|x)}[\log p_\theta(x|z)] but zz depends on ϕ\phi, so we cannot backpropagate through the sampling. The trick: write z=μϕ(x)+σϕ(x)εz = \mu_\phi(x) + \sigma_\phi(x) \cdot \varepsilon where εN(0,1)\varepsilon \sim \mathcal{N}(0,1) is a separate random variable independent of ϕ\phi. Now zϕ=μϕϕ+εσϕϕ\frac{\partial z}{\partial \phi} = \frac{\partial \mu_\phi}{\partial \phi} + \varepsilon \frac{\partial \sigma_\phi}{\partial \phi} is well-defined, and gradients flow through. This is the key insight of Kingma & Welling (2014).

Skill Check

Test this lesson

Answer 4 quick questions to lock in the lesson and feed your adaptive practice queue.

--
Score
0/4
Answered
Not attempted
Status
1

Which module does this lesson belong to?

2

Which section is covered in this lesson content?

3

Which term is most central to this lesson?

4

What is the best way to use this lesson for real learning?

Your answers save locally first, then sync when account storage is available.
Practice queue