Theory Notebook
Converted from
theory.ipynbfor web reading.
Discrete Fourier Transform and FFT
"The FFT is the most important numerical algorithm of our lifetime." -- Gilbert Strang, MIT Mathematics
This notebook provides interactive derivations and visualizations covering:
- DFT definition, matrix form, and unitarity
- Cooley-Tukey FFT algorithm from scratch
- Spectral leakage and window functions
- STFT spectrograms and time-frequency analysis
- Whisper mel spectrogram pipeline
- Fourier Neural Operator spectral convolution layer
- Monarch butterfly matrix factorization
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
Code cell 3
import numpy as np
import scipy.fft as sp_fft
import scipy.signal as sp_sig
from scipy.special import i0 as bessel_i0
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style='whitegrid', palette='colorblind')
HAS_SNS = True
except ImportError:
plt.style.use('seaborn-v0_8-whitegrid')
HAS_SNS = False
mpl.rcParams.update({
'figure.figsize': (10, 6),
'figure.dpi': 120,
'font.size': 13,
'axes.titlesize': 15,
'axes.labelsize': 13,
'lines.linewidth': 2.0,
'axes.spines.top': False,
'axes.spines.right': False,
})
np.random.seed(42)
COLORS = {
'primary': '#0077BB',
'secondary': '#EE7733',
'tertiary': '#009988',
'error': '#CC3311',
'neutral': '#555555',
'highlight': '#EE3377',
}
print('Setup complete. NumPy', np.__version__)
1. Intuition
1.1 The DFT as Change of Basis
The -point DFT transforms a vector into its coordinates in the Fourier basis :
Below we visualize the Fourier basis vectors as sampled complex sinusoids.
Code cell 5
# === 1.1 Fourier Basis Vectors ===
N = 16
omega_N = np.exp(2j * np.pi / N)
n = np.arange(N)
fig, axes = plt.subplots(2, 4, figsize=(14, 6))
fig.suptitle(f'Fourier Basis Vectors for N={N} (real parts)', fontsize=14)
for idx, k in enumerate([0, 1, 2, 3, 4, 6, 7, 8]):
ax = axes[idx // 4, idx % 4]
f_k = np.exp(2j * np.pi * k * n / N) # basis vector f_k
ax.stem(n, f_k.real, linefmt='C0-',
markerfmt='o', basefmt='k-')
ax.set_title(f'k={k}, freq={k}/{N}')
ax.set_xlabel('n')
ax.set_ylabel('Re[f_k]')
ax.set_ylim(-1.3, 1.3)
fig.tight_layout()
plt.show()
print(f'Fourier basis: {N} orthonormal vectors in C^{N}')
print(f'Each f_k oscillates at {1}/{N} to {N//2}/{N} cycles per sample')
Code cell 6
# === 1.2 Orthonormality of Fourier Basis ===
N = 8
# Build normalized DFT matrix (1/sqrt(N) * F_N)
k = np.arange(N)[:, None] # column
n = np.arange(N)[None, :] # row
F_N = np.exp(-2j * np.pi * k * n / N) / np.sqrt(N)
# Check unitarity: F_N @ F_N.conj().T should = I
product = F_N @ F_N.conj().T
print(f'F_N @ F_N* = (showing real part):')
print(np.round(product.real, 8))
print()
ok = np.allclose(product, np.eye(N), atol=1e-12)
print(f'PASS - Normalized DFT matrix is unitary: {ok}')
# Frobenius norm of the error
err = np.linalg.norm(product - np.eye(N), 'fro')
print(f'Frobenius error ||F_N F_N* - I||_F = {err:.2e}')
2. Formal Definitions
2.1 The DFT by Hand: N=4 Example
Let . We have :
Code cell 8
# === 2.1 DFT by Hand Verification ===
x = np.array([1, 0, -1, 0], dtype=complex)
N = len(x)
omega_4 = np.exp(2j * np.pi / N)
# Manual DFT
X_manual = np.zeros(N, dtype=complex)
for k in range(N):
for n_idx in range(N):
X_manual[k] += x[n_idx] * omega_4**(-n_idx * k)
X_numpy = np.fft.fft(x)
print('Manual DFT:', np.round(X_manual, 6))
print('NumPy FFT: ', np.round(X_numpy, 6))
ok = np.allclose(X_manual, X_numpy)
print(f'PASS - Manual DFT matches NumPy: {ok}')
# Parseval check
energy_time = np.sum(np.abs(x)**2)
energy_freq = np.sum(np.abs(X_numpy)**2) / N
print(f'\nParseval: sum|x|^2 = {energy_time:.4f}, sum|X|^2/N = {energy_freq:.4f}')
print(f'PASS - Parseval holds: {np.isclose(energy_time, energy_freq)}')
Code cell 9
# === 2.3 DFT Matrix Visualization ===
N = 8
k_idx = np.arange(N)[:, None]
n_idx = np.arange(N)[None, :]
F8 = np.exp(-2j * np.pi * k_idx * n_idx / N)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
im0 = axes[0].imshow(F8.real, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
axes[0].set_title('Re(F_8): real part of DFT matrix')
axes[0].set_xlabel('Column (time n)')
axes[0].set_ylabel('Row (frequency k)')
fig.colorbar(im0, ax=axes[0])
im1 = axes[1].imshow(F8.imag, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
axes[1].set_title('Im(F_8): imaginary part of DFT matrix')
axes[1].set_xlabel('Column (time n)')
axes[1].set_ylabel('Row (frequency k)')
fig.colorbar(im1, ax=axes[1])
fig.tight_layout()
plt.show()
# Verify: F8 @ F8.conj().T = 8 * I
ok = np.allclose(F8 @ F8.conj().T, N * np.eye(N))
print(f'PASS - F_8 F_8* = {N}*I: {ok}')
print(f'Max entry of |F_8 F_8* - {N}I| = {np.max(np.abs(F8 @ F8.conj().T - N*np.eye(N))):.2e}')
2.5 Relation to the Continuous Fourier Transform
The DFT approximates the continuous FT of a sampled signal:
Below we compare the DFT of a Gaussian to its known continuous FT.
Code cell 11
# === 2.5 DFT vs Continuous FT ===
N = 256
fs = 100.0 # Hz
dt = 1.0 / fs
t = np.arange(N) * dt
df = fs / N
# Gaussian: f(t) = exp(-pi*t^2)
# Continuous FT (xi-convention): F_hat(xi) = exp(-pi*xi^2)
sigma = 0.1
x = np.exp(-np.pi * (t - t[N//2])**2 / sigma**2) # centered Gaussian
# DFT
X = np.fft.fftshift(np.fft.fft(np.fft.ifftshift(x))) * dt
freqs = np.fft.fftshift(np.fft.fftfreq(N, d=dt))
# Analytical FT
X_analytical = sigma * np.exp(-np.pi * sigma**2 * freqs**2)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].plot(t, x, color=COLORS['primary'])
axes[0].set_title(f'Gaussian signal (sigma={sigma})')
axes[0].set_xlabel('Time (s)')
axes[0].set_ylabel('x(t)')
mask = np.abs(freqs) < 15
axes[1].plot(freqs[mask], np.abs(X[mask]),
color=COLORS['primary'], label='|DFT|', lw=2)
axes[1].plot(freqs[mask], X_analytical[mask],
'--', color=COLORS['error'], label='Analytical FT', lw=2)
axes[1].set_title('DFT magnitude vs analytical FT')
axes[1].set_xlabel('Frequency (Hz)')
axes[1].set_ylabel('|X(f)|')
axes[1].legend()
fig.tight_layout()
plt.show()
err = np.max(np.abs(np.abs(X[mask]) - X_analytical[mask]))
print(f'Max error |DFT| vs analytical FT: {err:.4f}')
print(f'PASS - DFT approximates continuous FT: {err < 0.01}')
3. Properties of the DFT
3.2 Circular Shift Property
If , then . The circular shift multiplies each DFT coefficient by a phase factor.
Code cell 13
# === 3.2 Circular Shift Property ===
N = 32
x = np.zeros(N); x[:8] = 1.0 # rectangular pulse
m = 5 # shift by 5 samples
# Circular shift in time
y = np.roll(x, m)
X = np.fft.fft(x)
Y = np.fft.fft(y)
# Expected: Y[k] = omega_N^{-m*k} * X[k]
k = np.arange(N)
omega_N = np.exp(2j * np.pi / N)
Y_predicted = omega_N**(-m * k) * X
ok = np.allclose(Y, Y_predicted, atol=1e-12)
print(f'PASS - Circular shift property: {ok}')
print(f'Max error: {np.max(np.abs(Y - Y_predicted)):.2e}')
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
n = np.arange(N)
axes[0].stem(n, x, linefmt='C0-', markerfmt='o',
basefmt='k-', label='x[n]')
axes[0].stem(n, y, linefmt='C1-', markerfmt='s',
basefmt='k-', label=f'y[n]=x[n-{m}]')
axes[0].set_title('Circular shift in time domain')
axes[0].set_xlabel('n'); axes[0].set_ylabel('Amplitude')
axes[0].legend()
axes[1].plot(k, np.abs(X), color=COLORS['primary'], label='|X[k]|')
axes[1].plot(k, np.abs(Y), '--', color=COLORS['secondary'], label='|Y[k]|')
axes[1].set_title('Magnitudes unchanged by circular shift')
axes[1].set_xlabel('k'); axes[1].set_ylabel('|DFT|')
axes[1].legend()
fig.tight_layout(); plt.show()
Code cell 14
# === 3.4 Conjugate Symmetry for Real Inputs ===
N = 16
x_real = np.random.randn(N) # real-valued signal
X = np.fft.fft(x_real)
print('Conjugate symmetry check: X[N-k] == X[k].conj()')
for k in range(1, N//2):
diff = abs(X[N-k] - X[k].conj())
if diff > 1e-12:
print(f' FAIL at k={k}: diff={diff:.2e}')
ok = np.allclose(X[1:N//2], X[N-1:N//2:-1].conj(), atol=1e-12)
print(f'PASS - Conjugate symmetry X[N-k]=X[k]*: {ok}')
# DC and Nyquist are always real
print(f'X[0] (DC): {X[0].real:.4f} + {X[0].imag:.4f}i (should be real)')
print(f'X[N/2] (Nyquist): {X[N//2].real:.4f} + {X[N//2].imag:.4f}i (should be real)')
print(f'Only {N//2+1} independent values out of {N} (rfft would return {N//2+1})')
Code cell 15
# === 3.5 Parseval's Theorem ===
np.random.seed(0)
for N in [8, 64, 512, 4096]:
x = np.random.randn(N)
X = np.fft.fft(x)
lhs = np.sum(np.abs(x)**2)
rhs = np.sum(np.abs(X)**2) / N
ok = np.isclose(lhs, rhs, rtol=1e-10)
print(f'N={N:5d}: sum|x|^2={lhs:.4f}, sum|X|^2/N={rhs:.4f}, '
f'error={abs(lhs-rhs):.2e} {"PASS" if ok else "FAIL"}')
4. The Fast Fourier Transform Algorithm
4.1-4.2 Cooley-Tukey: Recursive DFT
The key identity splits an -point DFT into two -point DFTs:
where = DFT of even-indexed samples, = DFT of odd-indexed samples.
Code cell 17
# === 4.1 Naive DFT vs FFT Complexity ===
import time
def naive_dft(x):
N = len(x)
k = np.arange(N)[:, None]
n = np.arange(N)[None, :]
W = np.exp(-2j * np.pi * k * n / N)
return W @ x
print('N Naive DFT (ms) NumPy FFT (ms) Speedup')
print('-' * 55)
for N in [64, 256, 1024, 4096]:
x = np.random.randn(N).astype(complex)
if N <= 1024:
t0 = time.perf_counter()
for _ in range(10): naive_dft(x)
t_naive = (time.perf_counter() - t0) / 10 * 1000
else:
t_naive = float('inf')
t0 = time.perf_counter()
for _ in range(100): np.fft.fft(x)
t_fft = (time.perf_counter() - t0) / 100 * 1000
speedup = t_naive/t_fft if t_naive < float('inf') else 'N/A'
speedup_str = f'{speedup:.0f}x' if isinstance(speedup, float) else speedup
print(f"{N:5d} {t_naive:10.3f} {t_fft:12.4f} {speedup_str:>8}")
Code cell 18
# === 4.2-4.6 Cooley-Tukey FFT from Scratch ===
def fft_recursive(x):
# Radix-2 Cooley-Tukey FFT (recursive)
N = len(x)
if N == 1:
return x
if N % 2 != 0:
raise ValueError(f'N must be power of 2, got {N}')
# Split even/odd
even = fft_recursive(x[0::2]) # DFT of even-indexed samples
odd = fft_recursive(x[1::2]) # DFT of odd-indexed samples
# Twiddle factors
k = np.arange(N // 2)
twiddle = np.exp(-2j * np.pi * k / N)
# Butterfly combine
X = np.empty(N, dtype=complex)
X[:N//2] = even + twiddle * odd
X[N//2:] = even - twiddle * odd
return X
# Test on various sizes
print('Testing FFT from scratch vs NumPy:')
for N in [8, 64, 512, 4096]:
x = np.random.randn(N) + 1j * np.random.randn(N)
X_scratch = fft_recursive(x)
X_numpy = np.fft.fft(x)
err = np.max(np.abs(X_scratch - X_numpy))
ok = err < 1e-10
print(f' N={N:5d}: max error = {err:.2e} {"PASS" if ok else "FAIL"}')
Code cell 19
# === 4.3 Butterfly Diagram: 8-point DFT ===
# Visualize the 3 stages of the 8-point Cooley-Tukey FFT
fig, ax = plt.subplots(figsize=(12, 6))
ax.set_xlim(-0.5, 4.5)
ax.set_ylim(-0.5, 8.5)
ax.axis('off')
ax.set_title('8-point DIT-FFT Butterfly Diagram', fontsize=14)
N_pts = 8
stages = 3
bit_rev = [0, 4, 2, 6, 1, 5, 3, 7] # bit-reversed order
# Draw nodes
for stage in range(stages + 1):
for pos in range(N_pts):
ax.plot(stage, N_pts - 1 - pos, 'o',
color=COLORS['primary'], ms=10, zorder=3)
# Input labels
for pos in range(N_pts):
ax.text(-0.3, N_pts - 1 - pos, f'x[{bit_rev[pos]}]',
ha='right', va='center', fontsize=10)
# Output labels
for pos in range(N_pts):
ax.text(stages + 0.3, N_pts - 1 - pos, f'X[{pos}]',
ha='left', va='center', fontsize=10)
# Stage connections and twiddle factors
butterfly_groups = [
[(0,4),(1,5),(2,6),(3,7)], # stage 1: stride=4
[(0,2),(1,3),(4,6),(5,7)], # stage 2: stride=2
[(0,1),(2,3),(4,5),(6,7)], # stage 3: stride=1
]
twiddles = [
['W^0','W^0','W^0','W^0'],
['W^0','W^1','W^0','W^1'],
['W^0','W^1','W^2','W^3'],
]
for s, (butterflies, tws) in enumerate(zip(butterfly_groups, twiddles)):
for bi, ((top, bot), tw) in enumerate(zip(butterflies, tws)):
y_top = N_pts - 1 - top
y_bot = N_pts - 1 - bot
# Straight line (top)
ax.plot([s, s+1], [y_top, y_top], '-',
color=COLORS['primary'], lw=1.5)
# Crossing line (bottom input to top output)
ax.plot([s, s+1], [y_bot, y_top], '-',
color=COLORS['secondary'], lw=1.0, alpha=0.7)
# Straight line (bottom)
ax.plot([s, s+1], [y_bot, y_bot], '-',
color=COLORS['primary'], lw=1.5)
# Twiddle label
mid_x = s + 0.5
mid_y = (y_top + y_bot) / 2
ax.text(mid_x, mid_y, tw, ha='center', va='center',
fontsize=8, color=COLORS['error'],
bbox=dict(boxstyle='round,pad=0.2', fc='white', alpha=0.8))
# Stage labels
for s in range(stages):
ax.text(s + 0.5, -0.3, f'Stage {s+1}', ha='center', va='top', fontsize=10)
fig.tight_layout()
plt.show()
print('8-point DIT-FFT: 3 stages, 4 butterflies per stage = 12 complex multiplications')
print(f'vs naive DFT: {8**2} = 64 complex multiplications')
print(f'Speedup: {64/12:.1f}x')
Code cell 20
# === 4.5 Bit-Reversal Permutation ===
def bit_reverse(n, num_bits):
result = 0
for _ in range(num_bits):
result = (result << 1) | (n & 1)
n >>= 1
return result
N = 8
m = int(np.log2(N)) # number of bits
print(f'Bit-reversal permutation for N={N} ({m} bits):')
print(f"{"Index":>8} {"Binary":>8} {"Bit-reversed":>14} {"Dec (br)":>10}")
print('-' * 44)
for i in range(N):
br = bit_reverse(i, m)
print(f"{i:>8} {bin(i)[2:].zfill(m):>8} {bin(br)[2:].zfill(m):>14} {br:>10}")
# Verify: applying DFT to bit-reversed input matches standard DFT with reordered output
x = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=complex)
perm = [bit_reverse(i, m) for i in range(N)]
x_br = x[perm]
X_normal = np.fft.fft(x)
print(f'\nBit-reversal permutation indices: {perm}')
Code cell 21
# === 4.6 O(N log N) Complexity Validation ===
import time
sizes = [2**k for k in range(6, 22)]
times = []
for N in sizes:
x = np.random.randn(N).astype(np.float64)
# Warm up
np.fft.rfft(x)
reps = max(3, 1000 // N) if N <= 10000 else 3
t0 = time.perf_counter()
for _ in range(reps):
np.fft.rfft(x)
times.append((time.perf_counter() - t0) / reps)
# Fit N log N model
log_N = np.log2(sizes)
log_T = np.log2(times)
coeffs = np.polyfit(log_N, log_T, 1)
print(f'Empirical power law exponent: {coeffs[0]:.3f} (expected ~1.0 for O(N log N))')
fig, ax = plt.subplots(figsize=(10, 5))
ax.loglog(sizes, times, 'o-', color=COLORS['primary'], label='NumPy rfft')
# Theoretical O(N log2 N) reference
ref = np.array(sizes) * np.log2(sizes) * times[4] / (sizes[4] * np.log2(sizes[4]))
ax.loglog(sizes, ref, '--', color=COLORS['neutral'],
label=r'$O(N\log_2 N)$ reference')
ax.set_xlabel('FFT size N')
ax.set_ylabel('Time (seconds)')
ax.set_title('NumPy FFT timing: verifying O(N log N) complexity')
ax.legend()
fig.tight_layout()
plt.show()
5. Spectral Leakage and Windowing
5.1 The Leakage Problem
When a sinusoid's frequency does not align exactly with a DFT bin, energy 'leaks' from the true frequency into all other bins. This occurs because the DFT implicitly assumes the signal is periodic with period . Truncation is equivalent to multiplication by a rectangular window, whose DFT (the Dirichlet kernel) has large sidelobes.
Code cell 23
# === 5.1 Spectral Leakage Demo ===
N = 64
n = np.arange(N)
# Integer frequency (bin-aligned) vs non-integer frequency
f_int = 5.0 # exactly 5 cycles in N=64 -> no leakage
f_frac = 5.5 # 5.5 cycles -> leakage
x_int = np.cos(2 * np.pi * f_int * n / N)
x_frac = np.cos(2 * np.pi * f_frac * n / N)
X_int = np.fft.fft(x_int)
X_frac = np.fft.fft(x_frac)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
k = np.arange(N)
axes[0].stem(k[:N//2+1], np.abs(X_int[:N//2+1]) / N * 2,
linefmt='C0-', markerfmt='o', basefmt='k-')
axes[0].set_title(f'Bin-aligned: f={f_int} (no leakage)')
axes[0].set_xlabel('Frequency bin k')
axes[0].set_ylabel('Amplitude')
axes[0].axvline(f_int, color=COLORS['error'], ls='--', label='True freq')
axes[0].legend()
axes[1].stem(k[:N//2+1], np.abs(X_frac[:N//2+1]) / N * 2,
linefmt='C1-', markerfmt='o', basefmt='k-')
axes[1].set_title(f'Non-bin-aligned: f={f_frac} (leakage visible)')
axes[1].set_xlabel('Frequency bin k')
axes[1].set_ylabel('Amplitude')
axes[1].axvline(f_frac, color=COLORS['error'], ls='--', label='True freq')
axes[1].legend()
fig.tight_layout(); plt.show()
# Quantify leakage
main_lobe_bins = {4, 5, 6} # bins near f_frac=5.5
total_power = np.sum(np.abs(X_frac)**2)
main_power = sum(np.abs(X_frac[b])**2 for b in main_lobe_bins)
leakage_frac = 1 - main_power / total_power
print(f'Leakage fraction (energy outside bins 4-6): {leakage_frac:.1%}')
Code cell 24
# === 5.2 Window Functions ===
N = 64
n = np.arange(N)
def hann_window(N):
n = np.arange(N)
return 0.5 * (1 - np.cos(2*np.pi*n/(N-1)))
def hamming_window(N):
n = np.arange(N)
return 0.54 - 0.46*np.cos(2*np.pi*n/(N-1))
def blackman_window(N):
n = np.arange(N)
return 0.42 - 0.5*np.cos(2*np.pi*n/(N-1)) + 0.08*np.cos(4*np.pi*n/(N-1))
def kaiser_window(N, beta=8.6):
n = np.arange(N)
x = beta * np.sqrt(1 - (2*n/(N-1) - 1)**2)
return bessel_i0(x) / bessel_i0(beta)
windows = {
'Rectangular': np.ones(N),
'Hann': hann_window(N),
'Hamming': hamming_window(N),
'Blackman': blackman_window(N),
'Kaiser(8.6)': kaiser_window(N, 8.6),
}
colors_list = [COLORS['neutral'], COLORS['primary'], COLORS['secondary'],
COLORS['tertiary'], COLORS['highlight']]
# Compute frequency responses (zero-padded for smooth curve)
N_pad = 1024
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
print(f"{'Window':<15} {'Coherent gain':>15} {'Peak sidelobe (dB)':>20}")
print('-' * 52)
for (name, w), c in zip(windows.items(), colors_list):
# Time domain
axes[0].plot(np.arange(N), w, color=c, label=name, alpha=0.85)
# Frequency domain (zero-padded, dB)
W = np.fft.fft(w, n=N_pad)
W_db = 20*np.log10(np.abs(np.fft.fftshift(W)) / np.max(np.abs(W)) + 1e-15)
freqs_norm = np.fft.fftshift(np.fft.fftfreq(N_pad))
axes[1].plot(freqs_norm, W_db, color=c, label=name, alpha=0.85)
# Metrics
cg = np.sum(w) / N
main_peak = np.max(np.abs(W))
# Find max sidelobe (outside main lobe, which is first 4 bins)
W_side = np.abs(W).copy()
W_side[:4] = 0; W_side[-4:] = 0
peak_sl = 20*np.log10(np.max(W_side) / main_peak + 1e-15)
print(f"{name:<15} {cg:>15.4f} {peak_sl:>20.1f}")
axes[0].set_title('Window functions (time domain)')
axes[0].set_xlabel('Sample n'); axes[0].set_ylabel('w[n]')
axes[0].legend(fontsize=9)
axes[1].set_xlim(-0.4, 0.4); axes[1].set_ylim(-100, 5)
axes[1].set_title('Window spectral responses (dB)')
axes[1].set_xlabel('Normalized frequency'); axes[1].set_ylabel('dB')
axes[1].axhline(-60, color='k', ls=':', lw=0.8)
axes[1].legend(fontsize=9)
fig.tight_layout(); plt.show()
7. Zero-Padding and Frequency Resolution
Zero-padding appends zeros to the signal before computing the DFT. It interpolates the spectrum — producing more DFT bins — but does not improve the fundamental frequency resolution .
Key insight: Resolution is determined by the observation window . More zeros = finer interpolation; longer observation = true resolution improvement.
Code cell 26
# Zero-padding: interpolation vs resolution
fs = 100.0 # Hz
N_orig = 32
t = np.arange(N_orig) / fs
f_true = 11.3 # Hz (deliberately non-bin-aligned)
x = np.cos(2 * np.pi * f_true * t)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle('Zero-padding interpolates the spectrum but does not improve resolution',
fontsize=14)
for ax, N_pad in zip(axes, [32, 128, 512]):
xp = np.zeros(N_pad)
xp[:N_orig] = x
X = np.fft.rfft(xp)
freqs = np.fft.rfftfreq(N_pad, 1/fs)
ax.plot(freqs, np.abs(X), color=COLORS['primary'], linewidth=1.5)
ax.axvline(f_true, color=COLORS['error'], linestyle='--', linewidth=1.2,
label=f'True freq {f_true} Hz')
ax.set_title(f'N_pad = {N_pad} (N_orig = {N_orig})')
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('|X[k]|')
ax.legend(fontsize=9)
ax.set_xlim(0, 30)
fig.tight_layout()
plt.show()
print(f'Resolution with N_orig={N_orig}: Δf = {fs/N_orig:.2f} Hz (unchanged by zero-padding)')
8. Sampling, Nyquist, and Aliasing
Nyquist-Shannon Sampling Theorem: A signal bandlimited to Hz can be perfectly reconstructed from samples taken at .
Aliasing occurs when : high-frequency components fold into lower frequencies. The alias of a sinusoid at frequency sampled at appears at:
This follows from the Poisson summation formula:
Code cell 28
# Aliasing demonstration
fs_high = 1000.0 # adequate sampling rate
fs_low = 30.0 # below Nyquist for 11 Hz component
f_signal = 11.0 # Hz
t_cont = np.linspace(0, 1, 5000, endpoint=False)
x_cont = np.sin(2 * np.pi * f_signal * t_cont)
t_high = np.arange(0, 1, 1/fs_high)
t_low = np.arange(0, 1, 1/fs_low)
x_high = np.sin(2 * np.pi * f_signal * t_high)
x_low = np.sin(2 * np.pi * f_signal * t_low)
# The alias frequency
f_alias = f_signal - round(f_signal / fs_low) * fs_low
print(f'Signal: {f_signal} Hz | fs_low: {fs_low} Hz | Alias at: {f_alias} Hz')
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle('Aliasing: high-frequency signal appears at wrong frequency', fontsize=14)
# Time-domain
ax = axes[0]
ax.plot(t_cont[:200], x_cont[:200], color=COLORS['neutral'], alpha=0.5, label='Continuous 11 Hz')
ax.stem(t_low[:8], x_low[:8], linefmt='C3--',
markerfmt='o', basefmt=' ', label=f'Samples at fs={fs_low} Hz')
# Alias waveform
x_alias = np.sin(2 * np.pi * abs(f_alias) * t_cont)
ax.plot(t_cont[:200], x_alias[:200], color=COLORS['error'], linewidth=2,
label=f'Alias: {abs(f_alias):.0f} Hz')
ax.set_title('Time domain')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Amplitude')
ax.legend(fontsize=9)
ax.set_xlim(0, t_cont[200])
# Frequency domain — DFT of low-sampled signal
ax = axes[1]
X_low = np.fft.rfft(x_low)
freqs_low = np.fft.rfftfreq(len(x_low), 1/fs_low)
ax.stem(freqs_low, np.abs(X_low), linefmt='C0-',
markerfmt='o', basefmt=' ')
ax.axvline(f_signal, color=COLORS['error'], linestyle='--', label=f'True {f_signal} Hz')
ax.axvline(abs(f_alias), color=COLORS['highlight'], linestyle=':',
label=f'Alias {abs(f_alias):.0f} Hz')
ax.set_title(f'DFT at fs={fs_low} Hz (Nyquist = {fs_low/2} Hz)')
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('|X[k]|')
ax.legend(fontsize=9)
fig.tight_layout()
plt.show()
9. Short-Time Fourier Transform (STFT)
The STFT slides a window along the signal and computes the DFT at each position:
where is the hop size (stride) and is the window length.
| Parameter | Effect |
|---|---|
| Larger window | Better frequency resolution, worse time resolution |
| Smaller | Better time resolution, worse frequency resolution |
| Hop size | Controls overlap; gives overlap-add reconstruction |
COLA condition (Constant Overlap-Add): for all → ensures perfect reconstruction via overlap-add synthesis.
Code cell 30
# STFT from scratch
def stft(x, N=256, H=128, window='hann'):
if window == 'hann':
w = np.hanning(N)
elif window == 'hamming':
w = np.hamming(N)
else: # rectangular
w = np.ones(N)
frames = []
n_frames = 1 + (len(x) - N) // H
for l in range(n_frames):
frame = x[l*H : l*H + N] * w
frames.append(np.fft.rfft(frame))
return np.array(frames).T # shape: (N//2+1, n_frames)
# Test signal: chirp (linearly increasing frequency)
fs = 8000
duration = 1.0
t = np.linspace(0, duration, int(fs * duration), endpoint=False)
x_chirp = np.sin(2 * np.pi * (200 * t + 1500 * t**2))
# Compute STFT
N_win = 256
H_hop = 64
S = stft(x_chirp, N=N_win, H=H_hop, window='hann')
# Plot spectrogram
fig, ax = plt.subplots(figsize=(12, 5))
times = np.arange(S.shape[1]) * H_hop / fs
freqs = np.fft.rfftfreq(N_win, 1/fs)
im = ax.pcolormesh(times, freqs, 20*np.log10(np.abs(S) + 1e-10),
cmap='viridis', shading='auto')
fig.colorbar(im, ax=ax, label='Power (dB)')
ax.set_title('STFT spectrogram of a linear chirp (200 → 1700 Hz)')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Frequency (Hz)')
fig.tight_layout()
plt.show()
Code cell 31
# Time-frequency trade-off: three window sizes on the same chirp
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle('STFT: time-frequency resolution trade-off', fontsize=14)
for ax, (N_w, H_w) in zip(axes, [(64, 16), (256, 64), (1024, 256)]):
S_w = stft(x_chirp, N=N_w, H=H_w, window='hann')
t_ax = np.arange(S_w.shape[1]) * H_w / fs
f_ax = np.fft.rfftfreq(N_w, 1/fs)
ax.pcolormesh(t_ax, f_ax, 20*np.log10(np.abs(S_w) + 1e-10),
cmap='viridis', shading='auto')
ax.set_title(f'Window N={N_w}')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Frequency (Hz)')
ax.set_ylim(0, 4000)
fig.tight_layout()
plt.show()
print('Short window → good time res, blurry freq res')
print('Long window → good freq res, blurry time res')
Code cell 32
# Verify COLA condition for Hann window with 50% overlap
N_win = 256
H_hop = N_win // 2 # 50% overlap
w = np.hanning(N_win)
signal_len = 2048
cola_sum = np.zeros(signal_len)
for l in range((signal_len - N_win) // H_hop + 1):
start = l * H_hop
cola_sum[start:start + N_win] += w
interior = cola_sum[N_win:-N_win] # exclude edges
print(f'COLA sum: min={interior.min():.6f}, max={interior.max():.6f}')
print(f'Constant overlap-add satisfied: {np.allclose(interior, interior[0])}')
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(cola_sum, color=COLORS['primary'], linewidth=1.5)
ax.axhline(interior.mean(), color=COLORS['error'], linestyle='--',
label=f'COLA constant = {interior.mean():.3f}')
ax.set_title('COLA (Constant Overlap-Add) verification: Hann window, 50% overlap')
ax.set_xlabel('Sample index')
ax.set_ylabel('Sum of window values')
ax.legend()
fig.tight_layout()
plt.show()
10. Case Study: Whisper's Mel Spectrogram Pipeline
OpenAI's Whisper (Radford et al. 2022) pre-processes audio with a fixed pipeline:
| Step | Value | Rationale |
|---|---|---|
| Sample rate | 16 000 Hz | Telephone-quality speech |
| Window length | 400 samples = 25 ms | Stationary speech frame |
| Hop size | 160 samples = 10 ms | Standard 10 ms frame shift |
| FFT size | 512 points | Zero-pad 400→512 (next power of 2) |
| Mel bins | 80 | Perceptual resolution |
| Frequency range | 0–8 000 Hz | Below Nyquist (8 000 Hz) |
| Compression | , clipped to | Dynamic range compression |
The mel scale maps linear frequency to mel , mimicking the logarithmic frequency perception of the human cochlea.
Code cell 34
# Whisper mel filterbank from scratch
def hz_to_mel(f):
return 2595.0 * np.log10(1.0 + f / 700.0)
def mel_to_hz(m):
return 700.0 * (10.0 ** (m / 2595.0) - 1.0)
def mel_filterbank(n_mels=80, n_fft=512, fs=16000, f_min=0, f_max=8000):
n_freq = n_fft // 2 + 1
m_min, m_max = hz_to_mel(f_min), hz_to_mel(f_max)
mel_pts = np.linspace(m_min, m_max, n_mels + 2)
freq_pts = mel_to_hz(mel_pts)
bins = np.floor((n_fft + 1) * freq_pts / fs).astype(int)
fbank = np.zeros((n_mels, n_freq))
for m in range(1, n_mels + 1):
f_m_minus = bins[m - 1]
f_m = bins[m]
f_m_plus = bins[m + 1]
for k in range(f_m_minus, f_m):
fbank[m-1, k] = (k - f_m_minus) / max(f_m - f_m_minus, 1)
for k in range(f_m, f_m_plus):
fbank[m-1, k] = (f_m_plus - k) / max(f_m_plus - f_m, 1)
return fbank
fb = mel_filterbank()
freqs = np.fft.rfftfreq(512, 1/16000)
fig, axes = plt.subplots(2, 1, figsize=(12, 8))
ax = axes[0]
for i in range(0, 80, 5):
ax.plot(freqs, fb[i], alpha=0.7, linewidth=1.2)
ax.set_title('80 mel filterbank triangular filters (0–8000 Hz)')
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Filter weight')
ax = axes[1]
im = ax.imshow(fb, aspect='auto', origin='lower',
extent=[freqs[0], freqs[-1], 0, 80], cmap='viridis')
fig.colorbar(im, ax=ax, label='Filter weight')
ax.set_title('Mel filterbank matrix (80 × 257)')
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Mel bin')
fig.tight_layout()
plt.show()
Code cell 35
# Simulate the full Whisper preprocessing pipeline
np.random.seed(42)
fs_whisper = 16000
# Synthetic speech-like signal: sum of harmonics with amplitude envelope
t_w = np.linspace(0, 3.0, 3 * fs_whisper, endpoint=False)
x_speech = np.zeros_like(t_w)
for k_harm, (f0, amp) in enumerate([(120, 0.6), (240, 0.3), (360, 0.15),
(480, 0.07), (960, 0.03)]):
x_speech += amp * np.sin(2 * np.pi * f0 * t_w)
# Amplitude envelope (vowel-like)
env = np.exp(-0.5 * ((t_w - 1.5) / 0.6)**2)
x_speech *= env
# Step 1: STFT (400-sample Hann, hop 160, FFT 512)
N_w, H_w, N_fft = 400, 160, 512
w_hann = np.hanning(N_w)
S_whisper = []
for l in range((len(x_speech) - N_w) // H_w + 1):
frame = np.zeros(N_fft)
frame[:N_w] = x_speech[l*H_w : l*H_w + N_w] * w_hann
S_whisper.append(np.fft.rfft(frame))
S_whisper = np.array(S_whisper).T # (257, n_frames)
# Step 2: Power spectrum
power = np.abs(S_whisper)**2
# Step 3: Apply mel filterbank
fb_w = mel_filterbank(n_mels=80, n_fft=N_fft, fs=fs_whisper)
mel_spec = fb_w @ power # (80, n_frames)
# Step 4: Log compression + clipping
log_mel = np.log10(np.maximum(mel_spec, 1e-10))
log_mel = np.maximum(log_mel, log_mel.max() - 8.0)
# Plot
frame_times = np.arange(log_mel.shape[1]) * H_w / fs_whisper
fig, axes = plt.subplots(2, 1, figsize=(13, 8))
fig.suptitle('Whisper mel spectrogram pipeline', fontsize=14)
im1 = axes[0].pcolormesh(frame_times,
np.fft.rfftfreq(N_fft, 1/fs_whisper),
20*np.log10(np.abs(S_whisper) + 1e-10),
cmap='viridis', shading='auto')
fig.colorbar(im1, ax=axes[0], label='dB')
axes[0].set_title('STFT spectrogram (linear frequency)')
axes[0].set_xlabel('Time (s)')
axes[0].set_ylabel('Frequency (Hz)')
im2 = axes[1].pcolormesh(frame_times, np.arange(80),
log_mel, cmap='viridis', shading='auto')
fig.colorbar(im2, ax=axes[1], label='Log power')
axes[1].set_title('Log-mel spectrogram (80 mel bins) — Whisper input')
axes[1].set_xlabel('Time (s)')
axes[1].set_ylabel('Mel bin')
fig.tight_layout()
plt.show()
print(f'Log-mel shape: {log_mel.shape} (80 mel bins × {log_mel.shape[1]} frames)')
11. Fourier Neural Operator (FNO)
The Fourier Neural Operator (Li et al. 2021) solves parametric PDEs in by learning in the frequency domain.
Spectral convolution layer:
where is a learnable weight tensor truncated to the lowest Fourier modes.
Key insight: Truncating to modes acts as a low-pass filter that captures the smooth, long-range structure of PDE solutions while discarding fine noise.
Code cell 37
# FNO spectral convolution layer (1D, NumPy/pure-Python implementation)
class SpectralConv1d:
def __init__(self, in_channels, out_channels, n_modes, rng=None):
if rng is None:
rng = np.random.default_rng(42)
scale = 1.0 / (in_channels * out_channels)
self.weights_real = rng.uniform(-scale, scale,
(in_channels, out_channels, n_modes))
self.weights_imag = rng.uniform(-scale, scale,
(in_channels, out_channels, n_modes))
self.weights = self.weights_real + 1j * self.weights_imag
self.n_modes = n_modes
def forward(self, x):
# x: (batch, in_channels, N)
batch, in_ch, N = x.shape
out_ch = self.weights.shape[1]
# FFT
x_ft = np.fft.rfft(x, axis=-1) # (batch, in_ch, N//2+1)
# Truncate to n_modes and multiply by learnable weights
out_ft = np.zeros((batch, out_ch, N//2+1), dtype=complex)
for b in range(batch):
# x_ft[b]: (in_ch, n_modes)
out_ft[b, :, :self.n_modes] = np.einsum(
'ix,iox->ox', x_ft[b, :, :self.n_modes], self.weights
)
# IFFT back to spatial domain
return np.fft.irfft(out_ft, n=N, axis=-1) # (batch, out_ch, N)
# Demo: 1D FNO spectral layer
N_grid = 64
K_modes = 12
np.random.seed(0)
x_in = np.random.randn(4, 3, N_grid) # batch=4, in_ch=3
layer = SpectralConv1d(in_channels=3, out_channels=8, n_modes=K_modes)
x_out = layer.forward(x_in)
print(f'Input: {x_in.shape}')
print(f'Output: {x_out.shape}')
print(f'Kept {K_modes}/{N_grid//2+1} Fourier modes ({100*K_modes/(N_grid//2+1):.0f}% of spectrum)')
# Visualize: input vs output for one channel
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
ax = axes[0]
for i in range(4):
ax.plot(x_in[i, 0, :], alpha=0.7, linewidth=1.5,
label=f'batch {i}')
ax.set_title(f'Input (in_ch=0)')
ax.set_xlabel('Grid point $x$')
ax.set_ylabel('Amplitude')
ax.legend(fontsize=9)
ax = axes[1]
for i in range(4):
ax.plot(x_out[i, 0, :], alpha=0.7, linewidth=1.5,
label=f'batch {i}')
ax.set_title(f'Output after spectral conv (out_ch=0, K={K_modes} modes)')
ax.set_xlabel('Grid point $x$')
ax.set_ylabel('Amplitude')
ax.legend(fontsize=9)
fig.tight_layout()
plt.show()
Code cell 38
# Effect of mode truncation: how many modes do you need?
# Test signal: smooth + noisy
N_grid = 128
x_grid = np.linspace(0, 2*np.pi, N_grid, endpoint=False)
u_smooth = np.sin(x_grid) + 0.5 * np.cos(3 * x_grid)
u_noisy = u_smooth + 0.3 * np.random.randn(N_grid)
U = np.fft.rfft(u_noisy)
freqs_grid = np.arange(len(U))
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle('FNO mode truncation: reconstruction quality vs K modes', fontsize=13)
ax = axes[0]
ax.plot(np.abs(U), color=COLORS['primary'], linewidth=2, label='|U[k]|')
for K_cut, col in [(4, COLORS['error']), (12, COLORS['tertiary']), (24, COLORS['secondary'])]:
ax.axvline(K_cut, color=col, linestyle='--', alpha=0.8, label=f'K={K_cut}')
ax.set_title('DFT magnitude — where to truncate?')
ax.set_xlabel('Frequency mode $k$')
ax.set_ylabel('|U[k]|')
ax.legend(fontsize=9)
ax.set_yscale('log')
ax = axes[1]
ax.plot(x_grid, u_smooth, color=COLORS['neutral'], linewidth=2, linestyle='--',
label='True smooth signal')
ax.plot(x_grid, u_noisy, color=COLORS['neutral'], alpha=0.3, linewidth=1,
label='Noisy input')
for K_cut, col in [(4, COLORS['error']), (12, COLORS['tertiary']), (24, COLORS['secondary'])]:
U_trunc = np.zeros_like(U)
U_trunc[:K_cut] = U[:K_cut]
u_rec = np.fft.irfft(U_trunc, n=N_grid)
err = np.mean((u_rec - u_smooth)**2)
ax.plot(x_grid, u_rec, color=col, linewidth=1.8,
label=f'K={K_cut} (MSE={err:.4f})')
ax.set_title('Reconstructions from K truncated modes')
ax.set_xlabel('$x$')
ax.set_ylabel('$u(x)$')
ax.legend(fontsize=9)
fig.tight_layout()
plt.show()
12. Monarch Matrices and Butterfly Factorizations
Monarch matrices (Dao et al. 2022) generalize butterfly matrices to a 2-factor product: where are block-diagonal, each with blocks of size .
Parameter count: vs for dense.
Key property: The DFT matrix is a butterfly matrix — it factors as where each is block-diagonal. This is precisely the Cooley-Tukey FFT factorization!
FlashFFTConv (Dao et al. 2023) uses Monarch structure to implement long convolutions () via state space models with hardware-efficient operations on GPU.
Code cell 40
# Visualize butterfly structure of the DFT matrix
N_demo = 8
omega = np.exp(2j * np.pi / N_demo)
F8 = np.array([[omega**(k*n) for n in range(N_demo)]
for k in range(N_demo)]) / np.sqrt(N_demo)
# The DFT matrix sparsity pattern (thresholded |entry| > 0.01)
# For comparison: the butterfly factors B1, B2, B3
def butterfly_factor(N, stage):
# Stage = 0, 1, ..., log2(N)-1
# Returns a (N, N) sparse matrix representing one butterfly stage
B = np.zeros((N, N), dtype=complex)
stride = 2**(stage + 1)
half = stride // 2
for start in range(0, N, stride):
for j in range(half):
twiddle = np.exp(-2j * np.pi * j / stride)
B[start + j, start + j] = 1.0
B[start + j, start + j + half] = twiddle
B[start + j + half, start + j] = 1.0
B[start + j + half, start + j + half] = -twiddle
return B / 2 # normalize
log2_N = int(np.log2(N_demo))
fig, axes = plt.subplots(1, log2_N + 1, figsize=(16, 4))
fig.suptitle(f'Butterfly factorization of F_{N_demo}: F = B0 * B1 * B2', fontsize=13)
# DFT matrix
axes[0].imshow(np.abs(F8), cmap='viridis', vmin=0)
axes[0].set_title(f'|F_{N_demo}| (full DFT)')
axes[0].set_xlabel('Column')
axes[0].set_ylabel('Row')
# Butterfly stages
for stage in range(log2_N):
B = butterfly_factor(N_demo, stage)
axes[stage + 1].imshow(np.abs(B), cmap='viridis', vmin=0)
axes[stage + 1].set_title(f'|B{stage}| (stage {stage})')
axes[stage + 1].set_xlabel('Column')
axes[stage + 1].set_ylabel('Row')
fig.tight_layout()
plt.show()
print(f'Each butterfly stage has O(N) non-zeros; log2(N)={log2_N} stages total')
print(f'Total non-zeros in factored form: {N_demo * log2_N} vs {N_demo**2} in full matrix')
13. Looking Ahead: DFT → Convolution Theorem
We have established the DFT as a unitary transform with FFT computation. The next section (§04 Convolution Theorem) builds on this foundation:
| DFT result (§03) | What §04 does with it |
|---|---|
| DFT is a ring isomorphism on | Circular convolution pointwise multiplication |
| FFT reduces DFT to | Reduces convolution to |
| Spectral leakage, windowing | Applied in filter design |
| STFT frames | Overlap-add = convolution in disguise |
| FNO spectral layer | Full treatment in §04 (CNNs, S4, Mamba) |
Core equation to carry forward:
The DFT turns the circular convolution sum into operations.
Code cell 42
# Summary: key results and their parameter dependencies
print('=== Chapter §03 Summary ===')
print()
print('DFT definition: X[k] = sum_{n=0}^{N-1} x[n] * exp(-2pi*i*n*k/N)')
print('IDFT: x[n] = (1/N) sum_{k=0}^{N-1} X[k] * exp(2pi*i*n*k/N)')
print()
print('FFT complexity: O(N log N) vs O(N^2) naive')
print('Frequency resolution: Δf = fs / N (CANNOT be improved by zero-padding)')
print('Nyquist criterion: fs >= 2 * f_max (or aliases occur)')
print()
print('Window sidelobe levels:')
windows = [('Rectangular', -13), ('Hann', -31.5), ('Hamming', -42.7),
('Blackman', -58), ('Kaiser beta=8.6', -69)]
for name, sl in windows:
print(f' {name:<20} {sl:>6} dB')
print()
print('Whisper pipeline: 16kHz, 400-sample Hann, hop 160, FFT 512, 80 mel bins')
print('FNO: keeps K lowest modes (K << N//2+1)')
print('Monarch: O(N sqrt(N)) params vs O(N^2) dense')
14. Two-Dimensional DFT
The 2D DFT extends naturally to images and 2D fields:
It decomposes a 2D signal into sinusoidal spatial frequency components.
Low frequencies concentrate near the center (after fftshift); high frequencies
lie at the edges.
Applications: Image compression (JPEG uses DCT, a DFT variant), CNNs in the frequency domain, MRI reconstruction (k-space = 2D Fourier space of tissue).
Code cell 44
# 2D DFT: image frequency analysis
from matplotlib.colors import LogNorm
N1, N2 = 128, 128
x1, x2 = np.meshgrid(np.arange(N1), np.arange(N2), indexing='ij')
img = (np.sin(2*np.pi*4*x1/N1) * np.cos(2*np.pi*6*x2/N2)
+ 0.5 * np.sin(2*np.pi*12*x1/N1)
+ 0.3 * np.random.randn(N1, N2))
IMG = np.fft.fft2(img)
IMG_shifted = np.fft.fftshift(IMG)
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle('2D DFT: image to frequency domain', fontsize=14)
axes[0].imshow(img, cmap='viridis', aspect='auto')
axes[0].set_title('Input image')
axes[0].set_xlabel('$x_2$'); axes[0].set_ylabel('$x_1$')
axes[1].imshow(np.abs(IMG_shifted), cmap='viridis', aspect='auto',
norm=LogNorm(vmin=1))
axes[1].set_title('|2D DFT| (log scale, DC at center)')
axes[1].set_xlabel('$k_2$ (shifted)'); axes[1].set_ylabel('$k_1$ (shifted)')
# Low-pass filter: keep central r modes
r = 10
mask = np.zeros((N1, N2), dtype=bool)
cx, cy = N1//2, N2//2
mask[cx-r:cx+r, cy-r:cy+r] = True
IMG_lp = IMG.copy()
IMG_lp[~np.fft.ifftshift(mask)] = 0
img_rec = np.real(np.fft.ifft2(IMG_lp))
axes[2].imshow(img_rec, cmap='viridis', aspect='auto')
axes[2].set_title(f'Low-pass filtered (r={r} modes)')
axes[2].set_xlabel('$x_2$'); axes[2].set_ylabel('$x_1$')
fig.tight_layout()
plt.show()
print(f'Kept {mask.sum()} of {N1*N2} Fourier modes ({100*mask.sum()/(N1*N2):.1f}%)')
15. Discrete Cosine Transform (DCT)
The DCT-II is defined as:
Why DCT over DFT for real signals?
| Property | DFT | DCT-II |
|---|---|---|
| Output | Complex (even for real input) | Real |
| Boundary conditions | Periodic | Even-symmetric |
| Energy compaction | Moderate | Excellent (Gibbs-free) |
| JPEG/MPEG blocks | — | Yes (8x8 blocks) |
The DCT achieves better energy compaction than DFT for smooth signals because it implicitly uses even extension -- eliminating the Gibbs phenomenon at boundaries.
Code cell 46
from scipy.fft import dct
N_dct = 64
t_dct = np.arange(N_dct)
x_dct = (np.sin(2*np.pi*3*t_dct/N_dct) + 0.5*np.cos(2*np.pi*7*t_dct/N_dct)
+ 0.2 * np.sin(2*np.pi*15*t_dct/N_dct))
X_dft = np.fft.rfft(x_dct)
X_dct2 = dct(x_dct, type=2, norm='ortho')
K_vals = np.arange(1, N_dct//2 + 1)
dft_efrac = [np.sum(np.abs(X_dft[:k])**2) / np.sum(np.abs(X_dft)**2) for k in K_vals]
dct_efrac = [np.sum(X_dct2[:k]**2) / np.sum(X_dct2**2) for k in K_vals]
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
axes[0].plot(np.abs(X_dft), color=COLORS['primary'], linewidth=2, label='|DFT|')
axes[0].plot(np.abs(X_dct2[:N_dct//2+1]), color=COLORS['secondary'],
linewidth=2, label='|DCT-II|')
axes[0].set_title('Coefficient magnitudes: DFT vs DCT-II')
axes[0].set_xlabel('Coefficient index $k$'); axes[0].set_ylabel('Magnitude')
axes[0].legend()
axes[1].plot(K_vals, dft_efrac, color=COLORS['primary'], linewidth=2, label='DFT')
axes[1].plot(K_vals, dct_efrac, color=COLORS['secondary'], linewidth=2, label='DCT-II')
axes[1].axhline(0.99, color=COLORS['neutral'], linestyle='--', alpha=0.7, label='99%')
axes[1].set_title('Cumulative energy: DCT compacts energy faster')
axes[1].set_xlabel('Number of coefficients $K$')
axes[1].set_ylabel('Fraction of total energy')
axes[1].legend(fontsize=9)
fig.tight_layout(); plt.show()
k99_dft = K_vals[next(i for i,e in enumerate(dft_efrac) if e >= 0.99)]
k99_dct = K_vals[next(i for i,e in enumerate(dct_efrac) if e >= 0.99)]
print(f'Coefficients for 99% energy -- DFT: {k99_dft}, DCT-II: {k99_dct}')
16. Number Theoretic Transform (NTT)
The NTT is the DFT over a finite field instead of :
where is a primitive -th root of unity modulo a prime .
Why it matters for AI:
- Fully Homomorphic Encryption (FHE): Privacy-preserving ML inference uses NTT-based polynomial multiplication in encrypted rings (CKKS, BFV schemes).
- Exact arithmetic: No floating-point rounding error.
- Same butterfly algorithm as Cooley-Tukey FFT.
Code cell 48
# Number Theoretic Transform: exact polynomial multiplication mod p
MOD = 998244353 # NTT-friendly prime: 119 * 2^23 + 1
G_ROOT = 3
def ntt(a, invert=False):
n = len(a)
a = list(a)
j = 0
for i in range(1, n):
bit = n >> 1
while j & bit:
j ^= bit; bit >>= 1
j ^= bit
if i < j: a[i], a[j] = a[j], a[i]
length = 2
while length <= n:
w = pow(G_ROOT, MOD - 1 - (MOD-1)//length if invert else (MOD-1)//length, MOD)
for i in range(0, n, length):
wn = 1
for k_ntt in range(length // 2):
u = a[i + k_ntt]
v = a[i + k_ntt + length//2] * wn % MOD
a[i + k_ntt] = (u + v) % MOD
a[i + k_ntt + length//2] = (u - v) % MOD
wn = wn * w % MOD
length <<= 1
if invert:
n_inv = pow(n, MOD - 2, MOD)
a = [x * n_inv % MOD for x in a]
return a
def poly_mul_ntt(p1, p2):
rlen = len(p1) + len(p2) - 1
n = 1
while n < rlen: n <<= 1
fa = ntt(p1 + [0]*(n-len(p1)))
fb = ntt(p2 + [0]*(n-len(p2)))
fc = [(a*b) % MOD for a,b in zip(fa, fb)]
return ntt(fc, invert=True)[:rlen]
p1, p2 = [1, 2, 3], [4, 5]
result = poly_mul_ntt(p1, p2)
expected = list(np.polymul(p1[::-1], p2[::-1])[::-1].astype(int))
print(f'p1 = {p1} (1 + 2x + 3x^2)')
print(f'p2 = {p2} (4 + 5x)')
print(f'NTT result: {result}')
print(f'Expected: {expected}')
print(f'Match: {result == expected}')
17. Magnitude and Phase Spectra
The DFT output has two components:
- Magnitude spectrum : how much of each frequency
- Phase spectrum : when (alignment) of each frequency
Phase is perceptually critical for speech: Scrambling phase while preserving magnitude renders speech unintelligible.
Group delay: measures time delay per frequency. Constant group delay = linear phase = no dispersion.
Code cell 50
# Swap magnitude and phase between two signals
np.random.seed(7)
N_ph = 256
t_ph = np.arange(N_ph) / 1000.0
sig_A = sum(np.sin(2*np.pi*k*80*t_ph + np.pi*k/4)/k for k in range(1, 8))
sig_B = np.random.randn(N_ph) * np.hanning(N_ph)
FA = np.fft.rfft(sig_A)
FB = np.fft.rfft(sig_B)
sig_AB = np.fft.irfft(np.abs(FA) * np.exp(1j * np.angle(FB)))
sig_BA = np.fft.irfft(np.abs(FB) * np.exp(1j * np.angle(FA)))
fig, axes = plt.subplots(2, 2, figsize=(13, 7))
fig.suptitle('Phase vs magnitude: swapping spectral components', fontsize=14)
for ax, (sig, title) in zip(axes.flatten(),
[(sig_A, 'Signal A (harmonic)'),
(sig_B, 'Signal B (noise burst)'),
(sig_AB, 'Magnitude(A) + Phase(B)'),
(sig_BA, 'Magnitude(B) + Phase(A)')]):
ax.plot(t_ph*1000, sig, color=COLORS['primary'], linewidth=1.5)
ax.set_title(title)
ax.set_xlabel('Time (ms)'); ax.set_ylabel('Amplitude')
fig.tight_layout(); plt.show()
print('Phase(A) + any magnitude -> character of A preserved.')
print('Magnitude alone does not determine signal structure.')
Code cell 51
# DFT as complete basis: partial reconstruction from K frequency pairs
N_rec = 64
t_rec = np.arange(N_rec)
x_orig = (1.5*np.sin(2*np.pi*2*t_rec/N_rec)
+ 0.8*np.cos(2*np.pi*5*t_rec/N_rec)
+ 0.4*np.sin(2*np.pi*11*t_rec/N_rec))
X_orig = np.fft.fft(x_orig)
fig, axes = plt.subplots(2, 2, figsize=(13, 8))
fig.suptitle('Partial reconstruction: using K of N/2 complex frequencies', fontsize=13)
axes.flatten()[0].plot(t_rec, x_orig, color=COLORS['primary'], linewidth=2)
axes.flatten()[0].set_title('Original signal')
axes.flatten()[0].set_xlabel('Sample $n$'); axes.flatten()[0].set_ylabel('Amplitude')
for ax, K_keep in zip(axes.flatten()[1:], [2, 5, 16]):
X_partial = np.zeros_like(X_orig)
X_partial[:K_keep] = X_orig[:K_keep]
X_partial[-K_keep:] = X_orig[-K_keep:]
x_rec = np.real(np.fft.ifft(X_partial))
err = np.mean((x_rec - x_orig)**2)
ax.plot(t_rec, x_orig, color=COLORS['neutral'], alpha=0.4, linestyle='--',
linewidth=1.5, label='Original')
ax.plot(t_rec, x_rec, color=COLORS['error'], linewidth=2,
label=f'Reconstruction (MSE={err:.4f})')
ax.set_title(f'K={K_keep} frequency pairs')
ax.set_xlabel('Sample $n$'); ax.set_ylabel('Amplitude')
ax.legend(fontsize=9)
fig.tight_layout(); plt.show()
Code cell 52
# Interactive: verify all DFT properties numerically
print('=== DFT Property Verification Suite ===')
N_v = 16
rng_v = np.random.default_rng(42)
x = rng_v.standard_normal(N_v) + 1j*rng_v.standard_normal(N_v)
y = rng_v.standard_normal(N_v) + 1j*rng_v.standard_normal(N_v)
a, b, m = 2.3+1.1j, -0.7+0.4j, 3
X, Y = np.fft.fft(x), np.fft.fft(y)
# 1. Linearity
err1 = np.max(np.abs(np.fft.fft(a*x + b*y) - (a*X + b*Y)))
print(f'1. Linearity: max error = {err1:.2e}')
# 2. Circular shift
rhs2 = np.exp(-2j*np.pi*m/N_v)**np.arange(N_v) * X
err2 = np.max(np.abs(np.fft.fft(np.roll(x, m)) - rhs2))
print(f'2. Circular shift: max error = {err2:.2e}')
# 3. Circular convolution via IFFT(X*Y)
xy_conv = np.fft.ifft(X * Y)
xy_direct = np.array([np.sum(x * np.roll(y[::-1], k+1)) for k in range(N_v)])
err3 = np.max(np.abs(xy_conv - xy_direct))
print(f'3. Circular conv: max error = {err3:.2e}')
# 4. Parseval
err4 = abs(np.sum(np.abs(x)**2) - np.sum(np.abs(X)**2)/N_v)
print(f'4. Parseval: |diff| = {err4:.2e}')
# 5. Unitarity F F*/N = I
F = np.fft.fft(np.eye(N_v), axis=0)
err5 = np.max(np.abs(F @ F.conj().T / N_v - np.eye(N_v)))
print(f'5. Unitarity F F*/N=I: max error = {err5:.2e}')
# 6. Conjugate symmetry for real input
x_real = rng_v.standard_normal(N_v)
X_real = np.fft.fft(x_real)
err6 = np.max(np.abs(X_real[1:] - np.conj(X_real[-1:0:-1])))
print(f'6. Conj. symmetry: max error = {err6:.2e}')
print()
print('All properties verified to machine precision.')
18. The DFT as a Group Algebra Isomorphism
The DFT is the Fourier transform on the cyclic group :
The group algebra has multiplication given by circular convolution. The DFT diagonalizes this multiplication:
This is why the Convolution Theorem (§04) is so fundamental: it follows directly from the DFT being a ring isomorphism.
For non-abelian groups, this generalizes to the non-abelian Fourier transform (representation theory), which underlies geometric deep learning on symmetric spaces.
19. Connections to Other Chapters
| Topic | This Chapter | Connection |
|---|---|---|
| Fourier Series (§01) | DFT is discrete/finite analogue | Period signal <-> discrete spectrum |
| Fourier Transform (§02) | DFT approximates continuous FT | Sampling theorem bridges them |
| Convolution Theorem (§04) | DFT enables FFT-based convolution | |
| Wavelets (§05) | DWT = iterated filter banks | Multi-resolution via FFT |
| Linear Algebra (§02-03) | is a unitary matrix | DFT = change of basis |
| Compressed Sensing | Sparsity in Fourier basis | CS recovery for MRI |
| CNNs | Spectral convolution | Circular conv = pointwise mult (§04) |
| SSMs (S4/Mamba) | Long convolutions via FFT | Sequences as spectral filters |
The Cooley-Tukey insight (1965): the DFT's computation collapses to by exploiting the factorization . This single algorithm makes Fourier analysis practical for all of signal processing, communications, scientific computing, and modern deep learning.
Code cell 55
# Application: spectral analysis of a synthetic LLM training loss curve
np.random.seed(42)
n_steps = 2000
steps = np.arange(n_steps)
trend = 2.0 * np.exp(-steps / 800)
epoch_osc = 0.05 * np.sin(2*np.pi*steps/100)
lr_osc = 0.03 * np.sin(2*np.pi*steps/500)
noise_c = 0.02 * np.random.randn(n_steps)
loss = trend + epoch_osc + lr_osc + noise_c
loss_detrend = loss - trend
L = np.fft.rfft(loss_detrend)
freqs_loss = np.fft.rfftfreq(n_steps)
fig, axes = plt.subplots(2, 1, figsize=(13, 8))
fig.suptitle('Spectral analysis of LLM training loss', fontsize=14)
axes[0].plot(steps, loss, color=COLORS['primary'], linewidth=1, alpha=0.8)
axes[0].plot(steps, trend, color=COLORS['error'], linewidth=2, linestyle='--',
label='Trend')
axes[0].set_title('Training loss curve')
axes[0].set_xlabel('Training step'); axes[0].set_ylabel('Loss')
axes[0].legend()
period_axis = 1.0 / (freqs_loss[1:] + 1e-12)
axes[1].plot(period_axis, np.abs(L[1:]),
color=COLORS['primary'], linewidth=1.5)
for p, label in [(100, 'epoch\n(100 steps)'), (500, 'LR schedule\n(500 steps)')]:
axes[1].axvline(p, color=COLORS['error'], linestyle='--', alpha=0.8)
axes[1].text(p+8, np.abs(L[1:]).max()*0.85, label, fontsize=9,
color=COLORS['error'])
axes[1].set_title('DFT magnitude of detrended loss')
axes[1].set_xlabel('Period (steps)'); axes[1].set_ylabel('|L[k]|')
axes[1].set_xlim(10, 1000)
fig.tight_layout(); plt.show()
Code cell 56
# Final summary: key quantitative results
print('=== Chapter 03 Key Results ===')
print()
print('DFT definition: X[k] = sum_{n} x[n] exp(-2pi*i*nk/N)')
print('IDFT: x[n] = (1/N) sum_{k} X[k] exp(2pi*i*nk/N)')
print()
print('FFT complexity: O(N log2 N) vs O(N^2) naive DFT')
print('Frequency resolution: Δf = fs / N (set by observation time T = N/fs)')
print('Zero-padding: Interpolates spectrum, does NOT improve resolution')
print('Nyquist criterion: fs >= 2 f_max (else aliasing via Poisson summation)')
print()
print('Window sidelobe attenuation:')
for name, sl in [('Rectangular', -13), ('Hann', -31.5),
('Hamming', -42.7), ('Blackman', -58), ('Kaiser beta=8.6', -69)]:
print(f' {name:<22} {sl:>6} dB')
print()
print('Whisper pipeline: 16kHz, 400-pt Hann, hop 160, FFT 512, 80 mel bins')
print('FNO spectral conv: keep K lowest modes (K << N//2+1), learn R_theta in C^K')
print('Monarch matrices: O(N sqrt(N)) params vs O(N^2) dense')
print('NTT: DFT over Z_p, O(N log N), exact -- powers FHE')
Chapter Summary
This notebook developed the DFT and FFT from first principles to modern ML applications.
What we proved:
- DFT is a unitary transformation:
- Cooley-Tukey FFT reduces to via butterfly factorization
- Frequency resolution is , set by observation time -- not zero-padding
- COLA condition enables perfect signal reconstruction from STFT frames
What we built:
- Recursive Cooley-Tukey FFT matching
scipy.fftto machine precision - Whisper mel spectrogram pipeline (16 kHz -> 80 mel bins)
- FNO spectral convolution layer (-mode truncation)
- NTT for exact polynomial multiplication over
What's next (§04 Convolution Theorem):
This duality reduces convolution to and underpins CNNs, WaveNet, S4/Mamba, and Hyena long convolutions.