Theory Notebook
Theory Notebook
Converted from
theory.ipynbfor web reading.
Fine-Tuning Math: Theory Notebook
This notebook turns fine-tuning into numerical bookkeeping: masked SFT loss, KL movement from a base model, LoRA ranks and merges, adapter and prefix parameter counts, DPO preference loss, and forgetting diagnostics.
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
1. Base versus tuned distributions
Code cell 4
def softmax(z):
z = np.asarray(z, dtype=float)
z = z - z.max()
e = np.exp(z)
return e / e.sum()
base_logits = np.array([2.0, 1.5, 0.2, -0.5])
tuned_logits = base_logits + np.array([0.1, -0.4, 0.8, -0.2])
base = softmax(base_logits)
tuned = softmax(tuned_logits)
kl = np.sum(tuned * (np.log(tuned + 1e-12) - np.log(base + 1e-12)))
print("base:", np.round(base, 3))
print("tuned:", np.round(tuned, 3))
print("KL(tuned || base):", round(kl, 4))
2. Answer-only SFT loss
Code cell 6
neg_log_probs = np.array([[0.2, 0.3, 1.2, 0.8, 0.0]])
mask = np.array([[0, 0, 1, 1, 0]])
answer_loss = (neg_log_probs * mask).sum() / mask.sum()
bad_loss = neg_log_probs.mean()
print("answer-only loss:", answer_loss)
print("unmasked loss:", bad_loss)
3. Full fine-tuning versus LoRA parameter count
Code cell 8
d_in = 4096
d_out = 4096
ranks = np.array([4, 8, 16, 32, 64])
full = d_in * d_out
lora = ranks * (d_in + d_out)
for r, count in zip(ranks, lora):
print(f"rank {r:>2}: {count:,} trainable ({100*count/full:.3f}% of full matrix)")
4. Low-rank approximation intuition
Code cell 10
rng = np.random.default_rng(3)
delta = rng.normal(size=(32, 32))
U, S, Vt = np.linalg.svd(delta, full_matrices=False)
errors = []
for r in [1, 2, 4, 8, 16, 32]:
approx = (U[:, :r] * S[:r]) @ Vt[:r, :]
rel_error = np.linalg.norm(delta - approx) / np.linalg.norm(delta)
errors.append(rel_error)
print(f"rank {r:>2}: relative error {rel_error:.4f}")
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot([1, 2, 4, 8, 16, 32], errors, marker="o")
ax.set_title("Truncated SVD approximation error")
ax.set_xlabel("rank")
ax.set_ylabel("relative Frobenius error")
fig.tight_layout()
plt.show()
5. LoRA update and merge
Code cell 12
d_in, d_out, r, alpha = 6, 5, 2, 8
W = np.random.normal(size=(d_out, d_in))
A = np.random.normal(scale=0.02, size=(r, d_in))
B = np.random.normal(scale=0.02, size=(d_out, r))
delta_W = (alpha / r) * (B @ A)
W_merged = W + delta_W
x = np.random.normal(size=(d_in,))
before = W @ x
after_adapter = W @ x + delta_W @ x
after_merged = W_merged @ x
print("delta_W shape:", delta_W.shape)
print("merge error:", np.max(np.abs(after_adapter - after_merged)))
6. Adapter bottleneck count
Code cell 14
d_model = 4096
bottleneck = 64
layers = 32
per_layer = d_model * bottleneck + bottleneck * d_model
total = layers * per_layer
print("adapter params per layer:", per_layer)
print("adapter params total:", total)
7. Prefix-tuning parameter count
Code cell 16
layers = 32
prefix_len = 16
heads = 32
head_dim = 128
kv = 2
prefix_params = layers * prefix_len * heads * head_dim * kv
print("prefix parameters:", prefix_params)
print("million:", prefix_params / 1e6)
8. Memory effect of PEFT
Code cell 18
base_params = 7e9
lora_params = 40e6
adam_bytes_per_trainable = 12
base_optimizer_gb = base_params * adam_bytes_per_trainable / 1e9
lora_optimizer_gb = lora_params * adam_bytes_per_trainable / 1e9
print("full Adam state GB:", base_optimizer_gb)
print("LoRA Adam state GB:", lora_optimizer_gb)
print("reduction factor:", base_optimizer_gb / lora_optimizer_gb)
9. DPO loss for one preference pair
Code cell 20
def sigmoid(x):
return 1 / (1 + np.exp(-x))
beta = 0.1
logp_theta_chosen = -12.0
logp_theta_rejected = -15.0
logp_ref_chosen = -13.0
logp_ref_rejected = -14.0
margin = (logp_theta_chosen - logp_theta_rejected) - (logp_ref_chosen - logp_ref_rejected)
loss = -np.log(sigmoid(beta * margin))
print("preference margin:", margin)
print("DPO loss:", loss)
10. KL penalty intuition
Code cell 22
def softmax(z):
z = np.asarray(z, dtype=float)
z = z - z.max()
e = np.exp(z)
return e / e.sum()
base_logits = np.array([2.0, 1.0, 0.0])
direction = np.array([-0.5, 0.2, 0.8])
strengths = np.linspace(0, 3, 20)
kls = []
for s in strengths:
p0 = softmax(base_logits)
p1 = softmax(base_logits + s * direction)
kls.append(np.sum(p1 * (np.log(p1 + 1e-12) - np.log(p0 + 1e-12))))
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(strengths, kls, marker="o")
ax.set_title("Function-space shift grows with update strength")
ax.set_xlabel("update strength")
ax.set_ylabel("KL(tuned || base)")
fig.tight_layout()
plt.show()
print("final KL:", kls[-1])
11. Layer-wise learning-rate decay
Code cell 24
layers = np.arange(1, 13)
eta_top = 2e-5
gamma = 0.85
lrs = eta_top * gamma ** (layers.max() - layers)
for layer, lr in zip(layers, lrs):
print(f"layer {layer:>2}: lr={lr:.8f}")
12. Task quality versus retention
Code cell 26
methods = ["base", "prompt", "lora-r8", "lora-r32", "full"]
task = np.array([0.45, 0.58, 0.72, 0.78, 0.82])
retain = np.array([0.90, 0.89, 0.86, 0.82, 0.70])
score = 0.7 * task + 0.3 * retain
for m, t, r, s in zip(methods, task, retain, score):
print(f"{m:>8s}: task={t:.2f} retain={r:.2f} combined={s:.3f}")
13. Trainable-parameter audit
Code cell 28
params = {
"base.embed": False,
"base.block1.q_proj": False,
"base.block1.q_proj.lora_A": True,
"base.block1.q_proj.lora_B": True,
"lm_head": False,
}
for name, trainable in params.items():
print(f"{name:30s} trainable={trainable}")
print("trainable tensors:", sum(params.values()))
14. Final fine-tuning checklist
Code cell 30
checks = [
"prompt and answer masks match the objective",
"only intended tensors require gradients",
"trainable parameter count matches expectation",
"base, prompt-only, and tuned model are compared",
"retention and task quality are both measured",
"adapter can be disabled or merged for an ablation",
]
for i, check in enumerate(checks, 1):
print(f"{i}. {check}")