Theory NotebookMath for LLMs

RNN and LSTM Math

Math for Specific Models / RNN and LSTM Math

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

RNN and LSTM Math: Theory Notebook

This notebook makes recurrent sequence math concrete: hidden-state updates, BPTT products, clipping, LSTM and GRU gates, masking, teacher forcing, and attention over encoder states.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

1. Vanilla RNN update

Code cell 4

rng = np.random.default_rng(1)
T, d_x, d_h = 5, 3, 4
X = rng.normal(size=(T, d_x))
W_xh = rng.normal(scale=0.3, size=(d_h, d_x))
W_hh = rng.normal(scale=0.3, size=(d_h, d_h))
b = np.zeros(d_h)
h = np.zeros(d_h)
states = []
for t in range(T):
    h = np.tanh(W_xh @ X[t] + W_hh @ h + b)
    states.append(h.copy())
H = np.stack(states)
print("hidden sequence shape:", H.shape)
print("last hidden state:", np.round(H[-1], 3))

2. Parameter sharing across time

Code cell 6

params_once = W_xh.size + W_hh.size + b.size
params_unshared = T * params_once
print("shared RNN parameters:", params_once)
print("if unshared over time:", params_unshared)

3. Gradient product intuition

Code cell 8

steps = np.arange(1, 51)
for scale in [0.8, 1.0, 1.2]:
    plt.plot(steps, scale ** steps, label=f"scale={scale}")
plt.yscale("log")
plt.title("Repeated Jacobian scale over time")
plt.xlabel("time distance")
plt.ylabel("gradient scale proxy")
plt.legend()
plt.tight_layout()
plt.show()

4. Gradient clipping

Code cell 10

g = np.array([3.0, 4.0, 12.0])
clip = 5.0
norm = np.linalg.norm(g)
scale = min(1.0, clip / norm)
g_clip = g * scale
print("original norm:", norm)
print("clipped norm:", np.linalg.norm(g_clip))
print("scale:", scale)

5. Truncated BPTT chunks

Code cell 12

T = 23
window = 6
chunks = [(start, min(start + window, T)) for start in range(0, T, window)]
print("chunks:", chunks)
print("Detach hidden state between chunks in an autodiff framework.")

6. One LSTM cell step

Code cell 14

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

d_x, d_h = 3, 4
x = rng.normal(size=d_x)
h_prev = rng.normal(size=d_h)
c_prev = rng.normal(size=d_h)
concat = np.r_[x, h_prev]
W = rng.normal(scale=0.2, size=(4 * d_h, d_x + d_h))
gates = W @ concat
f = sigmoid(gates[:d_h])
i = sigmoid(gates[d_h:2*d_h])
o = sigmoid(gates[2*d_h:3*d_h])
cand = np.tanh(gates[3*d_h:])
c = f * c_prev + i * cand
h = o * np.tanh(c)
print("forget gate:", np.round(f, 3))
print("input gate:", np.round(i, 3))
print("cell shape:", c.shape, "hidden shape:", h.shape)

7. Forget gate memory timescale

Code cell 16

forget_values = [0.5, 0.8, 0.95, 0.99]
steps = np.arange(0, 80)
for f in forget_values:
    plt.plot(steps, f ** steps, label=f"f={f}")
plt.title("Forget gate controls memory decay")
plt.xlabel("steps")
plt.ylabel("remaining cell contribution")
plt.legend()
plt.tight_layout()
plt.show()

8. One GRU step

Code cell 18

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

d_x, d_h = 3, 4
x = rng.normal(size=d_x)
h_prev = rng.normal(size=d_h)
concat = np.r_[x, h_prev]
Wz = rng.normal(scale=0.2, size=(d_h, d_x + d_h))
Wr = rng.normal(scale=0.2, size=(d_h, d_x + d_h))
z = sigmoid(Wz @ concat)
r = sigmoid(Wr @ concat)
Wh = rng.normal(scale=0.2, size=(d_h, d_x + d_h))
h_tilde = np.tanh(Wh @ np.r_[x, r * h_prev])
h = (1 - z) * h_prev + z * h_tilde
print("update gate:", np.round(z, 3))
print("reset gate:", np.round(r, 3))
print("hidden:", np.round(h, 3))

9. Masked sequence loss

Code cell 20

losses = np.array([[0.4, 0.5, 0.8, 0.0], [0.7, 0.6, 0.0, 0.0]])
mask = np.array([[1, 1, 1, 0], [1, 1, 0, 0]])
masked = (losses * mask).sum() / mask.sum()
bad = losses.mean()
print("masked loss:", masked)
print("unmasked loss:", bad)

10. Teacher forcing versus free running

Code cell 22

gold = ["I", "like", "math", "<EOS>"]
pred = ["I", "like", "pizza", "<EOS>"]
for t in range(1, len(gold)):
    teacher_input = gold[t-1]
    free_input = pred[t-1]
    print(f"step {t}: teacher forcing input={teacher_input:>5s}, free-running input={free_input:>5s}")

11. Attention over encoder states

Code cell 24

encoder = rng.normal(size=(5, 4))
decoder_state = rng.normal(size=4)
scores = encoder @ decoder_state
weights = np.exp(scores - scores.max())
weights = weights / weights.sum()
context = weights @ encoder
print("attention weights:", np.round(weights, 3))
print("context shape:", context.shape)

12. Bidirectional shape check

Code cell 26

B, T, hidden = 2, 6, 4
forward = np.zeros((B, T, hidden))
backward = np.zeros((B, T, hidden))
bi = np.concatenate([forward, backward], axis=-1)
print("bidirectional output shape:", bi.shape)

13. Gate saturation diagnostic

Code cell 28

gate_samples = sigmoid(rng.normal(loc=0.0, scale=3.0, size=1000))
saturated = np.mean((gate_samples < 0.05) | (gate_samples > 0.95))
plt.hist(gate_samples, bins=30)
plt.title("Gate activation histogram")
plt.xlabel("gate value")
plt.ylabel("count")
plt.tight_layout()
plt.show()
print("saturated fraction:", saturated)

14. Final RNN checklist

Code cell 30

checks = [
    "batch, time, and feature axes are explicit",
    "padding mask is applied before averaging loss",
    "hidden state is detached between truncated BPTT chunks",
    "gradient norms are tracked and clipped if needed",
    "gate statistics are monitored for saturation",
    "short and long sequences are evaluated separately",
]
for i, check in enumerate(checks, 1):
    print(f"{i}. {check}")