Theory NotebookMath for LLMs

Quantization and Distillation

Math for LLMs / Quantization and Distillation

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Quantization and Distillation: Theory Notebook

This notebook makes compression math concrete: uniform quantization, clipping, group-wise scales, error versus bits, distillation temperature, KL loss, and memory accounting.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

1. Affine quantization

Code cell 4

def quantize_affine(x, qmin, qmax):
    x = np.asarray(x, dtype=float)
    scale = (x.max() - x.min()) / (qmax - qmin)
    zero = qmin - np.round(x.min() / scale)
    q = np.round(x / scale + zero)
    q = np.clip(q, qmin, qmax)
    x_hat = scale * (q - zero)
    return q, x_hat, scale, zero

x = np.array([-1.0, -0.3, 0.2, 0.9])
q, x_hat, s, z = quantize_affine(x, 0, 15)
print("q:", q.astype(int))
print("x_hat:", np.round(x_hat, 4))
print("scale:", s, "zero:", z)
print("max error:", np.max(np.abs(x - x_hat)))

2. Symmetric INT4 quantization

Code cell 6

def quantize_symmetric(x, bits):
    qmax = 2 ** (bits - 1) - 1
    qmin = -2 ** (bits - 1)
    scale = np.max(np.abs(x)) / qmax
    q = np.clip(np.round(x / scale), qmin, qmax)
    return q, q * scale, scale

rng = np.random.default_rng(2)
w = rng.normal(size=16)
q, w_hat, scale = quantize_symmetric(w, 4)
print("scale:", scale)
print("MSE:", np.mean((w - w_hat) ** 2))
print("unique integer values:", np.unique(q).astype(int))

3. Error versus bit width

Code cell 8

bits_list = np.arange(2, 9)
errors = []
for bits in bits_list:
    _, w_hat, _ = quantize_symmetric(w, bits)
    errors.append(np.mean((w - w_hat) ** 2))
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(bits_list, errors, marker="o")
ax.set_title("Quantization error decreases with bit width")
ax.set_xlabel("bits")
ax.set_ylabel("MSE")
fig.tight_layout()
plt.show()
print("errors:", np.round(errors, 6))

4. Per-channel quantization

Code cell 10

W = np.vstack([
    np.random.normal(scale=0.1, size=8),
    np.random.normal(scale=1.0, size=8),
    np.random.normal(scale=3.0, size=8),
])
_, global_hat, _ = quantize_symmetric(W, 4)
per_hat = np.zeros_like(W)
for i in range(W.shape[0]):
    _, per_hat[i], _ = quantize_symmetric(W[i], 4)
print("global MSE:", np.mean((W - global_hat) ** 2))
print("per-channel MSE:", np.mean((W - per_hat) ** 2))

5. Clipping tradeoff

Code cell 12

x = np.r_[np.random.normal(scale=0.5, size=1000), np.array([5.0, -4.5])]
clips = np.linspace(0.5, 5.0, 20)
mse = []
for c in clips:
    clipped = np.clip(x, -c, c)
    _, x_hat, _ = quantize_symmetric(clipped, 4)
    mse.append(np.mean((x - x_hat) ** 2))
best = int(np.argmin(mse))
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(clips, mse, marker="o")
ax.scatter([clips[best]], [mse[best]], color="red")
ax.set_title("Clipping range tradeoff")
ax.set_xlabel("clip range")
ax.set_ylabel("MSE to original")
fig.tight_layout()
plt.show()
print("best clip:", clips[best], "MSE:", mse[best])

6. Weight memory by precision

Code cell 14

params = 7e9
for bits in [16, 8, 4, 3, 2]:
    gb = params * bits / 8 / 1e9
    print(f"{bits:>2}-bit weights: {gb:.2f} GB")

7. SmoothQuant-style scale shifting intuition

Code cell 16

X = np.array([[10.0, 0.2, 0.1], [8.0, -0.1, 0.2]])
W = np.array([[0.1, 2.0, -1.5], [0.2, -1.0, 1.0]])
scale = np.array([4.0, 1.0, 1.0])
Y_original = X @ W.T
X_scaled = X / scale
W_scaled = W * scale
Y_scaled = X_scaled @ W_scaled.T
print("max matmul difference after exact scale shift:", np.max(np.abs(Y_original - Y_scaled)))
print("activation max before:", np.max(np.abs(X), axis=0))
print("activation max after: ", np.max(np.abs(X_scaled), axis=0))

8. Temperature softens teacher probabilities

Code cell 18

def softmax(z):
    z = np.asarray(z, dtype=float)
    z = z - z.max()
    e = np.exp(z)
    return e / e.sum()

teacher_logits = np.array([6.0, 3.0, 1.0, -1.0])
for tau in [1.0, 2.0, 4.0]:
    p = softmax(teacher_logits / tau)
    print(f"tau={tau}: probs={np.round(p, 3)}, entropy={-np.sum(p*np.log(p+1e-12)):.3f}")

9. KL distillation loss

Code cell 20

teacher = softmax(np.array([3.0, 1.0, 0.0]) / 2.0)
student = softmax(np.array([2.2, 1.4, -0.1]) / 2.0)
tau = 2.0
kl = np.sum(teacher * (np.log(teacher + 1e-12) - np.log(student + 1e-12)))
kd_loss = tau**2 * kl
print("teacher:", np.round(teacher, 3))
print("student:", np.round(student, 3))
print("KD loss:", kd_loss)

10. Combine hard and soft losses

Code cell 22

hard_ce = 0.8
kd = 0.35
alpha = 0.4
combined = alpha * hard_ce + (1 - alpha) * kd
print("combined distillation objective:", combined)

11. QLoRA memory intuition

Code cell 24

base_params = 7e9
lora_params = 40e6
base_bits = 4
lora_bits = 16
adam_bits_per_param = 96
base_gb = base_params * base_bits / 8 / 1e9
lora_gb = lora_params * lora_bits / 8 / 1e9
adam_gb = lora_params * adam_bits_per_param / 8 / 1e9
print("quantized base GB:", base_gb)
print("LoRA weights GB:", lora_gb)
print("LoRA Adam states GB:", adam_gb)

12. Compression checklist

Code cell 26

checks = [
    "state what is quantized: weights, activations, or KV cache",
    "state granularity: tensor, channel, or group",
    "calibration data matches deployment prompts",
    "compare logits, held-out loss, task score, calibration, memory, and latency",
    "verify serving kernels support the selected format",
]
for i, check in enumerate(checks, 1):
    print(f"{i}. {check}")