Theory NotebookMath for LLMs

Sampling Methods

ML Specific Math / Sampling Methods

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Sampling Methods - Theory Notebook

Executable companion to notes.md.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({"figure.figsize":(10,6),"figure.dpi":120,"font.size":13,"axes.titlesize":15,"axes.labelsize":13,"xtick.labelsize":11,"ytick.labelsize":11,"legend.fontsize":11,"legend.framealpha":0.85,"lines.linewidth":2.0,"axes.spines.top":False,"axes.spines.right":False,"savefig.bbox":"tight","savefig.dpi":150})
np.random.seed(42)
print("Plot setup complete.")
COLORS={"primary":"#0077BB","secondary":"#EE7733","tertiary":"#009988","error":"#CC3311","neutral":"#555555","highlight":"#EE3377"}

Code cell 4

def check_close(name,value,expected,tol=1e-2):
    ok=np.allclose(value,expected,atol=tol,rtol=tol); print(f"{'PASS' if ok else 'FAIL'} - {name}: value={value}, expected={expected}"); return ok
def check_true(name,condition): ok=bool(condition); print(f"{'PASS' if ok else 'FAIL'} - {name}"); return ok
def softmax(z):
    z=np.asarray(z,dtype=float); sh=z-np.max(z); e=np.exp(sh); return e/e.sum()
def categorical_sample(p,n=1):
    cdf=np.cumsum(p); u=np.random.rand(n); return np.searchsorted(cdf,u)
def ess(weights):
    w=np.asarray(weights,dtype=float); return (w.sum()**2)/np.sum(w**2)
def gumbel(shape):
    u=np.random.uniform(1e-12,1-1e-12,size=shape); return -np.log(-np.log(u))
print("Sampling helpers ready.")

1. Inverse transform exponential

Code cell 6

u=np.random.rand(10000); lam=2.0; x=-np.log(1-u)/lam
print("sample mean", x.mean(), "theory", 1/lam)
check_close("exponential mean", x.mean(), 1/lam, tol=0.02)

2. Exponential histogram

Code cell 8

fig,ax=plt.subplots(); ax.hist(x,bins=50,density=True,alpha=0.7,color=COLORS["primary"],label="samples"); grid=np.linspace(0,4,300); ax.plot(grid,lam*np.exp(-lam*grid),color=COLORS["secondary"],label="pdf"); ax.set_title("Inverse-transform exponential sampling"); ax.set_xlabel("x"); ax.set_ylabel("density"); ax.legend(); fig.tight_layout(); plt.show(); print("Histogram plotted.")

3. Categorical sampling

Code cell 10

p=np.array([0.1,0.2,0.3,0.4]); samples=categorical_sample(p,20000); freq=np.bincount(samples,minlength=4)/len(samples)
print("frequencies", np.round(freq,3)); check_close("categorical frequencies", freq, p, tol=0.02)

4. Categorical CDF

Code cell 12

cdf=np.cumsum(p); print("cdf", cdf); check_close("last CDF one", cdf[-1], 1.0, tol=1e-8)

5. Rejection sampling triangle

Code cell 14

def target_pdf(x): return 2*x  # Beta(2,1) on [0,1]
accepted=[]; trials=0; M=2.0
while len(accepted)<5000:
    prop=np.random.rand(); u=np.random.rand(); trials+=1
    if u <= target_pdf(prop)/M: accepted.append(prop)
accepted=np.array(accepted); print("acceptance rate", len(accepted)/trials); print("mean", accepted.mean(), "theory", 2/3); check_close("Beta(2,1) mean", accepted.mean(), 2/3, tol=0.02)

6. Rejection visualization

Code cell 16

fig,ax=plt.subplots(); ax.hist(accepted,bins=40,density=True,alpha=0.7,color=COLORS["primary"],label="accepted"); grid=np.linspace(0,1,200); ax.plot(grid,target_pdf(grid),color=COLORS["secondary"],label="target"); ax.set_title("Rejection sampling target"); ax.set_xlabel("x"); ax.set_ylabel("density"); ax.legend(); fig.tight_layout(); plt.show(); print("Rejection plot done.")

7. Stratified sampling

Code cell 18

def f(x): return x**2
n=1000; plain=f(np.random.rand(n)).mean(); strata=(np.arange(n)+np.random.rand(n))/n; strat=f(strata).mean(); print("plain", plain, "stratified", strat, "true", 1/3); check_true("stratified close", abs(strat-1/3)<abs(plain-1/3)+0.01)

8. Monte Carlo convergence

Code cell 20

ns=np.array([10,30,100,300,1000,3000,10000]); estimates=[]
for n in ns: estimates.append(np.mean(np.random.normal(size=n)**2))
print(list(zip(ns, np.round(estimates,4))))
check_true("large n close to 1", abs(estimates[-1]-1)<0.05)

9. MC convergence plot

Code cell 22

fig,ax=plt.subplots(); ax.plot(ns,estimates,marker="o",color=COLORS["primary"],label="estimate"); ax.axhline(1,color=COLORS["neutral"],linestyle="--",label="true"); ax.set_xscale("log"); ax.set_title("Monte Carlo estimate of E[X^2]"); ax.set_xlabel("samples"); ax.set_ylabel("estimate"); ax.legend(); fig.tight_layout(); plt.show(); print("MC convergence plotted.")

10. Confidence interval

Code cell 24

vals=np.random.normal(loc=2,scale=3,size=500); mean=vals.mean(); se=vals.std(ddof=1)/np.sqrt(len(vals)); ci=(mean-1.96*se, mean+1.96*se); print("mean", mean, "CI", ci); check_true("CI contains true mean", ci[0] < 2 < ci[1])

11. Antithetic variance

Code cell 26

def estimate_plain(n): u=np.random.rand(n); return np.exp(u).mean()
def estimate_anti(n): u=np.random.rand(n//2); return np.r_[np.exp(u),np.exp(1-u)].mean()
plain=np.array([estimate_plain(200) for _ in range(300)]); anti=np.array([estimate_anti(200) for _ in range(300)]); print("plain var", plain.var(), "antithetic var", anti.var()); check_true("antithetic lower variance", anti.var()<plain.var())

12. Importance sampling normal tail

Code cell 28

# estimate P(Z>3) using proposal N(3,1)
n=20000; x=np.random.normal(loc=3,scale=1,size=n)
p_density=np.exp(-0.5*x**2)/np.sqrt(2*np.pi); q_density=np.exp(-0.5*(x-3)**2)/np.sqrt(2*np.pi); w=p_density/q_density; est=np.mean((x>3)*w)
print("tail estimate", est); check_true("rare tail positive", 0.0005 < est < 0.003)

13. ESS

Code cell 30

weights=w[:1000]; print("ESS", ess(weights), "of", len(weights)); check_true("ESS less than n", ess(weights)<len(weights))

14. Weight degeneracy

Code cell 32

bad_w=np.array([999.0]+[1.0]*999); print("bad ESS", ess(bad_w)); check_true("one giant weight lowers ESS", ess(bad_w)<5)

15. Self-normalized IS

Code cell 34

f=x[:5000]**2; ww=w[:5000]; sn=np.sum(ww*f)/np.sum(ww); print("self-normalized E[Z^2]", sn); check_true("finite SN estimate", np.isfinite(sn))

16. Metropolis target

Code cell 36

def logp(x): return -0.5*x*x
chain=[]; cur=0.0; acc=0
for t in range(15000):
    prop=cur+np.random.normal(scale=1.0)
    if np.log(np.random.rand()) < logp(prop)-logp(cur): cur=prop; acc+=1
    if t>=1000: chain.append(cur)
chain=np.array(chain); print("acceptance", acc/15000, "mean", chain.mean(), "var", chain.var()); check_close("MH variance", chain.var(), 1.0, tol=0.08)

17. MH trace plot

Code cell 38

fig,ax=plt.subplots(); ax.plot(chain[:500],color=COLORS["primary"]); ax.set_title("Random-walk Metropolis trace"); ax.set_xlabel("step"); ax.set_ylabel("x"); fig.tight_layout(); plt.show(); print("Trace plotted.")

18. Gibbs-style bivariate normal

Code cell 40

rho=0.8; xg=0.; yg=0.; draws=[]
for _ in range(6000):
    xg=np.random.normal(rho*yg, np.sqrt(1-rho**2)); yg=np.random.normal(rho*xg, np.sqrt(1-rho**2)); draws.append([xg,yg])
draws=np.array(draws[1000:]); print("corr", np.corrcoef(draws.T)[0,1]); check_close("Gibbs corr", np.corrcoef(draws.T)[0,1], rho, tol=0.05)

19. Langevin toy

Code cell 42

x=5.0; eta=0.05; path=[]
for _ in range(500):
    grad=-x; x=x+0.5*eta*grad+np.sqrt(eta)*np.random.normal(); path.append(x)
print("final", x, "path mean tail", np.mean(path[-100:])); check_true("Langevin moved from start", abs(path[-1])<5)

20. Reparameterization Gaussian

Code cell 44

mu=2.0; sigma=0.5; eps=np.random.normal(size=10000); z=mu+sigma*eps
print("sample mean", z.mean(), "std", z.std()); check_close("reparam mean", z.mean(), mu, tol=0.02); check_close("reparam std", z.std(), sigma, tol=0.02)

21. Gumbel-Max

Code cell 46

p=np.array([0.1,0.3,0.6]); logits=np.log(p); g=gumbel((20000,3)); samp=np.argmax(logits+g,axis=1); freq=np.bincount(samp,minlength=3)/len(samp); print("freq", np.round(freq,3)); check_close("Gumbel-Max categorical", freq, p, tol=0.02)

22. Gumbel-Softmax

Code cell 48

for tau in [1.0,0.5,0.1]:
    y=softmax((logits+gumbel(3))/tau); print(f"tau={tau}: {np.round(y,3)}, sum={y.sum():.3f}")

23. Minibatch gradient estimator

Code cell 50

data=np.random.normal(loc=1.0,size=10000); full_grad=2*np.mean(data-0); batch=data[np.random.choice(len(data),size=128,replace=False)]; batch_grad=2*np.mean(batch); print("full grad", full_grad, "batch grad", batch_grad); check_true("batch grad finite", np.isfinite(batch_grad))

24. Negative sampling hardness

Code cell 52

pos=2.0; easy=np.array([-3,-2,-1]); hard=np.array([1.8,1.5,1.2]); loss_easy=-pos+np.log(np.exp(pos)+np.exp(easy).sum()); loss_hard=-pos+np.log(np.exp(pos)+np.exp(hard).sum()); print("easy", loss_easy, "hard", loss_hard); check_true("hard negatives higher loss", loss_hard>loss_easy)

25. Temperature top-k top-p

Code cell 54

logits=np.array([5,4,3,2,1,0],dtype=float)
for tau in [0.5,1,2]: print("tau", tau, "probs", np.round(softmax(logits/tau),3))
probs=softmax(logits); topk_idx=np.argsort(probs)[-3:][::-1]; print("top-k idx", topk_idx)
order=np.argsort(probs)[::-1]; cum=np.cumsum(probs[order]); topp=order[cum<=0.9]; topp=np.r_[topp, order[len(topp)]] if len(topp)<len(order) else topp; print("top-p idx", topp); check_true("top-k selects 3", len(topk_idx)==3)