Theory NotebookMath for LLMs

Discrete Fourier Transform and FFT

Fourier Analysis and Signal Processing / Discrete Fourier Transform and FFT

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

Discrete Fourier Transform and FFT

"The FFT is the most important numerical algorithm of our lifetime." -- Gilbert Strang, MIT Mathematics

This notebook provides interactive derivations and visualizations covering:

  • DFT definition, matrix form, and unitarity
  • Cooley-Tukey FFT algorithm from scratch
  • Spectral leakage and window functions
  • STFT spectrograms and time-frequency analysis
  • Whisper mel spectrogram pipeline
  • Fourier Neural Operator spectral convolution layer
  • Monarch butterfly matrix factorization

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3

import numpy as np
import scipy.fft as sp_fft
import scipy.signal as sp_sig
from scipy.special import i0 as bessel_i0

import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style='whitegrid', palette='colorblind')
    HAS_SNS = True
except ImportError:
    plt.style.use('seaborn-v0_8-whitegrid')
    HAS_SNS = False

mpl.rcParams.update({
    'figure.figsize':   (10, 6),
    'figure.dpi':        120,
    'font.size':          13,
    'axes.titlesize':     15,
    'axes.labelsize':     13,
    'lines.linewidth':     2.0,
    'axes.spines.top':   False,
    'axes.spines.right': False,
})
np.random.seed(42)

COLORS = {
    'primary':   '#0077BB',
    'secondary': '#EE7733',
    'tertiary':  '#009988',
    'error':     '#CC3311',
    'neutral':   '#555555',
    'highlight': '#EE3377',
}

print('Setup complete. NumPy', np.__version__)

1. Intuition

1.1 The DFT as Change of Basis

The NN-point DFT transforms a vector xCN\mathbf{x} \in \mathbb{C}^N into its coordinates in the Fourier basis {fk}k=0N1\{\mathbf{f}_k\}_{k=0}^{N-1}:

X[k]=n=0N1x[n]ωNnk,ωN=e2πi/NX[k] = \sum_{n=0}^{N-1} x[n]\,\omega_N^{-nk}, \quad \omega_N = e^{2\pi i/N}

Below we visualize the Fourier basis vectors as sampled complex sinusoids.

Code cell 5

# === 1.1 Fourier Basis Vectors ===
N = 16
omega_N = np.exp(2j * np.pi / N)
n = np.arange(N)

fig, axes = plt.subplots(2, 4, figsize=(14, 6))
fig.suptitle(f'Fourier Basis Vectors for N={N} (real parts)', fontsize=14)

for idx, k in enumerate([0, 1, 2, 3, 4, 6, 7, 8]):
    ax = axes[idx // 4, idx % 4]
    f_k = np.exp(2j * np.pi * k * n / N)  # basis vector f_k
    ax.stem(n, f_k.real, linefmt='C0-',
            markerfmt='o', basefmt='k-')
    ax.set_title(f'k={k}, freq={k}/{N}')
    ax.set_xlabel('n')
    ax.set_ylabel('Re[f_k]')
    ax.set_ylim(-1.3, 1.3)

fig.tight_layout()
plt.show()
print(f'Fourier basis: {N} orthonormal vectors in C^{N}')
print(f'Each f_k oscillates at {1}/{N} to {N//2}/{N} cycles per sample')

Code cell 6

# === 1.2 Orthonormality of Fourier Basis ===
N = 8
# Build normalized DFT matrix (1/sqrt(N) * F_N)
k = np.arange(N)[:, None]  # column
n = np.arange(N)[None, :]  # row
F_N = np.exp(-2j * np.pi * k * n / N) / np.sqrt(N)

# Check unitarity: F_N @ F_N.conj().T should = I
product = F_N @ F_N.conj().T
print(f'F_N @ F_N* =  (showing real part):')
print(np.round(product.real, 8))
print()

ok = np.allclose(product, np.eye(N), atol=1e-12)
print(f'PASS - Normalized DFT matrix is unitary: {ok}')

# Frobenius norm of the error
err = np.linalg.norm(product - np.eye(N), 'fro')
print(f'Frobenius error ||F_N F_N* - I||_F = {err:.2e}')

2. Formal Definitions

2.1 The DFT by Hand: N=4 Example

Let x=(1,0,1,0)\mathbf{x} = (1, 0, -1, 0). We have ω4=e2πi/4=i\omega_4 = e^{2\pi i/4} = i:

X[k]=n=03x[n]inkX[k] = \sum_{n=0}^{3} x[n]\, i^{-nk}
  • X[0]=1+01+0=0X[0] = 1 + 0 - 1 + 0 = 0
  • X[1]=1+0(i)+(1)(1)+0i=2X[1] = 1 + 0\cdot(-i) + (-1)\cdot(-1) + 0\cdot i = 2
  • X[2]=1+0(1)+(1)1+0(1)=0X[2] = 1 + 0\cdot(-1) + (-1)\cdot 1 + 0\cdot(-1) = 0
  • X[3]=1+0i+(1)(1)+0(i)=2X[3] = 1 + 0\cdot i + (-1)\cdot(-1) + 0\cdot(-i) = 2

Code cell 8

# === 2.1 DFT by Hand Verification ===
x = np.array([1, 0, -1, 0], dtype=complex)
N = len(x)
omega_4 = np.exp(2j * np.pi / N)

# Manual DFT
X_manual = np.zeros(N, dtype=complex)
for k in range(N):
    for n_idx in range(N):
        X_manual[k] += x[n_idx] * omega_4**(-n_idx * k)

X_numpy = np.fft.fft(x)

print('Manual DFT:', np.round(X_manual, 6))
print('NumPy FFT: ', np.round(X_numpy, 6))
ok = np.allclose(X_manual, X_numpy)
print(f'PASS - Manual DFT matches NumPy: {ok}')

# Parseval check
energy_time = np.sum(np.abs(x)**2)
energy_freq = np.sum(np.abs(X_numpy)**2) / N
print(f'\nParseval: sum|x|^2 = {energy_time:.4f}, sum|X|^2/N = {energy_freq:.4f}')
print(f'PASS - Parseval holds: {np.isclose(energy_time, energy_freq)}')

Code cell 9

# === 2.3 DFT Matrix Visualization ===
N = 8
k_idx = np.arange(N)[:, None]
n_idx = np.arange(N)[None, :]
F8 = np.exp(-2j * np.pi * k_idx * n_idx / N)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

im0 = axes[0].imshow(F8.real, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
axes[0].set_title('Re(F_8): real part of DFT matrix')
axes[0].set_xlabel('Column (time n)')
axes[0].set_ylabel('Row (frequency k)')
fig.colorbar(im0, ax=axes[0])

im1 = axes[1].imshow(F8.imag, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
axes[1].set_title('Im(F_8): imaginary part of DFT matrix')
axes[1].set_xlabel('Column (time n)')
axes[1].set_ylabel('Row (frequency k)')
fig.colorbar(im1, ax=axes[1])

fig.tight_layout()
plt.show()

# Verify: F8 @ F8.conj().T = 8 * I
ok = np.allclose(F8 @ F8.conj().T, N * np.eye(N))
print(f'PASS - F_8 F_8* = {N}*I: {ok}')
print(f'Max entry of |F_8 F_8* - {N}I| = {np.max(np.abs(F8 @ F8.conj().T - N*np.eye(N))):.2e}')

2.5 Relation to the Continuous Fourier Transform

The DFT approximates the continuous FT of a sampled signal:

X[k]x^(kΔf)fs,Δf=fsNX[k] \approx \hat{x}(k\,\Delta f) \cdot f_s, \quad \Delta f = \frac{f_s}{N}

Below we compare the DFT of a Gaussian to its known continuous FT.

Code cell 11

# === 2.5 DFT vs Continuous FT ===
N = 256
fs = 100.0  # Hz
dt = 1.0 / fs
t = np.arange(N) * dt
df = fs / N

# Gaussian: f(t) = exp(-pi*t^2)
# Continuous FT (xi-convention): F_hat(xi) = exp(-pi*xi^2)
sigma = 0.1
x = np.exp(-np.pi * (t - t[N//2])**2 / sigma**2)  # centered Gaussian

# DFT
X = np.fft.fftshift(np.fft.fft(np.fft.ifftshift(x))) * dt
freqs = np.fft.fftshift(np.fft.fftfreq(N, d=dt))

# Analytical FT
X_analytical = sigma * np.exp(-np.pi * sigma**2 * freqs**2)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].plot(t, x, color=COLORS['primary'])
axes[0].set_title(f'Gaussian signal (sigma={sigma})')
axes[0].set_xlabel('Time (s)')
axes[0].set_ylabel('x(t)')

mask = np.abs(freqs) < 15
axes[1].plot(freqs[mask], np.abs(X[mask]),
             color=COLORS['primary'], label='|DFT|', lw=2)
axes[1].plot(freqs[mask], X_analytical[mask],
             '--', color=COLORS['error'], label='Analytical FT', lw=2)
axes[1].set_title('DFT magnitude vs analytical FT')
axes[1].set_xlabel('Frequency (Hz)')
axes[1].set_ylabel('|X(f)|')
axes[1].legend()

fig.tight_layout()
plt.show()

err = np.max(np.abs(np.abs(X[mask]) - X_analytical[mask]))
print(f'Max error |DFT| vs analytical FT: {err:.4f}')
print(f'PASS - DFT approximates continuous FT: {err < 0.01}')

3. Properties of the DFT

3.2 Circular Shift Property

If y[n]=x[nmmodN]y[n] = x[n - m \bmod N], then Y[k]=ωNmkX[k]Y[k] = \omega_N^{-mk} X[k]. The circular shift multiplies each DFT coefficient by a phase factor.

Code cell 13

# === 3.2 Circular Shift Property ===
N = 32
x = np.zeros(N); x[:8] = 1.0  # rectangular pulse
m = 5  # shift by 5 samples

# Circular shift in time
y = np.roll(x, m)

X = np.fft.fft(x)
Y = np.fft.fft(y)

# Expected: Y[k] = omega_N^{-m*k} * X[k]
k = np.arange(N)
omega_N = np.exp(2j * np.pi / N)
Y_predicted = omega_N**(-m * k) * X

ok = np.allclose(Y, Y_predicted, atol=1e-12)
print(f'PASS - Circular shift property: {ok}')
print(f'Max error: {np.max(np.abs(Y - Y_predicted)):.2e}')

fig, axes = plt.subplots(1, 2, figsize=(12, 4))
n = np.arange(N)
axes[0].stem(n, x, linefmt='C0-', markerfmt='o',
             basefmt='k-', label='x[n]')
axes[0].stem(n, y, linefmt='C1-', markerfmt='s',
             basefmt='k-', label=f'y[n]=x[n-{m}]')
axes[0].set_title('Circular shift in time domain')
axes[0].set_xlabel('n'); axes[0].set_ylabel('Amplitude')
axes[0].legend()

axes[1].plot(k, np.abs(X), color=COLORS['primary'], label='|X[k]|')
axes[1].plot(k, np.abs(Y), '--', color=COLORS['secondary'], label='|Y[k]|')
axes[1].set_title('Magnitudes unchanged by circular shift')
axes[1].set_xlabel('k'); axes[1].set_ylabel('|DFT|')
axes[1].legend()
fig.tight_layout(); plt.show()

Code cell 14

# === 3.4 Conjugate Symmetry for Real Inputs ===
N = 16
x_real = np.random.randn(N)  # real-valued signal
X = np.fft.fft(x_real)

print('Conjugate symmetry check: X[N-k] == X[k].conj()')
for k in range(1, N//2):
    diff = abs(X[N-k] - X[k].conj())
    if diff > 1e-12:
        print(f'  FAIL at k={k}: diff={diff:.2e}')

ok = np.allclose(X[1:N//2], X[N-1:N//2:-1].conj(), atol=1e-12)
print(f'PASS - Conjugate symmetry X[N-k]=X[k]*: {ok}')

# DC and Nyquist are always real
print(f'X[0]   (DC):      {X[0].real:.4f} + {X[0].imag:.4f}i  (should be real)')
print(f'X[N/2] (Nyquist): {X[N//2].real:.4f} + {X[N//2].imag:.4f}i (should be real)')
print(f'Only {N//2+1} independent values out of {N} (rfft would return {N//2+1})')

Code cell 15

# === 3.5 Parseval's Theorem ===
np.random.seed(0)
for N in [8, 64, 512, 4096]:
    x = np.random.randn(N)
    X = np.fft.fft(x)
    lhs = np.sum(np.abs(x)**2)
    rhs = np.sum(np.abs(X)**2) / N
    ok = np.isclose(lhs, rhs, rtol=1e-10)
    print(f'N={N:5d}: sum|x|^2={lhs:.4f}, sum|X|^2/N={rhs:.4f}, '
          f'error={abs(lhs-rhs):.2e}  {"PASS" if ok else "FAIL"}')

4. The Fast Fourier Transform Algorithm

4.1-4.2 Cooley-Tukey: Recursive DFT

The key identity splits an NN-point DFT into two (N/2)(N/2)-point DFTs:

X[k]=E[k]+ωNkO[k]X[k] = E[k] + \omega_N^{-k} O[k] X[k+N/2]=E[k]ωNkO[k]X[k+N/2] = E[k] - \omega_N^{-k} O[k]

where EE = DFT of even-indexed samples, OO = DFT of odd-indexed samples.

Code cell 17

# === 4.1 Naive DFT vs FFT Complexity ===
import time

def naive_dft(x):
    N = len(x)
    k = np.arange(N)[:, None]
    n = np.arange(N)[None, :]
    W = np.exp(-2j * np.pi * k * n / N)
    return W @ x

print('N        Naive DFT (ms)   NumPy FFT (ms)   Speedup')
print('-' * 55)
for N in [64, 256, 1024, 4096]:
    x = np.random.randn(N).astype(complex)
    
    if N <= 1024:
        t0 = time.perf_counter()
        for _ in range(10): naive_dft(x)
        t_naive = (time.perf_counter() - t0) / 10 * 1000
    else:
        t_naive = float('inf')
    
    t0 = time.perf_counter()
    for _ in range(100): np.fft.fft(x)
    t_fft = (time.perf_counter() - t0) / 100 * 1000
    
    speedup = t_naive/t_fft if t_naive < float('inf') else 'N/A'
    speedup_str = f'{speedup:.0f}x' if isinstance(speedup, float) else speedup
    print(f"{N:5d}    {t_naive:10.3f}   {t_fft:12.4f}   {speedup_str:>8}")

Code cell 18

# === 4.2-4.6 Cooley-Tukey FFT from Scratch ===

def fft_recursive(x):
    # Radix-2 Cooley-Tukey FFT (recursive)
    N = len(x)
    if N == 1:
        return x
    if N % 2 != 0:
        raise ValueError(f'N must be power of 2, got {N}')
    # Split even/odd
    even = fft_recursive(x[0::2])  # DFT of even-indexed samples
    odd  = fft_recursive(x[1::2])  # DFT of odd-indexed samples
    # Twiddle factors
    k = np.arange(N // 2)
    twiddle = np.exp(-2j * np.pi * k / N)
    # Butterfly combine
    X = np.empty(N, dtype=complex)
    X[:N//2] = even + twiddle * odd
    X[N//2:] = even - twiddle * odd
    return X

# Test on various sizes
print('Testing FFT from scratch vs NumPy:')
for N in [8, 64, 512, 4096]:
    x = np.random.randn(N) + 1j * np.random.randn(N)
    X_scratch = fft_recursive(x)
    X_numpy   = np.fft.fft(x)
    err = np.max(np.abs(X_scratch - X_numpy))
    ok = err < 1e-10
    print(f'  N={N:5d}: max error = {err:.2e}  {"PASS" if ok else "FAIL"}')

Code cell 19

# === 4.3 Butterfly Diagram: 8-point DFT ===
# Visualize the 3 stages of the 8-point Cooley-Tukey FFT

fig, ax = plt.subplots(figsize=(12, 6))
ax.set_xlim(-0.5, 4.5)
ax.set_ylim(-0.5, 8.5)
ax.axis('off')
ax.set_title('8-point DIT-FFT Butterfly Diagram', fontsize=14)

N_pts = 8
stages = 3
bit_rev = [0, 4, 2, 6, 1, 5, 3, 7]  # bit-reversed order

# Draw nodes
for stage in range(stages + 1):
    for pos in range(N_pts):
        ax.plot(stage, N_pts - 1 - pos, 'o',
                color=COLORS['primary'], ms=10, zorder=3)

# Input labels
for pos in range(N_pts):
    ax.text(-0.3, N_pts - 1 - pos, f'x[{bit_rev[pos]}]',
            ha='right', va='center', fontsize=10)

# Output labels
for pos in range(N_pts):
    ax.text(stages + 0.3, N_pts - 1 - pos, f'X[{pos}]',
            ha='left', va='center', fontsize=10)

# Stage connections and twiddle factors
butterfly_groups = [
    [(0,4),(1,5),(2,6),(3,7)],  # stage 1: stride=4
    [(0,2),(1,3),(4,6),(5,7)],  # stage 2: stride=2
    [(0,1),(2,3),(4,5),(6,7)],  # stage 3: stride=1
]
twiddles = [
    ['W^0','W^0','W^0','W^0'],
    ['W^0','W^1','W^0','W^1'],
    ['W^0','W^1','W^2','W^3'],
]

for s, (butterflies, tws) in enumerate(zip(butterfly_groups, twiddles)):
    for bi, ((top, bot), tw) in enumerate(zip(butterflies, tws)):
        y_top = N_pts - 1 - top
        y_bot = N_pts - 1 - bot
        # Straight line (top)
        ax.plot([s, s+1], [y_top, y_top], '-',
                color=COLORS['primary'], lw=1.5)
        # Crossing line (bottom input to top output)
        ax.plot([s, s+1], [y_bot, y_top], '-',
                color=COLORS['secondary'], lw=1.0, alpha=0.7)
        # Straight line (bottom)
        ax.plot([s, s+1], [y_bot, y_bot], '-',
                color=COLORS['primary'], lw=1.5)
        # Twiddle label
        mid_x = s + 0.5
        mid_y = (y_top + y_bot) / 2
        ax.text(mid_x, mid_y, tw, ha='center', va='center',
                fontsize=8, color=COLORS['error'],
                bbox=dict(boxstyle='round,pad=0.2', fc='white', alpha=0.8))

# Stage labels
for s in range(stages):
    ax.text(s + 0.5, -0.3, f'Stage {s+1}', ha='center', va='top', fontsize=10)

fig.tight_layout()
plt.show()
print('8-point DIT-FFT: 3 stages, 4 butterflies per stage = 12 complex multiplications')
print(f'vs naive DFT: {8**2} = 64 complex multiplications')
print(f'Speedup: {64/12:.1f}x')

Code cell 20

# === 4.5 Bit-Reversal Permutation ===

def bit_reverse(n, num_bits):
    result = 0
    for _ in range(num_bits):
        result = (result << 1) | (n & 1)
        n >>= 1
    return result

N = 8
m = int(np.log2(N))  # number of bits

print(f'Bit-reversal permutation for N={N} ({m} bits):')
print(f"{"Index":>8} {"Binary":>8} {"Bit-reversed":>14} {"Dec (br)":>10}")
print('-' * 44)
for i in range(N):
    br = bit_reverse(i, m)
    print(f"{i:>8} {bin(i)[2:].zfill(m):>8} {bin(br)[2:].zfill(m):>14} {br:>10}")

# Verify: applying DFT to bit-reversed input matches standard DFT with reordered output
x = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=complex)
perm = [bit_reverse(i, m) for i in range(N)]
x_br = x[perm]
X_normal = np.fft.fft(x)
print(f'\nBit-reversal permutation indices: {perm}')

Code cell 21

# === 4.6 O(N log N) Complexity Validation ===
import time

sizes = [2**k for k in range(6, 22)]
times = []

for N in sizes:
    x = np.random.randn(N).astype(np.float64)
    # Warm up
    np.fft.rfft(x)
    reps = max(3, 1000 // N) if N <= 10000 else 3
    t0 = time.perf_counter()
    for _ in range(reps):
        np.fft.rfft(x)
    times.append((time.perf_counter() - t0) / reps)

# Fit N log N model
log_N = np.log2(sizes)
log_T = np.log2(times)
coeffs = np.polyfit(log_N, log_T, 1)
print(f'Empirical power law exponent: {coeffs[0]:.3f} (expected ~1.0 for O(N log N))')

fig, ax = plt.subplots(figsize=(10, 5))
ax.loglog(sizes, times, 'o-', color=COLORS['primary'], label='NumPy rfft')
# Theoretical O(N log2 N) reference
ref = np.array(sizes) * np.log2(sizes) * times[4] / (sizes[4] * np.log2(sizes[4]))
ax.loglog(sizes, ref, '--', color=COLORS['neutral'],
          label=r'$O(N\log_2 N)$ reference')
ax.set_xlabel('FFT size N')
ax.set_ylabel('Time (seconds)')
ax.set_title('NumPy FFT timing: verifying O(N log N) complexity')
ax.legend()
fig.tight_layout()
plt.show()

5. Spectral Leakage and Windowing

5.1 The Leakage Problem

When a sinusoid's frequency does not align exactly with a DFT bin, energy 'leaks' from the true frequency into all other bins. This occurs because the DFT implicitly assumes the signal is periodic with period NN. Truncation is equivalent to multiplication by a rectangular window, whose DFT (the Dirichlet kernel) has large sidelobes.

Code cell 23

# === 5.1 Spectral Leakage Demo ===
N = 64
n = np.arange(N)

# Integer frequency (bin-aligned) vs non-integer frequency
f_int  = 5.0   # exactly 5 cycles in N=64 -> no leakage
f_frac = 5.5   # 5.5 cycles -> leakage

x_int  = np.cos(2 * np.pi * f_int  * n / N)
x_frac = np.cos(2 * np.pi * f_frac * n / N)

X_int  = np.fft.fft(x_int)
X_frac = np.fft.fft(x_frac)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
k = np.arange(N)

axes[0].stem(k[:N//2+1], np.abs(X_int[:N//2+1]) / N * 2,
             linefmt='C0-', markerfmt='o', basefmt='k-')
axes[0].set_title(f'Bin-aligned: f={f_int} (no leakage)')
axes[0].set_xlabel('Frequency bin k')
axes[0].set_ylabel('Amplitude')
axes[0].axvline(f_int, color=COLORS['error'], ls='--', label='True freq')
axes[0].legend()

axes[1].stem(k[:N//2+1], np.abs(X_frac[:N//2+1]) / N * 2,
             linefmt='C1-', markerfmt='o', basefmt='k-')
axes[1].set_title(f'Non-bin-aligned: f={f_frac} (leakage visible)')
axes[1].set_xlabel('Frequency bin k')
axes[1].set_ylabel('Amplitude')
axes[1].axvline(f_frac, color=COLORS['error'], ls='--', label='True freq')
axes[1].legend()

fig.tight_layout(); plt.show()

# Quantify leakage
main_lobe_bins = {4, 5, 6}  # bins near f_frac=5.5
total_power = np.sum(np.abs(X_frac)**2)
main_power = sum(np.abs(X_frac[b])**2 for b in main_lobe_bins)
leakage_frac = 1 - main_power / total_power
print(f'Leakage fraction (energy outside bins 4-6): {leakage_frac:.1%}')

Code cell 24

# === 5.2 Window Functions ===
N = 64
n = np.arange(N)

def hann_window(N):
    n = np.arange(N)
    return 0.5 * (1 - np.cos(2*np.pi*n/(N-1)))

def hamming_window(N):
    n = np.arange(N)
    return 0.54 - 0.46*np.cos(2*np.pi*n/(N-1))

def blackman_window(N):
    n = np.arange(N)
    return 0.42 - 0.5*np.cos(2*np.pi*n/(N-1)) + 0.08*np.cos(4*np.pi*n/(N-1))

def kaiser_window(N, beta=8.6):
    n = np.arange(N)
    x = beta * np.sqrt(1 - (2*n/(N-1) - 1)**2)
    return bessel_i0(x) / bessel_i0(beta)

windows = {
    'Rectangular': np.ones(N),
    'Hann':       hann_window(N),
    'Hamming':    hamming_window(N),
    'Blackman':   blackman_window(N),
    'Kaiser(8.6)': kaiser_window(N, 8.6),
}
colors_list = [COLORS['neutral'], COLORS['primary'], COLORS['secondary'],
               COLORS['tertiary'], COLORS['highlight']]

# Compute frequency responses (zero-padded for smooth curve)
N_pad = 1024
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

print(f"{'Window':<15} {'Coherent gain':>15} {'Peak sidelobe (dB)':>20}")
print('-' * 52)

for (name, w), c in zip(windows.items(), colors_list):
    # Time domain
    axes[0].plot(np.arange(N), w, color=c, label=name, alpha=0.85)
    # Frequency domain (zero-padded, dB)
    W = np.fft.fft(w, n=N_pad)
    W_db = 20*np.log10(np.abs(np.fft.fftshift(W)) / np.max(np.abs(W)) + 1e-15)
    freqs_norm = np.fft.fftshift(np.fft.fftfreq(N_pad))
    axes[1].plot(freqs_norm, W_db, color=c, label=name, alpha=0.85)
    # Metrics
    cg = np.sum(w) / N
    main_peak = np.max(np.abs(W))
    # Find max sidelobe (outside main lobe, which is first 4 bins)
    W_side = np.abs(W).copy()
    W_side[:4] = 0; W_side[-4:] = 0
    peak_sl = 20*np.log10(np.max(W_side) / main_peak + 1e-15)
    print(f"{name:<15} {cg:>15.4f} {peak_sl:>20.1f}")

axes[0].set_title('Window functions (time domain)')
axes[0].set_xlabel('Sample n'); axes[0].set_ylabel('w[n]')
axes[0].legend(fontsize=9)

axes[1].set_xlim(-0.4, 0.4); axes[1].set_ylim(-100, 5)
axes[1].set_title('Window spectral responses (dB)')
axes[1].set_xlabel('Normalized frequency'); axes[1].set_ylabel('dB')
axes[1].axhline(-60, color='k', ls=':', lw=0.8)
axes[1].legend(fontsize=9)
fig.tight_layout(); plt.show()

7. Zero-Padding and Frequency Resolution

Zero-padding appends zeros to the signal before computing the DFT. It interpolates the spectrum — producing more DFT bins — but does not improve the fundamental frequency resolution Δf=fs/Norig\Delta f = f_s / N_{\text{orig}}.

Key insight: Resolution is determined by the observation window T=N/fsT = N/f_s. More zeros = finer interpolation; longer observation = true resolution improvement.

Code cell 26

# Zero-padding: interpolation vs resolution
fs = 100.0          # Hz
N_orig = 32
t = np.arange(N_orig) / fs
f_true = 11.3       # Hz (deliberately non-bin-aligned)
x = np.cos(2 * np.pi * f_true * t)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle('Zero-padding interpolates the spectrum but does not improve resolution',
             fontsize=14)

for ax, N_pad in zip(axes, [32, 128, 512]):
    xp = np.zeros(N_pad)
    xp[:N_orig] = x
    X = np.fft.rfft(xp)
    freqs = np.fft.rfftfreq(N_pad, 1/fs)
    ax.plot(freqs, np.abs(X), color=COLORS['primary'], linewidth=1.5)
    ax.axvline(f_true, color=COLORS['error'], linestyle='--', linewidth=1.2,
               label=f'True freq {f_true} Hz')
    ax.set_title(f'N_pad = {N_pad} (N_orig = {N_orig})')
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('|X[k]|')
    ax.legend(fontsize=9)
    ax.set_xlim(0, 30)

fig.tight_layout()
plt.show()
print(f'Resolution with N_orig={N_orig}: Δf = {fs/N_orig:.2f} Hz  (unchanged by zero-padding)')

8. Sampling, Nyquist, and Aliasing

Nyquist-Shannon Sampling Theorem: A signal bandlimited to fmaxf_{\max} Hz can be perfectly reconstructed from samples taken at fs2fmaxf_s \geq 2f_{\max}.

Aliasing occurs when fs<2fmaxf_s < 2f_{\max}: high-frequency components fold into lower frequencies. The alias of a sinusoid at frequency ff sampled at fsf_s appears at:

falias=fround(f/fs)fsf_{\text{alias}} = f - \text{round}(f / f_s) \cdot f_s

This follows from the Poisson summation formula:

x^d(ξ)=fsk=x^(ξkfs)\hat{x}_d(\xi) = f_s \sum_{k=-\infty}^{\infty} \hat{x}(\xi - k f_s)

Code cell 28

# Aliasing demonstration
fs_high = 1000.0   # adequate sampling rate
fs_low  = 30.0     # below Nyquist for 11 Hz component
f_signal = 11.0    # Hz

t_cont = np.linspace(0, 1, 5000, endpoint=False)
x_cont = np.sin(2 * np.pi * f_signal * t_cont)

t_high = np.arange(0, 1, 1/fs_high)
t_low  = np.arange(0, 1, 1/fs_low)
x_high = np.sin(2 * np.pi * f_signal * t_high)
x_low  = np.sin(2 * np.pi * f_signal * t_low)

# The alias frequency
f_alias = f_signal - round(f_signal / fs_low) * fs_low
print(f'Signal: {f_signal} Hz  |  fs_low: {fs_low} Hz  |  Alias at: {f_alias} Hz')

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle('Aliasing: high-frequency signal appears at wrong frequency', fontsize=14)

# Time-domain
ax = axes[0]
ax.plot(t_cont[:200], x_cont[:200], color=COLORS['neutral'], alpha=0.5, label='Continuous 11 Hz')
ax.stem(t_low[:8], x_low[:8], linefmt='C3--',
        markerfmt='o', basefmt=' ', label=f'Samples at fs={fs_low} Hz')
# Alias waveform
x_alias = np.sin(2 * np.pi * abs(f_alias) * t_cont)
ax.plot(t_cont[:200], x_alias[:200], color=COLORS['error'], linewidth=2,
        label=f'Alias: {abs(f_alias):.0f} Hz')
ax.set_title('Time domain')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Amplitude')
ax.legend(fontsize=9)
ax.set_xlim(0, t_cont[200])

# Frequency domain — DFT of low-sampled signal
ax = axes[1]
X_low = np.fft.rfft(x_low)
freqs_low = np.fft.rfftfreq(len(x_low), 1/fs_low)
ax.stem(freqs_low, np.abs(X_low), linefmt='C0-',
        markerfmt='o', basefmt=' ')
ax.axvline(f_signal, color=COLORS['error'], linestyle='--', label=f'True {f_signal} Hz')
ax.axvline(abs(f_alias), color=COLORS['highlight'], linestyle=':',
           label=f'Alias {abs(f_alias):.0f} Hz')
ax.set_title(f'DFT at fs={fs_low} Hz  (Nyquist = {fs_low/2} Hz)')
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('|X[k]|')
ax.legend(fontsize=9)

fig.tight_layout()
plt.show()

9. Short-Time Fourier Transform (STFT)

The STFT slides a window w[m]w[m] along the signal and computes the DFT at each position:

S[l,k]=m=0N1x[lH+m]w[m]e2πimk/NS[l, k] = \sum_{m=0}^{N-1} x[lH + m]\, w[m]\, e^{-2\pi i mk/N}

where HH is the hop size (stride) and NN is the window length.

ParameterEffect
Larger window NNBetter frequency resolution, worse time resolution
Smaller NNBetter time resolution, worse frequency resolution
Hop size HHControls overlap; H<NH < N gives overlap-add reconstruction

COLA condition (Constant Overlap-Add): lw[nlH]=C\sum_l w[n - lH] = C for all nn → ensures perfect reconstruction via overlap-add synthesis.

Code cell 30

# STFT from scratch
def stft(x, N=256, H=128, window='hann'):
    if window == 'hann':
        w = np.hanning(N)
    elif window == 'hamming':
        w = np.hamming(N)
    else:  # rectangular
        w = np.ones(N)
    frames = []
    n_frames = 1 + (len(x) - N) // H
    for l in range(n_frames):
        frame = x[l*H : l*H + N] * w
        frames.append(np.fft.rfft(frame))
    return np.array(frames).T   # shape: (N//2+1, n_frames)

# Test signal: chirp (linearly increasing frequency)
fs = 8000
duration = 1.0
t = np.linspace(0, duration, int(fs * duration), endpoint=False)
x_chirp = np.sin(2 * np.pi * (200 * t + 1500 * t**2))

# Compute STFT
N_win = 256
H_hop = 64
S = stft(x_chirp, N=N_win, H=H_hop, window='hann')

# Plot spectrogram
fig, ax = plt.subplots(figsize=(12, 5))
times = np.arange(S.shape[1]) * H_hop / fs
freqs = np.fft.rfftfreq(N_win, 1/fs)
im = ax.pcolormesh(times, freqs, 20*np.log10(np.abs(S) + 1e-10),
                   cmap='viridis', shading='auto')
fig.colorbar(im, ax=ax, label='Power (dB)')
ax.set_title('STFT spectrogram of a linear chirp (200 → 1700 Hz)')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Frequency (Hz)')
fig.tight_layout()
plt.show()

Code cell 31

# Time-frequency trade-off: three window sizes on the same chirp
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle('STFT: time-frequency resolution trade-off', fontsize=14)

for ax, (N_w, H_w) in zip(axes, [(64, 16), (256, 64), (1024, 256)]):
    S_w = stft(x_chirp, N=N_w, H=H_w, window='hann')
    t_ax = np.arange(S_w.shape[1]) * H_w / fs
    f_ax = np.fft.rfftfreq(N_w, 1/fs)
    ax.pcolormesh(t_ax, f_ax, 20*np.log10(np.abs(S_w) + 1e-10),
                 cmap='viridis', shading='auto')
    ax.set_title(f'Window N={N_w}')
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Frequency (Hz)')
    ax.set_ylim(0, 4000)

fig.tight_layout()
plt.show()
print('Short window → good time res, blurry freq res')
print('Long window  → good freq res, blurry time res')

Code cell 32

# Verify COLA condition for Hann window with 50% overlap
N_win = 256
H_hop = N_win // 2   # 50% overlap
w = np.hanning(N_win)

signal_len = 2048
cola_sum = np.zeros(signal_len)
for l in range((signal_len - N_win) // H_hop + 1):
    start = l * H_hop
    cola_sum[start:start + N_win] += w

interior = cola_sum[N_win:-N_win]   # exclude edges
print(f'COLA sum: min={interior.min():.6f}, max={interior.max():.6f}')
print(f'Constant overlap-add satisfied: {np.allclose(interior, interior[0])}')

fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(cola_sum, color=COLORS['primary'], linewidth=1.5)
ax.axhline(interior.mean(), color=COLORS['error'], linestyle='--',
           label=f'COLA constant = {interior.mean():.3f}')
ax.set_title('COLA (Constant Overlap-Add) verification: Hann window, 50% overlap')
ax.set_xlabel('Sample index')
ax.set_ylabel('Sum of window values')
ax.legend()
fig.tight_layout()
plt.show()

10. Case Study: Whisper's Mel Spectrogram Pipeline

OpenAI's Whisper (Radford et al. 2022) pre-processes audio with a fixed pipeline:

StepValueRationale
Sample rate16 000 HzTelephone-quality speech
Window length400 samples = 25 msStationary speech frame
Hop size160 samples = 10 msStandard 10 ms frame shift
FFT size512 pointsZero-pad 400→512 (next power of 2)
Mel bins80Perceptual resolution
Frequency range0–8 000 HzBelow Nyquist (8 000 Hz)
Compressionlog(max(x,1010))\log(\max(x, 10^{-10})), clipped to [max8,)[\max - 8, \infty)Dynamic range compression

The mel scale maps linear frequency ff to mel m=2595log10(1+f/700)m = 2595 \log_{10}(1 + f/700), mimicking the logarithmic frequency perception of the human cochlea.

Code cell 34

# Whisper mel filterbank from scratch
def hz_to_mel(f):
    return 2595.0 * np.log10(1.0 + f / 700.0)

def mel_to_hz(m):
    return 700.0 * (10.0 ** (m / 2595.0) - 1.0)

def mel_filterbank(n_mels=80, n_fft=512, fs=16000, f_min=0, f_max=8000):
    n_freq = n_fft // 2 + 1
    m_min, m_max = hz_to_mel(f_min), hz_to_mel(f_max)
    mel_pts = np.linspace(m_min, m_max, n_mels + 2)
    freq_pts = mel_to_hz(mel_pts)
    bins = np.floor((n_fft + 1) * freq_pts / fs).astype(int)
    fbank = np.zeros((n_mels, n_freq))
    for m in range(1, n_mels + 1):
        f_m_minus = bins[m - 1]
        f_m       = bins[m]
        f_m_plus  = bins[m + 1]
        for k in range(f_m_minus, f_m):
            fbank[m-1, k] = (k - f_m_minus) / max(f_m - f_m_minus, 1)
        for k in range(f_m, f_m_plus):
            fbank[m-1, k] = (f_m_plus - k) / max(f_m_plus - f_m, 1)
    return fbank

fb = mel_filterbank()
freqs = np.fft.rfftfreq(512, 1/16000)

fig, axes = plt.subplots(2, 1, figsize=(12, 8))

ax = axes[0]
for i in range(0, 80, 5):
    ax.plot(freqs, fb[i], alpha=0.7, linewidth=1.2)
ax.set_title('80 mel filterbank triangular filters (0–8000 Hz)')
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Filter weight')

ax = axes[1]
im = ax.imshow(fb, aspect='auto', origin='lower',
               extent=[freqs[0], freqs[-1], 0, 80], cmap='viridis')
fig.colorbar(im, ax=ax, label='Filter weight')
ax.set_title('Mel filterbank matrix (80 × 257)')
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Mel bin')

fig.tight_layout()
plt.show()

Code cell 35

# Simulate the full Whisper preprocessing pipeline
np.random.seed(42)
fs_whisper = 16000
# Synthetic speech-like signal: sum of harmonics with amplitude envelope
t_w = np.linspace(0, 3.0, 3 * fs_whisper, endpoint=False)
x_speech = np.zeros_like(t_w)
for k_harm, (f0, amp) in enumerate([(120, 0.6), (240, 0.3), (360, 0.15),
                                     (480, 0.07), (960, 0.03)]):
    x_speech += amp * np.sin(2 * np.pi * f0 * t_w)
# Amplitude envelope (vowel-like)
env = np.exp(-0.5 * ((t_w - 1.5) / 0.6)**2)
x_speech *= env

# Step 1: STFT (400-sample Hann, hop 160, FFT 512)
N_w, H_w, N_fft = 400, 160, 512
w_hann = np.hanning(N_w)
S_whisper = []
for l in range((len(x_speech) - N_w) // H_w + 1):
    frame = np.zeros(N_fft)
    frame[:N_w] = x_speech[l*H_w : l*H_w + N_w] * w_hann
    S_whisper.append(np.fft.rfft(frame))
S_whisper = np.array(S_whisper).T   # (257, n_frames)

# Step 2: Power spectrum
power = np.abs(S_whisper)**2

# Step 3: Apply mel filterbank
fb_w = mel_filterbank(n_mels=80, n_fft=N_fft, fs=fs_whisper)
mel_spec = fb_w @ power   # (80, n_frames)

# Step 4: Log compression + clipping
log_mel = np.log10(np.maximum(mel_spec, 1e-10))
log_mel = np.maximum(log_mel, log_mel.max() - 8.0)

# Plot
frame_times = np.arange(log_mel.shape[1]) * H_w / fs_whisper
fig, axes = plt.subplots(2, 1, figsize=(13, 8))
fig.suptitle('Whisper mel spectrogram pipeline', fontsize=14)

im1 = axes[0].pcolormesh(frame_times,
                          np.fft.rfftfreq(N_fft, 1/fs_whisper),
                          20*np.log10(np.abs(S_whisper) + 1e-10),
                          cmap='viridis', shading='auto')
fig.colorbar(im1, ax=axes[0], label='dB')
axes[0].set_title('STFT spectrogram (linear frequency)')
axes[0].set_xlabel('Time (s)')
axes[0].set_ylabel('Frequency (Hz)')

im2 = axes[1].pcolormesh(frame_times, np.arange(80),
                          log_mel, cmap='viridis', shading='auto')
fig.colorbar(im2, ax=axes[1], label='Log power')
axes[1].set_title('Log-mel spectrogram (80 mel bins) — Whisper input')
axes[1].set_xlabel('Time (s)')
axes[1].set_ylabel('Mel bin')

fig.tight_layout()
plt.show()
print(f'Log-mel shape: {log_mel.shape}  (80 mel bins × {log_mel.shape[1]} frames)')

11. Fourier Neural Operator (FNO)

The Fourier Neural Operator (Li et al. 2021) solves parametric PDEs in O(NlogN)O(N \log N) by learning in the frequency domain.

Spectral convolution layer:

F(v)(x)=F1(Rθ(Fv))(x)\mathcal{F}(v)(x) = \mathcal{F}^{-1}\bigl(R_\theta \cdot (\mathcal{F} v)\bigr)(x)

where RθCdin×dout×KR_\theta \in \mathbb{C}^{d_{\mathrm{in}} \times d_{\mathrm{out}} \times K} is a learnable weight tensor truncated to the KK lowest Fourier modes.

Key insight: Truncating to KK modes acts as a low-pass filter that captures the smooth, long-range structure of PDE solutions while discarding fine noise.

Code cell 37

# FNO spectral convolution layer (1D, NumPy/pure-Python implementation)
class SpectralConv1d:
    def __init__(self, in_channels, out_channels, n_modes, rng=None):
        if rng is None:
            rng = np.random.default_rng(42)
        scale = 1.0 / (in_channels * out_channels)
        self.weights_real = rng.uniform(-scale, scale,
                                        (in_channels, out_channels, n_modes))
        self.weights_imag = rng.uniform(-scale, scale,
                                        (in_channels, out_channels, n_modes))
        self.weights = self.weights_real + 1j * self.weights_imag
        self.n_modes = n_modes

    def forward(self, x):
        # x: (batch, in_channels, N)
        batch, in_ch, N = x.shape
        out_ch = self.weights.shape[1]
        # FFT
        x_ft = np.fft.rfft(x, axis=-1)   # (batch, in_ch, N//2+1)
        # Truncate to n_modes and multiply by learnable weights
        out_ft = np.zeros((batch, out_ch, N//2+1), dtype=complex)
        for b in range(batch):
            # x_ft[b]: (in_ch, n_modes)
            out_ft[b, :, :self.n_modes] = np.einsum(
                'ix,iox->ox', x_ft[b, :, :self.n_modes], self.weights
            )
        # IFFT back to spatial domain
        return np.fft.irfft(out_ft, n=N, axis=-1)   # (batch, out_ch, N)

# Demo: 1D FNO spectral layer
N_grid = 64
K_modes = 12
np.random.seed(0)
x_in = np.random.randn(4, 3, N_grid)   # batch=4, in_ch=3
layer = SpectralConv1d(in_channels=3, out_channels=8, n_modes=K_modes)
x_out = layer.forward(x_in)
print(f'Input:  {x_in.shape}')
print(f'Output: {x_out.shape}')
print(f'Kept {K_modes}/{N_grid//2+1} Fourier modes ({100*K_modes/(N_grid//2+1):.0f}% of spectrum)')

# Visualize: input vs output for one channel
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
ax = axes[0]
for i in range(4):
    ax.plot(x_in[i, 0, :], alpha=0.7, linewidth=1.5,
            label=f'batch {i}')
ax.set_title(f'Input (in_ch=0)')
ax.set_xlabel('Grid point $x$')
ax.set_ylabel('Amplitude')
ax.legend(fontsize=9)

ax = axes[1]
for i in range(4):
    ax.plot(x_out[i, 0, :], alpha=0.7, linewidth=1.5,
            label=f'batch {i}')
ax.set_title(f'Output after spectral conv (out_ch=0, K={K_modes} modes)')
ax.set_xlabel('Grid point $x$')
ax.set_ylabel('Amplitude')
ax.legend(fontsize=9)

fig.tight_layout()
plt.show()

Code cell 38

# Effect of mode truncation: how many modes do you need?
# Test signal: smooth + noisy
N_grid = 128
x_grid = np.linspace(0, 2*np.pi, N_grid, endpoint=False)
u_smooth = np.sin(x_grid) + 0.5 * np.cos(3 * x_grid)
u_noisy  = u_smooth + 0.3 * np.random.randn(N_grid)

U = np.fft.rfft(u_noisy)
freqs_grid = np.arange(len(U))

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle('FNO mode truncation: reconstruction quality vs K modes', fontsize=13)

ax = axes[0]
ax.plot(np.abs(U), color=COLORS['primary'], linewidth=2, label='|U[k]|')
for K_cut, col in [(4, COLORS['error']), (12, COLORS['tertiary']), (24, COLORS['secondary'])]:
    ax.axvline(K_cut, color=col, linestyle='--', alpha=0.8, label=f'K={K_cut}')
ax.set_title('DFT magnitude — where to truncate?')
ax.set_xlabel('Frequency mode $k$')
ax.set_ylabel('|U[k]|')
ax.legend(fontsize=9)
ax.set_yscale('log')

ax = axes[1]
ax.plot(x_grid, u_smooth, color=COLORS['neutral'], linewidth=2, linestyle='--',
        label='True smooth signal')
ax.plot(x_grid, u_noisy, color=COLORS['neutral'], alpha=0.3, linewidth=1,
        label='Noisy input')
for K_cut, col in [(4, COLORS['error']), (12, COLORS['tertiary']), (24, COLORS['secondary'])]:
    U_trunc = np.zeros_like(U)
    U_trunc[:K_cut] = U[:K_cut]
    u_rec = np.fft.irfft(U_trunc, n=N_grid)
    err = np.mean((u_rec - u_smooth)**2)
    ax.plot(x_grid, u_rec, color=col, linewidth=1.8,
            label=f'K={K_cut} (MSE={err:.4f})')
ax.set_title('Reconstructions from K truncated modes')
ax.set_xlabel('$x$')
ax.set_ylabel('$u(x)$')
ax.legend(fontsize=9)

fig.tight_layout()
plt.show()

12. Monarch Matrices and Butterfly Factorizations

Monarch matrices (Dao et al. 2022) generalize butterfly matrices to a 2-factor product: M=LRM = L R where L,RL, R are block-diagonal, each with N\sqrt{N} blocks of size N\sqrt{N}.

Parameter count: O(NN)O(N\sqrt{N}) vs O(N2)O(N^2) for dense.

Key property: The DFT matrix FNF_N is a butterfly matrix — it factors as FN=B1B2Blog2NF_N = B_1 B_2 \cdots B_{\log_2 N} where each BjB_j is block-diagonal. This is precisely the Cooley-Tukey FFT factorization!

FlashFFTConv (Dao et al. 2023) uses Monarch structure to implement long convolutions (N8192N \leq 8192) via state space models with O(NN)O(N \sqrt{N}) hardware-efficient operations on GPU.

Code cell 40

# Visualize butterfly structure of the DFT matrix
N_demo = 8
omega = np.exp(2j * np.pi / N_demo)
F8 = np.array([[omega**(k*n) for n in range(N_demo)]
               for k in range(N_demo)]) / np.sqrt(N_demo)

# The DFT matrix sparsity pattern (thresholded |entry| > 0.01)
# For comparison: the butterfly factors B1, B2, B3

def butterfly_factor(N, stage):
    # Stage = 0, 1, ..., log2(N)-1
    # Returns a (N, N) sparse matrix representing one butterfly stage
    B = np.zeros((N, N), dtype=complex)
    stride = 2**(stage + 1)
    half = stride // 2
    for start in range(0, N, stride):
        for j in range(half):
            twiddle = np.exp(-2j * np.pi * j / stride)
            B[start + j,       start + j]        = 1.0
            B[start + j,       start + j + half] = twiddle
            B[start + j + half, start + j]       = 1.0
            B[start + j + half, start + j + half] = -twiddle
    return B / 2  # normalize

log2_N = int(np.log2(N_demo))
fig, axes = plt.subplots(1, log2_N + 1, figsize=(16, 4))
fig.suptitle(f'Butterfly factorization of F_{N_demo}: F = B0 * B1 * B2', fontsize=13)

# DFT matrix
axes[0].imshow(np.abs(F8), cmap='viridis', vmin=0)
axes[0].set_title(f'|F_{N_demo}| (full DFT)')
axes[0].set_xlabel('Column')
axes[0].set_ylabel('Row')

# Butterfly stages
for stage in range(log2_N):
    B = butterfly_factor(N_demo, stage)
    axes[stage + 1].imshow(np.abs(B), cmap='viridis', vmin=0)
    axes[stage + 1].set_title(f'|B{stage}| (stage {stage})')
    axes[stage + 1].set_xlabel('Column')
    axes[stage + 1].set_ylabel('Row')

fig.tight_layout()
plt.show()
print(f'Each butterfly stage has O(N) non-zeros; log2(N)={log2_N} stages total')
print(f'Total non-zeros in factored form: {N_demo * log2_N} vs {N_demo**2} in full matrix')

13. Looking Ahead: DFT → Convolution Theorem

We have established the DFT as a unitary transform with O(NlogN)O(N \log N) FFT computation. The next section (§04 Convolution Theorem) builds on this foundation:

DFT result (§03)What §04 does with it
DFT is a ring isomorphism on Z/NZ\mathbb{Z}/N\mathbb{Z}Circular convolution \leftrightarrow pointwise multiplication
FFT reduces DFT to O(NlogN)O(N \log N)Reduces O(N2)O(N^2) convolution to O(NlogN)O(N \log N)
Spectral leakage, windowingApplied in filter design
STFT framesOverlap-add = convolution in disguise
FNO spectral layerFull treatment in §04 (CNNs, S4, Mamba)

Core equation to carry forward:

xy=F1(F(x)F(y))x \circledast y = \mathcal{F}^{-1}(\mathcal{F}(x) \cdot \mathcal{F}(y))

The DFT turns the O(N2)O(N^2) circular convolution sum into O(NlogN)O(N \log N) operations.

Code cell 42

# Summary: key results and their parameter dependencies
print('=== Chapter §03 Summary ===')
print()
print('DFT definition:        X[k] = sum_{n=0}^{N-1} x[n] * exp(-2pi*i*n*k/N)')
print('IDFT:                  x[n] = (1/N) sum_{k=0}^{N-1} X[k] * exp(2pi*i*n*k/N)')
print()
print('FFT complexity:        O(N log N)  vs  O(N^2) naive')
print('Frequency resolution:  Δf = fs / N  (CANNOT be improved by zero-padding)')
print('Nyquist criterion:     fs >= 2 * f_max  (or aliases occur)')
print()
print('Window sidelobe levels:')
windows = [('Rectangular', -13), ('Hann', -31.5), ('Hamming', -42.7),
           ('Blackman', -58), ('Kaiser beta=8.6', -69)]
for name, sl in windows:
    print(f'  {name:<20} {sl:>6} dB')
print()
print('Whisper pipeline:      16kHz, 400-sample Hann, hop 160, FFT 512, 80 mel bins')
print('FNO:                   keeps K lowest modes (K << N//2+1)')
print('Monarch:               O(N sqrt(N)) params vs O(N^2) dense')

14. Two-Dimensional DFT

The 2D DFT extends naturally to images and 2D fields:

X[k1,k2]=n1=0N11n2=0N21x[n1,n2]e2πi(n1k1/N1+n2k2/N2)X[k_1, k_2] = \sum_{n_1=0}^{N_1-1} \sum_{n_2=0}^{N_2-1} x[n_1, n_2]\, e^{-2\pi i(n_1 k_1/N_1 + n_2 k_2/N_2)}

It decomposes a 2D signal into sinusoidal spatial frequency components. Low frequencies concentrate near the center (after fftshift); high frequencies lie at the edges.

Applications: Image compression (JPEG uses DCT, a DFT variant), CNNs in the frequency domain, MRI reconstruction (k-space = 2D Fourier space of tissue).

Code cell 44

# 2D DFT: image frequency analysis
from matplotlib.colors import LogNorm

N1, N2 = 128, 128
x1, x2 = np.meshgrid(np.arange(N1), np.arange(N2), indexing='ij')
img = (np.sin(2*np.pi*4*x1/N1) * np.cos(2*np.pi*6*x2/N2)
       + 0.5 * np.sin(2*np.pi*12*x1/N1)
       + 0.3 * np.random.randn(N1, N2))

IMG = np.fft.fft2(img)
IMG_shifted = np.fft.fftshift(IMG)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle('2D DFT: image to frequency domain', fontsize=14)

axes[0].imshow(img, cmap='viridis', aspect='auto')
axes[0].set_title('Input image')
axes[0].set_xlabel('$x_2$'); axes[0].set_ylabel('$x_1$')

axes[1].imshow(np.abs(IMG_shifted), cmap='viridis', aspect='auto',
               norm=LogNorm(vmin=1))
axes[1].set_title('|2D DFT| (log scale, DC at center)')
axes[1].set_xlabel('$k_2$ (shifted)'); axes[1].set_ylabel('$k_1$ (shifted)')

# Low-pass filter: keep central r modes
r = 10
mask = np.zeros((N1, N2), dtype=bool)
cx, cy = N1//2, N2//2
mask[cx-r:cx+r, cy-r:cy+r] = True
IMG_lp = IMG.copy()
IMG_lp[~np.fft.ifftshift(mask)] = 0
img_rec = np.real(np.fft.ifft2(IMG_lp))

axes[2].imshow(img_rec, cmap='viridis', aspect='auto')
axes[2].set_title(f'Low-pass filtered (r={r} modes)')
axes[2].set_xlabel('$x_2$'); axes[2].set_ylabel('$x_1$')

fig.tight_layout()
plt.show()
print(f'Kept {mask.sum()} of {N1*N2} Fourier modes ({100*mask.sum()/(N1*N2):.1f}%)')

15. Discrete Cosine Transform (DCT)

The DCT-II is defined as:

X[k]=n=0N1x[n]cos ⁣(πk(2n+1)2N)X[k] = \sum_{n=0}^{N-1} x[n] \cos\!\left(\frac{\pi k(2n+1)}{2N}\right)

Why DCT over DFT for real signals?

PropertyDFTDCT-II
OutputComplex (even for real input)Real
Boundary conditionsPeriodicEven-symmetric
Energy compactionModerateExcellent (Gibbs-free)
JPEG/MPEG blocksYes (8x8 blocks)

The DCT achieves better energy compaction than DFT for smooth signals because it implicitly uses even extension -- eliminating the Gibbs phenomenon at boundaries.

Code cell 46

from scipy.fft import dct

N_dct = 64
t_dct = np.arange(N_dct)
x_dct = (np.sin(2*np.pi*3*t_dct/N_dct) + 0.5*np.cos(2*np.pi*7*t_dct/N_dct)
         + 0.2 * np.sin(2*np.pi*15*t_dct/N_dct))

X_dft = np.fft.rfft(x_dct)
X_dct2 = dct(x_dct, type=2, norm='ortho')

K_vals = np.arange(1, N_dct//2 + 1)
dft_efrac = [np.sum(np.abs(X_dft[:k])**2) / np.sum(np.abs(X_dft)**2) for k in K_vals]
dct_efrac = [np.sum(X_dct2[:k]**2) / np.sum(X_dct2**2) for k in K_vals]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(np.abs(X_dft), color=COLORS['primary'], linewidth=2, label='|DFT|')
axes[0].plot(np.abs(X_dct2[:N_dct//2+1]), color=COLORS['secondary'],
             linewidth=2, label='|DCT-II|')
axes[0].set_title('Coefficient magnitudes: DFT vs DCT-II')
axes[0].set_xlabel('Coefficient index $k$'); axes[0].set_ylabel('Magnitude')
axes[0].legend()

axes[1].plot(K_vals, dft_efrac, color=COLORS['primary'], linewidth=2, label='DFT')
axes[1].plot(K_vals, dct_efrac, color=COLORS['secondary'], linewidth=2, label='DCT-II')
axes[1].axhline(0.99, color=COLORS['neutral'], linestyle='--', alpha=0.7, label='99%')
axes[1].set_title('Cumulative energy: DCT compacts energy faster')
axes[1].set_xlabel('Number of coefficients $K$')
axes[1].set_ylabel('Fraction of total energy')
axes[1].legend(fontsize=9)

fig.tight_layout(); plt.show()

k99_dft = K_vals[next(i for i,e in enumerate(dft_efrac) if e >= 0.99)]
k99_dct = K_vals[next(i for i,e in enumerate(dct_efrac) if e >= 0.99)]
print(f'Coefficients for 99% energy -- DFT: {k99_dft}, DCT-II: {k99_dct}')

16. Number Theoretic Transform (NTT)

The NTT is the DFT over a finite field Zp\mathbb{Z}_p instead of C\mathbb{C}:

X[k]=n=0N1x[n]gnk(modp)X[k] = \sum_{n=0}^{N-1} x[n]\, g^{nk} \pmod{p}

where gg is a primitive NN-th root of unity modulo a prime pp.

Why it matters for AI:

  • Fully Homomorphic Encryption (FHE): Privacy-preserving ML inference uses NTT-based polynomial multiplication in encrypted rings (CKKS, BFV schemes).
  • Exact arithmetic: No floating-point rounding error.
  • Same O(NlogN)O(N \log N) butterfly algorithm as Cooley-Tukey FFT.

Code cell 48

# Number Theoretic Transform: exact polynomial multiplication mod p
MOD = 998244353  # NTT-friendly prime: 119 * 2^23 + 1
G_ROOT = 3

def ntt(a, invert=False):
    n = len(a)
    a = list(a)
    j = 0
    for i in range(1, n):
        bit = n >> 1
        while j & bit:
            j ^= bit; bit >>= 1
        j ^= bit
        if i < j: a[i], a[j] = a[j], a[i]
    length = 2
    while length <= n:
        w = pow(G_ROOT, MOD - 1 - (MOD-1)//length if invert else (MOD-1)//length, MOD)
        for i in range(0, n, length):
            wn = 1
            for k_ntt in range(length // 2):
                u = a[i + k_ntt]
                v = a[i + k_ntt + length//2] * wn % MOD
                a[i + k_ntt]             = (u + v) % MOD
                a[i + k_ntt + length//2] = (u - v) % MOD
                wn = wn * w % MOD
        length <<= 1
    if invert:
        n_inv = pow(n, MOD - 2, MOD)
        a = [x * n_inv % MOD for x in a]
    return a

def poly_mul_ntt(p1, p2):
    rlen = len(p1) + len(p2) - 1
    n = 1
    while n < rlen: n <<= 1
    fa = ntt(p1 + [0]*(n-len(p1)))
    fb = ntt(p2 + [0]*(n-len(p2)))
    fc = [(a*b) % MOD for a,b in zip(fa, fb)]
    return ntt(fc, invert=True)[:rlen]

p1, p2 = [1, 2, 3], [4, 5]
result = poly_mul_ntt(p1, p2)
expected = list(np.polymul(p1[::-1], p2[::-1])[::-1].astype(int))
print(f'p1 = {p1}  (1 + 2x + 3x^2)')
print(f'p2 = {p2}  (4 + 5x)')
print(f'NTT result:  {result}')
print(f'Expected:    {expected}')
print(f'Match: {result == expected}')

17. Magnitude and Phase Spectra

The DFT output X[k]=X[k]eiX[k]X[k] = |X[k]| e^{i \angle X[k]} has two components:

  • Magnitude spectrum X[k]|X[k]|: how much of each frequency
  • Phase spectrum X[k]\angle X[k]: when (alignment) of each frequency

Phase is perceptually critical for speech: Scrambling phase while preserving magnitude renders speech unintelligible.

Group delay: ddωX(ω)-\frac{d}{d\omega} \angle X(\omega) measures time delay per frequency. Constant group delay = linear phase = no dispersion.

Code cell 50

# Swap magnitude and phase between two signals
np.random.seed(7)
N_ph = 256
t_ph = np.arange(N_ph) / 1000.0
sig_A = sum(np.sin(2*np.pi*k*80*t_ph + np.pi*k/4)/k for k in range(1, 8))
sig_B = np.random.randn(N_ph) * np.hanning(N_ph)

FA = np.fft.rfft(sig_A)
FB = np.fft.rfft(sig_B)
sig_AB = np.fft.irfft(np.abs(FA) * np.exp(1j * np.angle(FB)))
sig_BA = np.fft.irfft(np.abs(FB) * np.exp(1j * np.angle(FA)))

fig, axes = plt.subplots(2, 2, figsize=(13, 7))
fig.suptitle('Phase vs magnitude: swapping spectral components', fontsize=14)
for ax, (sig, title) in zip(axes.flatten(),
    [(sig_A, 'Signal A (harmonic)'),
     (sig_B, 'Signal B (noise burst)'),
     (sig_AB, 'Magnitude(A) + Phase(B)'),
     (sig_BA, 'Magnitude(B) + Phase(A)')]):
    ax.plot(t_ph*1000, sig, color=COLORS['primary'], linewidth=1.5)
    ax.set_title(title)
    ax.set_xlabel('Time (ms)'); ax.set_ylabel('Amplitude')
fig.tight_layout(); plt.show()
print('Phase(A) + any magnitude -> character of A preserved.')
print('Magnitude alone does not determine signal structure.')

Code cell 51

# DFT as complete basis: partial reconstruction from K frequency pairs
N_rec = 64
t_rec = np.arange(N_rec)
x_orig = (1.5*np.sin(2*np.pi*2*t_rec/N_rec)
          + 0.8*np.cos(2*np.pi*5*t_rec/N_rec)
          + 0.4*np.sin(2*np.pi*11*t_rec/N_rec))
X_orig = np.fft.fft(x_orig)

fig, axes = plt.subplots(2, 2, figsize=(13, 8))
fig.suptitle('Partial reconstruction: using K of N/2 complex frequencies', fontsize=13)
axes.flatten()[0].plot(t_rec, x_orig, color=COLORS['primary'], linewidth=2)
axes.flatten()[0].set_title('Original signal')
axes.flatten()[0].set_xlabel('Sample $n$'); axes.flatten()[0].set_ylabel('Amplitude')

for ax, K_keep in zip(axes.flatten()[1:], [2, 5, 16]):
    X_partial = np.zeros_like(X_orig)
    X_partial[:K_keep] = X_orig[:K_keep]
    X_partial[-K_keep:] = X_orig[-K_keep:]
    x_rec = np.real(np.fft.ifft(X_partial))
    err = np.mean((x_rec - x_orig)**2)
    ax.plot(t_rec, x_orig, color=COLORS['neutral'], alpha=0.4, linestyle='--',
            linewidth=1.5, label='Original')
    ax.plot(t_rec, x_rec, color=COLORS['error'], linewidth=2,
            label=f'Reconstruction (MSE={err:.4f})')
    ax.set_title(f'K={K_keep} frequency pairs')
    ax.set_xlabel('Sample $n$'); ax.set_ylabel('Amplitude')
    ax.legend(fontsize=9)
fig.tight_layout(); plt.show()

Code cell 52

# Interactive: verify all DFT properties numerically
print('=== DFT Property Verification Suite ===')
N_v = 16
rng_v = np.random.default_rng(42)
x = rng_v.standard_normal(N_v) + 1j*rng_v.standard_normal(N_v)
y = rng_v.standard_normal(N_v) + 1j*rng_v.standard_normal(N_v)
a, b, m = 2.3+1.1j, -0.7+0.4j, 3
X, Y = np.fft.fft(x), np.fft.fft(y)

# 1. Linearity
err1 = np.max(np.abs(np.fft.fft(a*x + b*y) - (a*X + b*Y)))
print(f'1. Linearity:          max error = {err1:.2e}')

# 2. Circular shift
rhs2 = np.exp(-2j*np.pi*m/N_v)**np.arange(N_v) * X
err2 = np.max(np.abs(np.fft.fft(np.roll(x, m)) - rhs2))
print(f'2. Circular shift:     max error = {err2:.2e}')

# 3. Circular convolution via IFFT(X*Y)
xy_conv = np.fft.ifft(X * Y)
xy_direct = np.array([np.sum(x * np.roll(y[::-1], k+1)) for k in range(N_v)])
err3 = np.max(np.abs(xy_conv - xy_direct))
print(f'3. Circular conv:      max error = {err3:.2e}')

# 4. Parseval
err4 = abs(np.sum(np.abs(x)**2) - np.sum(np.abs(X)**2)/N_v)
print(f'4. Parseval:           |diff| = {err4:.2e}')

# 5. Unitarity F F*/N = I
F = np.fft.fft(np.eye(N_v), axis=0)
err5 = np.max(np.abs(F @ F.conj().T / N_v - np.eye(N_v)))
print(f'5. Unitarity F F*/N=I: max error = {err5:.2e}')

# 6. Conjugate symmetry for real input
x_real = rng_v.standard_normal(N_v)
X_real = np.fft.fft(x_real)
err6 = np.max(np.abs(X_real[1:] - np.conj(X_real[-1:0:-1])))
print(f'6. Conj. symmetry:     max error = {err6:.2e}')
print()
print('All properties verified to machine precision.')

18. The DFT as a Group Algebra Isomorphism

The DFT is the Fourier transform on the cyclic group Z/NZ\mathbb{Z}/N\mathbb{Z}:

F:C[Z/NZ]CN\mathcal{F}: \mathbb{C}[\mathbb{Z}/N\mathbb{Z}] \xrightarrow{\sim} \mathbb{C}^N

The group algebra C[Z/NZ]\mathbb{C}[\mathbb{Z}/N\mathbb{Z}] has multiplication given by circular convolution. The DFT diagonalizes this multiplication:

F(xy)=F(x)F(y)(pointwise)\mathcal{F}(x \circledast y) = \mathcal{F}(x) \cdot \mathcal{F}(y) \quad (\text{pointwise})

This is why the Convolution Theorem (§04) is so fundamental: it follows directly from the DFT being a ring isomorphism.

For non-abelian groups, this generalizes to the non-abelian Fourier transform (representation theory), which underlies geometric deep learning on symmetric spaces.

19. Connections to Other Chapters

TopicThis ChapterConnection
Fourier Series (§01)DFT is discrete/finite analoguePeriod NN signal <-> discrete spectrum
Fourier Transform (§02)DFT approximates continuous FTSampling theorem bridges them
Convolution Theorem (§04)DFT enables FFT-based convolution\circledast \to \cdot
Wavelets (§05)DWT = iterated filter banksMulti-resolution via FFT
Linear Algebra (§02-03)FNF_N is a unitary matrixDFT = change of basis
Compressed SensingSparsity in Fourier basisCS recovery for MRI
CNNsSpectral convolutionCircular conv = pointwise mult (§04)
SSMs (S4/Mamba)Long convolutions via FFTSequences as spectral filters

The Cooley-Tukey insight (1965): the DFT's O(N2)O(N^2) computation collapses to O(NlogN)O(N \log N) by exploiting the factorization FN=B1B2BlogNF_N = B_1 B_2 \cdots B_{\log N}. This single algorithm makes Fourier analysis practical for all of signal processing, communications, scientific computing, and modern deep learning.

Code cell 55

# Application: spectral analysis of a synthetic LLM training loss curve
np.random.seed(42)
n_steps = 2000
steps = np.arange(n_steps)
trend     = 2.0 * np.exp(-steps / 800)
epoch_osc = 0.05 * np.sin(2*np.pi*steps/100)
lr_osc    = 0.03 * np.sin(2*np.pi*steps/500)
noise_c   = 0.02 * np.random.randn(n_steps)
loss = trend + epoch_osc + lr_osc + noise_c

loss_detrend = loss - trend
L = np.fft.rfft(loss_detrend)
freqs_loss = np.fft.rfftfreq(n_steps)

fig, axes = plt.subplots(2, 1, figsize=(13, 8))
fig.suptitle('Spectral analysis of LLM training loss', fontsize=14)

axes[0].plot(steps, loss, color=COLORS['primary'], linewidth=1, alpha=0.8)
axes[0].plot(steps, trend, color=COLORS['error'], linewidth=2, linestyle='--',
             label='Trend')
axes[0].set_title('Training loss curve')
axes[0].set_xlabel('Training step'); axes[0].set_ylabel('Loss')
axes[0].legend()

period_axis = 1.0 / (freqs_loss[1:] + 1e-12)
axes[1].plot(period_axis, np.abs(L[1:]),
             color=COLORS['primary'], linewidth=1.5)
for p, label in [(100, 'epoch\n(100 steps)'), (500, 'LR schedule\n(500 steps)')]:
    axes[1].axvline(p, color=COLORS['error'], linestyle='--', alpha=0.8)
    axes[1].text(p+8, np.abs(L[1:]).max()*0.85, label, fontsize=9,
                color=COLORS['error'])
axes[1].set_title('DFT magnitude of detrended loss')
axes[1].set_xlabel('Period (steps)'); axes[1].set_ylabel('|L[k]|')
axes[1].set_xlim(10, 1000)
fig.tight_layout(); plt.show()

Code cell 56

# Final summary: key quantitative results
print('=== Chapter 03 Key Results ===')
print()
print('DFT definition:        X[k] = sum_{n} x[n] exp(-2pi*i*nk/N)')
print('IDFT:                  x[n] = (1/N) sum_{k} X[k] exp(2pi*i*nk/N)')
print()
print('FFT complexity:        O(N log2 N)  vs  O(N^2) naive DFT')
print('Frequency resolution:  Δf = fs / N  (set by observation time T = N/fs)')
print('Zero-padding:          Interpolates spectrum, does NOT improve resolution')
print('Nyquist criterion:     fs >= 2 f_max  (else aliasing via Poisson summation)')
print()
print('Window sidelobe attenuation:')
for name, sl in [('Rectangular', -13), ('Hann', -31.5),
                  ('Hamming', -42.7), ('Blackman', -58), ('Kaiser beta=8.6', -69)]:
    print(f'  {name:<22} {sl:>6} dB')
print()
print('Whisper pipeline:  16kHz, 400-pt Hann, hop 160, FFT 512, 80 mel bins')
print('FNO spectral conv: keep K lowest modes (K << N//2+1), learn R_theta in C^K')
print('Monarch matrices:  O(N sqrt(N)) params vs O(N^2) dense')
print('NTT:               DFT over Z_p, O(N log N), exact -- powers FHE')

Chapter Summary

This notebook developed the DFT and FFT from first principles to modern ML applications.

What we proved:

  • DFT is a unitary transformation: FNFN=NIF_N F_N^* = NI
  • Cooley-Tukey FFT reduces O(N2)O(N^2) to O(NlogN)O(N \log N) via butterfly factorization
  • Frequency resolution is Δf=fs/N\Delta f = f_s/N, set by observation time -- not zero-padding
  • COLA condition enables perfect signal reconstruction from STFT frames

What we built:

  • Recursive Cooley-Tukey FFT matching scipy.fft to machine precision
  • Whisper mel spectrogram pipeline (16 kHz -> 80 mel bins)
  • FNO spectral convolution layer (KK-mode truncation)
  • NTT for exact polynomial multiplication over Zp\mathbb{Z}_p

What's next (§04 Convolution Theorem):

xy  F  X[k]Y[k]x \circledast y \;\overset{\mathcal{F}}{\longleftrightarrow}\; X[k] \cdot Y[k]

This duality reduces O(N2)O(N^2) convolution to O(NlogN)O(N \log N) and underpins CNNs, WaveNet, S4/Mamba, and Hyena long convolutions.