Theory Notebook
Theory Notebook
Converted from
theory.ipynbfor web reading.
Quantization and Distillation: Theory Notebook
This notebook makes compression math concrete: uniform quantization, clipping, group-wise scales, error versus bits, distillation temperature, KL loss, and memory accounting.
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
1. Affine quantization
Code cell 4
def quantize_affine(x, qmin, qmax):
x = np.asarray(x, dtype=float)
scale = (x.max() - x.min()) / (qmax - qmin)
zero = qmin - np.round(x.min() / scale)
q = np.round(x / scale + zero)
q = np.clip(q, qmin, qmax)
x_hat = scale * (q - zero)
return q, x_hat, scale, zero
x = np.array([-1.0, -0.3, 0.2, 0.9])
q, x_hat, s, z = quantize_affine(x, 0, 15)
print("q:", q.astype(int))
print("x_hat:", np.round(x_hat, 4))
print("scale:", s, "zero:", z)
print("max error:", np.max(np.abs(x - x_hat)))
2. Symmetric INT4 quantization
Code cell 6
def quantize_symmetric(x, bits):
qmax = 2 ** (bits - 1) - 1
qmin = -2 ** (bits - 1)
scale = np.max(np.abs(x)) / qmax
q = np.clip(np.round(x / scale), qmin, qmax)
return q, q * scale, scale
rng = np.random.default_rng(2)
w = rng.normal(size=16)
q, w_hat, scale = quantize_symmetric(w, 4)
print("scale:", scale)
print("MSE:", np.mean((w - w_hat) ** 2))
print("unique integer values:", np.unique(q).astype(int))
3. Error versus bit width
Code cell 8
bits_list = np.arange(2, 9)
errors = []
for bits in bits_list:
_, w_hat, _ = quantize_symmetric(w, bits)
errors.append(np.mean((w - w_hat) ** 2))
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(bits_list, errors, marker="o")
ax.set_title("Quantization error decreases with bit width")
ax.set_xlabel("bits")
ax.set_ylabel("MSE")
fig.tight_layout()
plt.show()
print("errors:", np.round(errors, 6))
4. Per-channel quantization
Code cell 10
W = np.vstack([
np.random.normal(scale=0.1, size=8),
np.random.normal(scale=1.0, size=8),
np.random.normal(scale=3.0, size=8),
])
_, global_hat, _ = quantize_symmetric(W, 4)
per_hat = np.zeros_like(W)
for i in range(W.shape[0]):
_, per_hat[i], _ = quantize_symmetric(W[i], 4)
print("global MSE:", np.mean((W - global_hat) ** 2))
print("per-channel MSE:", np.mean((W - per_hat) ** 2))
5. Clipping tradeoff
Code cell 12
x = np.r_[np.random.normal(scale=0.5, size=1000), np.array([5.0, -4.5])]
clips = np.linspace(0.5, 5.0, 20)
mse = []
for c in clips:
clipped = np.clip(x, -c, c)
_, x_hat, _ = quantize_symmetric(clipped, 4)
mse.append(np.mean((x - x_hat) ** 2))
best = int(np.argmin(mse))
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(clips, mse, marker="o")
ax.scatter([clips[best]], [mse[best]], color="red")
ax.set_title("Clipping range tradeoff")
ax.set_xlabel("clip range")
ax.set_ylabel("MSE to original")
fig.tight_layout()
plt.show()
print("best clip:", clips[best], "MSE:", mse[best])
6. Weight memory by precision
Code cell 14
params = 7e9
for bits in [16, 8, 4, 3, 2]:
gb = params * bits / 8 / 1e9
print(f"{bits:>2}-bit weights: {gb:.2f} GB")
7. SmoothQuant-style scale shifting intuition
Code cell 16
X = np.array([[10.0, 0.2, 0.1], [8.0, -0.1, 0.2]])
W = np.array([[0.1, 2.0, -1.5], [0.2, -1.0, 1.0]])
scale = np.array([4.0, 1.0, 1.0])
Y_original = X @ W.T
X_scaled = X / scale
W_scaled = W * scale
Y_scaled = X_scaled @ W_scaled.T
print("max matmul difference after exact scale shift:", np.max(np.abs(Y_original - Y_scaled)))
print("activation max before:", np.max(np.abs(X), axis=0))
print("activation max after: ", np.max(np.abs(X_scaled), axis=0))
8. Temperature softens teacher probabilities
Code cell 18
def softmax(z):
z = np.asarray(z, dtype=float)
z = z - z.max()
e = np.exp(z)
return e / e.sum()
teacher_logits = np.array([6.0, 3.0, 1.0, -1.0])
for tau in [1.0, 2.0, 4.0]:
p = softmax(teacher_logits / tau)
print(f"tau={tau}: probs={np.round(p, 3)}, entropy={-np.sum(p*np.log(p+1e-12)):.3f}")
9. KL distillation loss
Code cell 20
teacher = softmax(np.array([3.0, 1.0, 0.0]) / 2.0)
student = softmax(np.array([2.2, 1.4, -0.1]) / 2.0)
tau = 2.0
kl = np.sum(teacher * (np.log(teacher + 1e-12) - np.log(student + 1e-12)))
kd_loss = tau**2 * kl
print("teacher:", np.round(teacher, 3))
print("student:", np.round(student, 3))
print("KD loss:", kd_loss)
10. Combine hard and soft losses
Code cell 22
hard_ce = 0.8
kd = 0.35
alpha = 0.4
combined = alpha * hard_ce + (1 - alpha) * kd
print("combined distillation objective:", combined)
11. QLoRA memory intuition
Code cell 24
base_params = 7e9
lora_params = 40e6
base_bits = 4
lora_bits = 16
adam_bits_per_param = 96
base_gb = base_params * base_bits / 8 / 1e9
lora_gb = lora_params * lora_bits / 8 / 1e9
adam_gb = lora_params * adam_bits_per_param / 8 / 1e9
print("quantized base GB:", base_gb)
print("LoRA weights GB:", lora_gb)
print("LoRA Adam states GB:", adam_gb)
12. Compression checklist
Code cell 26
checks = [
"state what is quantized: weights, activations, or KV cache",
"state granularity: tensor, channel, or group",
"calibration data matches deployment prompts",
"compare logits, held-out loss, task score, calibration, memory, and latency",
"verify serving kernels support the selected format",
]
for i, check in enumerate(checks, 1):
print(f"{i}. {check}")