Theory Notebook
Theory Notebook
Converted from
theory.ipynbfor web reading.
RNN and LSTM Math: Theory Notebook
This notebook makes recurrent sequence math concrete: hidden-state updates, BPTT products, clipping, LSTM and GRU gates, masking, teacher forcing, and attention over encoder states.
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
1. Vanilla RNN update
Code cell 4
rng = np.random.default_rng(1)
T, d_x, d_h = 5, 3, 4
X = rng.normal(size=(T, d_x))
W_xh = rng.normal(scale=0.3, size=(d_h, d_x))
W_hh = rng.normal(scale=0.3, size=(d_h, d_h))
b = np.zeros(d_h)
h = np.zeros(d_h)
states = []
for t in range(T):
h = np.tanh(W_xh @ X[t] + W_hh @ h + b)
states.append(h.copy())
H = np.stack(states)
print("hidden sequence shape:", H.shape)
print("last hidden state:", np.round(H[-1], 3))
2. Parameter sharing across time
Code cell 6
params_once = W_xh.size + W_hh.size + b.size
params_unshared = T * params_once
print("shared RNN parameters:", params_once)
print("if unshared over time:", params_unshared)
3. Gradient product intuition
Code cell 8
steps = np.arange(1, 51)
for scale in [0.8, 1.0, 1.2]:
plt.plot(steps, scale ** steps, label=f"scale={scale}")
plt.yscale("log")
plt.title("Repeated Jacobian scale over time")
plt.xlabel("time distance")
plt.ylabel("gradient scale proxy")
plt.legend()
plt.tight_layout()
plt.show()
4. Gradient clipping
Code cell 10
g = np.array([3.0, 4.0, 12.0])
clip = 5.0
norm = np.linalg.norm(g)
scale = min(1.0, clip / norm)
g_clip = g * scale
print("original norm:", norm)
print("clipped norm:", np.linalg.norm(g_clip))
print("scale:", scale)
5. Truncated BPTT chunks
Code cell 12
T = 23
window = 6
chunks = [(start, min(start + window, T)) for start in range(0, T, window)]
print("chunks:", chunks)
print("Detach hidden state between chunks in an autodiff framework.")
6. One LSTM cell step
Code cell 14
def sigmoid(x):
return 1 / (1 + np.exp(-x))
d_x, d_h = 3, 4
x = rng.normal(size=d_x)
h_prev = rng.normal(size=d_h)
c_prev = rng.normal(size=d_h)
concat = np.r_[x, h_prev]
W = rng.normal(scale=0.2, size=(4 * d_h, d_x + d_h))
gates = W @ concat
f = sigmoid(gates[:d_h])
i = sigmoid(gates[d_h:2*d_h])
o = sigmoid(gates[2*d_h:3*d_h])
cand = np.tanh(gates[3*d_h:])
c = f * c_prev + i * cand
h = o * np.tanh(c)
print("forget gate:", np.round(f, 3))
print("input gate:", np.round(i, 3))
print("cell shape:", c.shape, "hidden shape:", h.shape)
7. Forget gate memory timescale
Code cell 16
forget_values = [0.5, 0.8, 0.95, 0.99]
steps = np.arange(0, 80)
for f in forget_values:
plt.plot(steps, f ** steps, label=f"f={f}")
plt.title("Forget gate controls memory decay")
plt.xlabel("steps")
plt.ylabel("remaining cell contribution")
plt.legend()
plt.tight_layout()
plt.show()
8. One GRU step
Code cell 18
def sigmoid(x):
return 1 / (1 + np.exp(-x))
d_x, d_h = 3, 4
x = rng.normal(size=d_x)
h_prev = rng.normal(size=d_h)
concat = np.r_[x, h_prev]
Wz = rng.normal(scale=0.2, size=(d_h, d_x + d_h))
Wr = rng.normal(scale=0.2, size=(d_h, d_x + d_h))
z = sigmoid(Wz @ concat)
r = sigmoid(Wr @ concat)
Wh = rng.normal(scale=0.2, size=(d_h, d_x + d_h))
h_tilde = np.tanh(Wh @ np.r_[x, r * h_prev])
h = (1 - z) * h_prev + z * h_tilde
print("update gate:", np.round(z, 3))
print("reset gate:", np.round(r, 3))
print("hidden:", np.round(h, 3))
9. Masked sequence loss
Code cell 20
losses = np.array([[0.4, 0.5, 0.8, 0.0], [0.7, 0.6, 0.0, 0.0]])
mask = np.array([[1, 1, 1, 0], [1, 1, 0, 0]])
masked = (losses * mask).sum() / mask.sum()
bad = losses.mean()
print("masked loss:", masked)
print("unmasked loss:", bad)
10. Teacher forcing versus free running
Code cell 22
gold = ["I", "like", "math", "<EOS>"]
pred = ["I", "like", "pizza", "<EOS>"]
for t in range(1, len(gold)):
teacher_input = gold[t-1]
free_input = pred[t-1]
print(f"step {t}: teacher forcing input={teacher_input:>5s}, free-running input={free_input:>5s}")
11. Attention over encoder states
Code cell 24
encoder = rng.normal(size=(5, 4))
decoder_state = rng.normal(size=4)
scores = encoder @ decoder_state
weights = np.exp(scores - scores.max())
weights = weights / weights.sum()
context = weights @ encoder
print("attention weights:", np.round(weights, 3))
print("context shape:", context.shape)
12. Bidirectional shape check
Code cell 26
B, T, hidden = 2, 6, 4
forward = np.zeros((B, T, hidden))
backward = np.zeros((B, T, hidden))
bi = np.concatenate([forward, backward], axis=-1)
print("bidirectional output shape:", bi.shape)
13. Gate saturation diagnostic
Code cell 28
gate_samples = sigmoid(rng.normal(loc=0.0, scale=3.0, size=1000))
saturated = np.mean((gate_samples < 0.05) | (gate_samples > 0.95))
plt.hist(gate_samples, bins=30)
plt.title("Gate activation histogram")
plt.xlabel("gate value")
plt.ylabel("count")
plt.tight_layout()
plt.show()
print("saturated fraction:", saturated)
14. Final RNN checklist
Code cell 30
checks = [
"batch, time, and feature axes are explicit",
"padding mask is applied before averaging loss",
"hidden state is detached between truncated BPTT chunks",
"gradient norms are tracked and clipped if needed",
"gate statistics are monitored for saturation",
"short and long sequences are evaluated separately",
]
for i, check in enumerate(checks, 1):
print(f"{i}. {check}")