Theory Notebook
Theory Notebook
Converted from
theory.ipynbfor web reading.
Generative Models: Theory Notebook
This notebook makes generative-model objectives executable: autoregressive likelihood, VAE reparameterization, GAN losses, flow change of variables, diffusion noising, score updates, and FID intuition.
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
1. Autoregressive log likelihood
Code cell 4
conditional_probs = np.array([0.8, 0.5, 0.25, 0.9])
logp = np.log(conditional_probs).sum()
print("log probability:", logp)
print("probability:", np.exp(logp))
2. VAE reparameterization
Code cell 6
rng = np.random.default_rng(0)
mu = np.array([0.5, -0.2])
logvar = np.array([0.0, -1.0])
eps = rng.normal(size=mu.shape)
z = mu + np.exp(0.5 * logvar) * eps
print("epsilon:", np.round(eps, 3))
print("z:", np.round(z, 3))
3. Gaussian KL to standard normal
Code cell 8
kl = 0.5 * np.sum(np.exp(logvar) + mu**2 - 1 - logvar)
print("KL(q(z|x) || N(0,I)):", kl)
4. ELBO pieces
Code cell 10
reconstruction_logprob = -1.25
elbo = reconstruction_logprob - kl
negative_elbo = -elbo
print("ELBO:", elbo)
print("training loss -ELBO:", negative_elbo)
5. GAN losses
Code cell 12
D_real = np.array([0.9, 0.8, 0.7])
D_fake = np.array([0.2, 0.3, 0.4])
D_loss = -(np.log(D_real).mean() + np.log(1 - D_fake).mean())
G_loss_saturating = np.log(1 - D_fake).mean()
G_loss_nonsat = -np.log(D_fake).mean()
print("D loss:", D_loss)
print("G saturating objective:", G_loss_saturating)
print("G non-saturating loss:", G_loss_nonsat)
6. 1D flow change of variables
Code cell 14
# x = a z + b, z=(x-b)/a, p_x(x)=p_z(z)*|dz/dx|
a, b = 2.0, -1.0
x = 1.0
z = (x - b) / a
log_pz = -0.5 * z**2 - 0.5 * np.log(2 * np.pi)
log_px = log_pz - np.log(abs(a))
print("z:", z)
print("log p_x:", log_px)
7. Diffusion noising step
Code cell 16
x0 = np.array([1.0, -0.5, 0.25])
alpha_bar = 0.7
eps = rng.normal(size=x0.shape)
xt = np.sqrt(alpha_bar) * x0 + np.sqrt(1 - alpha_bar) * eps
print("x_t:", np.round(xt, 3))
8. Denoising MSE
Code cell 18
eps_pred = eps + np.array([0.1, -0.2, 0.05])
loss = np.mean((eps - eps_pred) ** 2)
print("noise prediction MSE:", loss)
9. Score-based Langevin update
Code cell 20
x = np.array([2.0, -1.0])
score = -x # score of standard normal
eta = 0.1
noise = rng.normal(size=x.shape)
x_next = x + eta * score + np.sqrt(2 * eta) * noise
print("x_next:", np.round(x_next, 3))
10. Simplified FID-style distance
Code cell 22
mu_r = np.array([0.0, 0.0])
mu_g = np.array([0.5, -0.25])
cov_r = np.eye(2)
cov_g = np.array([[1.2, 0.0], [0.0, 0.8]])
mean_term = np.sum((mu_r - mu_g) ** 2)
cov_term = np.trace(cov_r + cov_g - 2 * np.diag(np.sqrt(np.diag(cov_r) * np.diag(cov_g))))
print("FID-style diagonal distance:", mean_term + cov_term)
11. Diversity check
Code cell 24
samples = rng.normal(size=(100, 2))
collapsed = np.repeat(samples[:1], repeats=100, axis=0)
print("sample variance normal:", samples.var(axis=0))
print("sample variance collapsed:", collapsed.var(axis=0))
12. Denoising loss by timestep
Code cell 26
timesteps = np.arange(1, 21)
loss_by_t = 0.1 + 0.02 * timesteps + 0.05 * np.sin(timesteps / 2)
plt.plot(timesteps, loss_by_t, marker="o")
plt.title("Toy denoising loss by timestep")
plt.xlabel("timestep")
plt.ylabel("loss")
plt.tight_layout()
plt.show()
13. Final checklist
Code cell 28
checks = [
"state whether the objective is likelihood, ELBO, adversarial loss, or denoising loss",
"evaluate sample quality and diversity separately",
"track likelihood only when it is meaningful for the model family",
"inspect latent traversals or conditioning controls",
"report sampling cost and number of generation steps",
"check for mode collapse or timestep-specific failures",
]
for i, check in enumerate(checks, 1):
print(f"{i}. {check}")