Exercises NotebookMath for LLMs

Quantization and Distillation

Math for LLMs / Quantization and Distillation

Run notebook
Exercises Notebook

Exercises Notebook

Converted from exercises.ipynb for web reading.

Quantization and Distillation: Exercises

Ten exercises cover the compression math used in quantization and distillation: scales, dequantization, memory savings, clipping, KL loss, and deployment checks.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Exercise 1: Affine quantization

Quantize scalar 0.3 with scale 0.1 and zero point 0.

Code cell 4

# Your Solution
x, s, z = 0.3, 0.1, 0
print("Starter: q=round(x/s)+z, x_hat=s*(q-z).")

Code cell 5

# Solution
x, s, z = 0.3, 0.1, 0
q = round(x / s) + z
x_hat = s * (q - z)
print("q:", q, "x_hat:", x_hat)

Exercise 2: INT4 scale

Find symmetric scale for max absolute value 2.1.

Code cell 7

# Your Solution
max_abs = 2.1
qmax = 7
print("Starter: scale=max_abs/qmax.")

Code cell 8

# Solution
max_abs = 2.1
qmax = 7
scale = max_abs / qmax
print("scale:", scale)

Exercise 3: Memory savings

Compare bf16 and 4-bit memory for 13B parameters.

Code cell 10

# Your Solution
P = 13e9
print("Starter: memory GB = P*bits/8/1e9.")

Code cell 11

# Solution
P = 13e9
bf16 = P * 16 / 8 / 1e9
int4 = P * 4 / 8 / 1e9
print("bf16 GB:", bf16)
print("int4 GB:", int4)
print("reduction:", bf16 / int4)

Exercise 4: Quantization error

Compute MSE between weights and dequantized weights.

Code cell 13

# Your Solution
w = np.array([0.1, 0.4, -0.2])
w_hat = np.array([0.0, 0.5, -0.25])
print("Starter: mean((w-w_hat)^2).")

Code cell 14

# Solution
w = np.array([0.1, 0.4, -0.2])
w_hat = np.array([0.0, 0.5, -0.25])
mse = np.mean((w - w_hat) ** 2)
print("MSE:", mse)

Exercise 5: Best clipping

Choose the lower-MSE clipping result.

Code cell 16

# Your Solution
mse = np.array([0.05, 0.02, 0.03])
clips = np.array([1.0, 2.0, 3.0])
print("Starter: choose clips[argmin(mse)].")

Code cell 17

# Solution
mse = np.array([0.05, 0.02, 0.03])
clips = np.array([1.0, 2.0, 3.0])
print("best clip:", clips[np.argmin(mse)])

Exercise 6: Temperature softmax

Compute teacher probabilities at temperature 2.

Code cell 19

# Your Solution
logits = np.array([2.0, 0.0])
tau = 2.0
print("Starter: softmax(logits/tau).")

Code cell 20

# Solution
logits = np.array([2.0, 0.0])
tau = 2.0
z = logits / tau
e = np.exp(z - z.max())
p = e / e.sum()
print("p:", p)

Exercise 7: KL distillation

Compute KL(teacher || student).

Code cell 22

# Your Solution
t = np.array([0.7, 0.3])
s = np.array([0.6, 0.4])
print("Starter: sum t*(log t - log s).")

Code cell 23

# Solution
t = np.array([0.7, 0.3])
s = np.array([0.6, 0.4])
kl = np.sum(t * (np.log(t) - np.log(s)))
print("KL:", kl)

Exercise 8: Combined loss

Combine hard CE and KD loss.

Code cell 25

# Your Solution
ce = 1.0
kd = 0.4
alpha = 0.25
print("Starter: alpha*ce + (1-alpha)*kd.")

Code cell 26

# Solution
ce = 1.0
kd = 0.4
alpha = 0.25
loss = alpha * ce + (1 - alpha) * kd
print("combined:", loss)

Exercise 9: Adapter optimizer memory

Estimate optimizer memory for 20M trainable params with Adam fp32 moments.

Code cell 28

# Your Solution
P = 20e6
print("Starter: two fp32 moments = 8 bytes per param.")

Code cell 29

# Solution
P = 20e6
gb = P * 8 / 1e9
print("Adam moments GB:", gb)

Exercise 10: Compression checklist

Write four checks before shipping a quantized model.

Code cell 31

# Your Solution
print("Starter: include calibration, loss, latency, and task quality.")

Code cell 32

# Solution
checks = [
    "calibration data matches deployment prompts",
    "held-out loss shift is acceptable",
    "latency improves on the target kernel",
    "task and safety quality gates pass",
]
for check in checks:
    print("-", check)

Closing Reflection

Compression is successful only when quality, memory, latency, and hardware support improve together for the deployment target.