Theory NotebookMath for LLMs

Wavelets

Fourier Analysis and Signal Processing / Wavelets

Run notebook
Theory Notebook

Theory Notebook

Converted from theory.ipynb for web reading.

§20-05 Wavelets and Multiresolution Analysis

"Wavelets are a mathematical microscope: by changing the magnification and the position of the lens, one can examine local features at any desired scale." — Stéphane Mallat

Interactive theory notebook covering: CWT scalograms, MRA axioms, Mallat fast DWT, Daubechies wavelet construction, 2D image DWT, scattering networks, and wavelet denoising.

Code cell 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

try:
    import seaborn as sns
    sns.set_theme(style="whitegrid", palette="colorblind")
    HAS_SNS = True
except ImportError:
    plt.style.use("seaborn-v0_8-whitegrid")
    HAS_SNS = False

mpl.rcParams.update({
    "figure.figsize":    (10, 6),
    "figure.dpi":         120,
    "font.size":           13,
    "axes.titlesize":      15,
    "axes.labelsize":      13,
    "xtick.labelsize":     11,
    "ytick.labelsize":     11,
    "legend.fontsize":     11,
    "legend.framealpha":   0.85,
    "lines.linewidth":      2.0,
    "axes.spines.top":     False,
    "axes.spines.right":   False,
    "savefig.bbox":       "tight",
    "savefig.dpi":         150,
})
np.random.seed(42)
print("Plot setup complete.")

Code cell 3

import numpy as np
import scipy.signal as sig
from scipy.signal import chirp

try:
    import matplotlib.pyplot as plt
    import matplotlib as mpl
    try:
        import seaborn as sns
        sns.set_theme(style='whitegrid', palette='colorblind')
        HAS_SNS = True
    except ImportError:
        plt.style.use('seaborn-v0_8-whitegrid')
        HAS_SNS = False
    mpl.rcParams.update({
        'figure.figsize': (10, 6),
        'figure.dpi': 100,
        'font.size': 12,
        'axes.titlesize': 14,
        'axes.labelsize': 12,
        'lines.linewidth': 2.0,
        'axes.spines.top': False,
        'axes.spines.right': False,
    })
    HAS_MPL = True
except ImportError:
    HAS_MPL = False

try:
    import pywt
    HAS_PYWT = True
    print(f'PyWavelets {pywt.__version__} available')
except ImportError:
    HAS_PYWT = False
    print('PyWavelets not available — some cells will use manual implementations')

np.set_printoptions(precision=6, suppress=True)
np.random.seed(42)

COLORS = {
    'primary':   '#0077BB',
    'secondary': '#EE7733',
    'tertiary':  '#009988',
    'error':     '#CC3311',
    'neutral':   '#555555',
    'highlight': '#EE3377',
}
print('Setup complete.')

1. Intuition: The Time-Frequency Problem

The Fourier transform gives perfect frequency resolution but zero time resolution. Wavelets resolve this by using basis functions that are both localized in time and in frequency.

The STFT uses a fixed-width window — same time resolution for all frequencies. Wavelets use a scale-adaptive window — narrow at high frequencies (good time resolution), wide at low frequencies (good frequency resolution).

Code cell 5

# === 1.1 Fourier vs STFT vs Wavelet Time-Frequency Tiling ===

# Generate a non-stationary signal: chirp + transient
fs = 1000      # sampling rate (Hz)
t = np.linspace(0, 1, fs, endpoint=False)
N = len(t)

# Component 1: chirp sweeping 50→300 Hz over [0, 0.7s]
f_chirp = chirp(t, f0=50, f1=300, t1=0.7, method='linear')
# Component 2: short burst at 400 Hz in [0.5, 0.55s]
burst = np.zeros(N)
burst_idx = (t >= 0.5) & (t < 0.55)
burst[burst_idx] = np.sin(2*np.pi*400*t[burst_idx])

signal = f_chirp + burst

# Method 1: Fourier magnitude (no time info)
X = np.fft.rfft(signal)
freqs = np.fft.rfftfreq(N, 1/fs)

# Method 2: STFT (fixed 50ms window)
f_stft, t_stft, Zxx = sig.stft(signal, fs, nperseg=50, noverlap=40)

if HAS_MPL:
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    axes[0].plot(t, signal, color=COLORS['primary'])
    axes[0].set_title('Signal (chirp + burst)')
    axes[0].set_xlabel('Time (s)')
    axes[0].set_ylabel('Amplitude')

    axes[1].plot(freqs[:300], np.abs(X[:300]), color=COLORS['secondary'])
    axes[1].set_title('Fourier Magnitude (no time info)')
    axes[1].set_xlabel('Frequency (Hz)')
    axes[1].set_ylabel('|X(f)|')

    axes[2].pcolormesh(t_stft, f_stft[:40], np.abs(Zxx[:40]),
                       shading='gouraud', cmap='viridis')
    axes[2].set_title('STFT Spectrogram (fixed window)')
    axes[2].set_xlabel('Time (s)')
    axes[2].set_ylabel('Frequency (Hz)')

    plt.tight_layout()
    plt.savefig('/tmp/wav_tfr.png', dpi=100, bbox_inches='tight')
    plt.show()

print(f'Signal: {N} samples, chirp 50-300 Hz + 400 Hz burst at t=0.5s')
print('Fourier: sees both components, cannot locate burst in time')
print('STFT: fixed window — same resolution for all frequencies')

2. Wavelet Families

Different wavelets suit different applications. The key properties are:

  • Vanishing moments NN: number of polynomial terms annihilated (tkψdt=0\int t^k \psi\,dt=0 for k<Nk<N)
  • Support length: 2N12N-1 for dbNN
  • Regularity: Hölder exponent α0.2075N\alpha \approx 0.2075N
  • Symmetry: dbNN is asymmetric; symNN is near-symmetric

Code cell 7

# === 2.1 Wavelet Gallery ===

if HAS_PYWT:
    wavelet_names = ['haar', 'db2', 'db4', 'db8', 'sym4', 'coif4', 'mexh', 'morl']
    fig, axes = plt.subplots(2, 4, figsize=(14, 6))

    for i, wname in enumerate(wavelet_names):
        ax = axes[i//4, i%4]
        try:
            w = pywt.Wavelet(wname)
            # Get wavelet function values
            phi, psi, x = w.wavefun(level=8)
            ax.plot(x, psi, color=COLORS['primary'], lw=1.5)
            ax.axhline(0, color='gray', lw=0.5)
            ax.set_title(wname, fontsize=11)
        except Exception:
            # Continuous wavelets
            scale = 1.0
            x_c, psi_c = pywt.ContinuousWavelet(wname).wavefun()
            ax.plot(x_c, np.real(psi_c), color=COLORS['primary'], lw=1.5)
            ax.axhline(0, color='gray', lw=0.5)
            ax.set_title(wname, fontsize=11)
        ax.set_xlabel('t')
        ax.set_yticks([])

    plt.suptitle('Wavelet Gallery: ψ(t) for Common Families', fontsize=13, fontweight='bold')
    plt.tight_layout()
    plt.savefig('/tmp/wav_gallery.png', dpi=100, bbox_inches='tight')
    plt.show()

    # Print properties
    print('Daubechies wavelet properties:')
    print(f"{'Name':<8} {'VanMom':>8} {'Support':>9} {'FilterLen':>10}")
    for n in [1, 2, 4, 6, 8]:
        w = pywt.Wavelet(f'db{n}')
        print(f'db{n:<6} {w.vanishing_moments_psi:>8} '
              f'[0,{2*n-1}]{"":>3} {len(w.dec_lo):>10}')
else:
    print('PyWavelets not available. Install with: pip install PyWavelets')

2.2 Continuous Wavelet Transform: Scalogram

The scalogram Wψf(a,b)2|W_\psi f(a, b)|^2 shows signal energy as a function of time bb and scale aa. Unlike the STFT, the time-frequency tiles are logarithmically spaced — fine at high frequencies, coarse at low frequencies (constant-Q analysis).

Code cell 9

# === 2.2 CWT Scalogram of Chirp + Burst ===

if HAS_PYWT:
    # CWT with Morlet wavelet
    scales_cwt = np.logspace(0.3, 2.5, 80)  # log-spaced scales
    coefs_cwt, freqs_cwt = pywt.cwt(signal, scales_cwt, 'morl', sampling_period=1/fs)

    if HAS_MPL:
        fig, axes = plt.subplots(2, 1, figsize=(12, 8))

        axes[0].plot(t, signal, color=COLORS['primary'], lw=1)
        axes[0].set_title('Signal')
        axes[0].set_xlabel('Time (s)')

        im = axes[1].pcolormesh(
            t, freqs_cwt, np.abs(coefs_cwt)**2,
            shading='gouraud', cmap='plasma'
        )
        axes[1].set_yscale('log')
        axes[1].set_ylim([20, 500])
        axes[1].set_title('CWT Scalogram (Morlet) — logarithmic time-frequency tiling')
        axes[1].set_xlabel('Time (s)')
        axes[1].set_ylabel('Frequency (Hz)')
        plt.colorbar(im, ax=axes[1], label='Power')

        plt.tight_layout()
        plt.savefig('/tmp/wav_scalogram.png', dpi=100, bbox_inches='tight')
        plt.show()

    # Verify: chirp ridge at linear freq, burst localized at t=0.5
    peak_time_idx = np.argmax(np.abs(coefs_cwt[freqs_cwt > 350, :][:5]).mean(axis=0))
    peak_time = t[peak_time_idx]
    print(f'Peak energy near 400 Hz at t = {peak_time:.3f}s (burst is at 0.50-0.55s)')
    ok = 0.48 <= peak_time <= 0.57
    print(f"{'PASS' if ok else 'FAIL'} — CWT correctly localizes burst")
else:
    print('PyWavelets not available.')

3. Multiresolution Analysis (MRA)

MRA provides the algebraic framework for wavelet filter banks. The key idea: represent a signal at successively coarser scales, with each scale's 'detail' captured by the detail space WjW_j.

L2(R)=jZWjL^2(\mathbb{R}) = \bigoplus_{j \in \mathbb{Z}} W_j

The Mallat algorithm computes this decomposition in O(N)O(N) via iterated low-pass / high-pass filtering.

Code cell 11

# === 3.1 MRA: Nested Approximation Spaces ===
# Visualize how V_j spaces nest and approximate a signal

N_mra = 128
t_mra = np.linspace(0, 1, N_mra, endpoint=False)

# Test signal: mixture of smooth + piecewise
f_mra = np.sin(2*np.pi*3*t_mra) + 0.5*(t_mra > 0.5).astype(float)

if HAS_PYWT:
    fig, axes = plt.subplots(4, 1, figsize=(12, 10))

    axes[0].plot(t_mra, f_mra, color=COLORS['primary'], lw=2)
    axes[0].set_title('Original Signal f')

    for j, level in enumerate([1, 2, 4]):
        coeffs_mra = pywt.wavedec(f_mra, 'db2', level=level)
        # Reconstruct approximation only (zero detail coefficients)
        zeros = [np.zeros_like(c) for c in coeffs_mra[1:]]
        approx = pywt.waverec([coeffs_mra[0]] + zeros, 'db2')[:N_mra]

        axes[j+1].plot(t_mra, f_mra, color=COLORS['neutral'], alpha=0.3, lw=1, label='Original')
        axes[j+1].plot(t_mra, approx, color=COLORS['secondary'], lw=2,
                       label=f'V_{level} approximation')
        axes[j+1].legend(loc='upper right')
        axes[j+1].set_title(f'Projection onto V_{level} ({N_mra//2**level} coefficients)')

    for ax in axes:
        ax.set_xlabel('t')

    plt.suptitle('MRA: Nested Approximation Spaces V_j', fontsize=13, fontweight='bold')
    plt.tight_layout()
    plt.savefig('/tmp/wav_mra.png', dpi=100, bbox_inches='tight')
    plt.show()
print('As level increases (coarser V_j), approximation loses fine detail.')

Code cell 12

# === 3.2 Scaling Function Cascade Algorithm ===
# Build phi(t) by iterating phi(t) = sqrt(2) * sum_k h_k phi(2t - k)

def cascade_algorithm(h, n_iter=8):
    """Compute scaling function phi by iterating the refinement equation."""
    # Start with box function on [0, 1]
    p = len(h) - 1  # support of phi in [0, p]
    # phi at iteration 0: box function
    phi = np.zeros(2**n_iter * p + 1)
    phi[0:2**n_iter] = 1.0

    for _ in range(n_iter):
        # Apply refinement: phi_new(t) = sqrt(2) * sum_k h_k phi(2t - k)
        N_fine = len(phi) * 2 - 1
        phi_up = np.zeros(N_fine)
        phi_up[::2] = phi  # upsample
        phi_fine = np.sqrt(2) * np.convolve(phi_up, h[::-1], mode='full')
        phi = phi_fine[:N_fine]

    t_phi = np.linspace(0, p, len(phi))
    return t_phi, phi / (phi.sum() * (t_phi[1] - t_phi[0]))  # normalize

if HAS_PYWT:
    wavelet_names_cascade = ['haar', 'db2', 'db4', 'db6']
    fig, axes = plt.subplots(2, 4, figsize=(14, 6))

    for i, wname in enumerate(wavelet_names_cascade):
        w = pywt.Wavelet(wname)
        phi_vals, psi_vals, x_w = w.wavefun(level=10)

        axes[0, i].plot(x_w, phi_vals, color=COLORS['primary'], lw=2)
        axes[0, i].set_title(f'{wname} scaling φ')
        axes[0, i].axhline(0, color='gray', lw=0.5)

        axes[1, i].plot(x_w, psi_vals, color=COLORS['secondary'], lw=2)
        axes[1, i].set_title(f'{wname} wavelet ψ')
        axes[1, i].axhline(0, color='gray', lw=0.5)

    plt.suptitle('Scaling Functions φ and Wavelets ψ (Daubechies family)', fontsize=13)
    plt.tight_layout()
    plt.savefig('/tmp/wav_cascade.png', dpi=100, bbox_inches='tight')
    plt.show()
    print('Note how db1=Haar is discontinuous, while db4,db6 become increasingly smooth.')

print("audit output: 3.2 Scaling Function Cascade Algorithm === complete or optional branch skipped.")

4. The Mallat Algorithm (Fast DWT)

The Mallat algorithm computes the DWT via iterated convolution + downsampling:

ak(j+1)=nhn2kan(j)(low-pass + downsample)a^{(j+1)}_k = \sum_n h_{n-2k}\, a^{(j)}_n \qquad \text{(low-pass + downsample)} dk(j+1)=ngn2kan(j)(high-pass + downsample)d^{(j+1)}_k = \sum_n g_{n-2k}\, a^{(j)}_n \qquad \text{(high-pass + downsample)}

This costs O(N)O(N) total — faster than the FFT's O(NlogN)O(N\log N).

Code cell 14

# === 4.1 Haar DWT from Scratch ===

def haar_step(x):
    """One step of Haar DWT: return (approximation, detail)."""
    n = len(x) // 2
    a = (x[0::2] + x[1::2]) / np.sqrt(2)
    d = (x[0::2] - x[1::2]) / np.sqrt(2)
    return a, d

def haar_dwt(x, J):
    """Multi-level Haar DWT. Returns [aJ, dJ, d(J-1), ..., d1]."""
    coeffs = []
    a = x.copy()
    for j in range(J):
        a, d = haar_step(a)
        coeffs.append(d)
    coeffs.append(a)
    return list(reversed(coeffs))  # [aJ, dJ, ..., d1]

def haar_istep(a, d):
    """One step of Haar IDWT."""
    n = len(a)
    x = np.zeros(2*n)
    x[0::2] = (a + d) / np.sqrt(2)
    x[1::2] = (a - d) / np.sqrt(2)
    return x

def haar_idwt(coeffs, J):
    """Multi-level Haar IDWT."""
    a = coeffs[0]
    for j in range(J):
        d = coeffs[j+1]
        a = haar_istep(a, d)
    return a

# Test signal
np.random.seed(42)
N_test = 64
x_test = np.sin(2*np.pi*np.arange(N_test)/16) + 0.3*np.random.randn(N_test)
J_levels = 4

coeffs_haar = haar_dwt(x_test, J_levels)
x_rec = haar_idwt(coeffs_haar, J_levels)

# Perfect reconstruction check
err = np.max(np.abs(x_rec - x_test))
print(f'Perfect reconstruction error: {err:.2e}')
print(f"PASS: {'yes' if err < 1e-10 else 'no'}")

# Parseval check
energy_orig = np.sum(x_test**2)
energy_wav = sum(np.sum(c**2) for c in coeffs_haar)
print(f'\nParseval: ||x||² = {energy_orig:.6f}')
print(f'           sum coeffs² = {energy_wav:.6f}')
print(f"PASS: {abs(energy_orig - energy_wav) < 1e-10}")

# Print coefficient sizes
print(f'\nCoefficient sizes: {[len(c) for c in coeffs_haar]}')
print(f'Total: {sum(len(c) for c in coeffs_haar)} (= N = {N_test})')

Code cell 15

# === 4.2 Mallat Algorithm Visualization ===

if HAS_MPL:
    N_vis = 256
    t_vis = np.linspace(0, 1, N_vis)
    # Signal: smooth + jump + high-freq burst
    f_vis = (np.sin(2*np.pi*2*t_vis)
             + 0.5*(t_vis > 0.6)
             + 0.3*np.sin(2*np.pi*50*t_vis)*(t_vis > 0.3)*(t_vis < 0.4))

    J_vis = 4
    coeffs_vis = haar_dwt(f_vis, J_vis)

    fig, axes = plt.subplots(J_vis+2, 1, figsize=(12, 12))
    axes[0].plot(t_vis, f_vis, color=COLORS['primary'], lw=2)
    axes[0].set_title('Original Signal')
    axes[0].set_ylabel('f(t)')

    axes[1].plot(coeffs_vis[0], color=COLORS['tertiary'], lw=2)
    axes[1].set_title(f'Approximation a{J_vis} (scale {2**J_vis}x coarser)')
    axes[1].set_ylabel(f'a{J_vis}')

    for j in range(1, J_vis+1):
        t_d = np.linspace(0, 1, len(coeffs_vis[j]))
        axes[j+1].stem(t_d, coeffs_vis[j], linefmt='C1-',
                       markerfmt='o', basefmt='gray')
        axes[j+1].set_title(f'Detail d{J_vis-j+1} (frequency band octave {J_vis-j+1})')
        axes[j+1].set_ylabel(f'd{J_vis-j+1}')

    axes[-1].set_xlabel('Position')
    plt.suptitle('Mallat DWT Decomposition Tree (Haar, 4 levels)', fontsize=13, fontweight='bold')
    plt.tight_layout()
    plt.savefig('/tmp/wav_mallat.png', dpi=100, bbox_inches='tight')
    plt.show()

print('Note: detail at level 1 (d1) captures high-frequency burst at t=0.3-0.4')
print('       detail at level 4 (d4) captures the low-frequency content')
print('       approximation shows the very smooth global structure')

4.3 QMF Relation and Perfect Reconstruction

The wavelet filter gg is derived from the scaling filter hh via:

gk=(1)kh1kg_k = (-1)^k h_{1-k}

This ensures the analysis filter bank is power-complementaryHH and GG together cover the entire frequency axis without gaps or overlap.

Perfect Reconstruction (PR): Alias cancellation + distortion-free condition ensures exact reconstruction IDWT(DWT(x))=x\text{IDWT}(\text{DWT}(x)) = x.

Code cell 17

# === 4.3 QMF Verification: db4 ===

if HAS_PYWT:
    w_db4 = pywt.Wavelet('db4')
    h = np.array(w_db4.dec_lo)  # analysis low-pass
    g = np.array(w_db4.dec_hi)  # analysis high-pass
    K = len(h)

    # Verify QMF: g[k] = (-1)^k * h[K-1-k]
    k_idx = np.arange(K)
    g_qmf = (-1)**k_idx * h[K-1-k_idx]
    err_qmf = np.max(np.abs(g - g_qmf))
    print(f'QMF relation g_k = (-1)^k h_{{K-1-k}}:')
    print(f'  Max error: {err_qmf:.2e}')
    print(f'  PASS: {err_qmf < 1e-10}')

    # Power complementary: |H(xi)|^2 + |H(xi+0.5)|^2 = 1
    N_grid = 512
    xi = np.linspace(0, 0.5, N_grid)
    H_full = np.fft.fft(h, n=N_grid*2)
    H_half = H_full[:N_grid]
    H_shift = H_full[N_grid:]  # H(xi + 0.5)
    power_sum = np.abs(H_half)**2 + np.abs(H_shift)**2
    err_pc = np.max(np.abs(power_sum - 1.0))
    print(f'\nPower complementary |H(xi)|^2 + |H(xi+0.5)|^2 = 1:')
    print(f'  Max error: {err_pc:.2e}')
    print(f'  PASS: {err_pc < 1e-10}')

    # Perfect reconstruction test
    x_pr = np.random.randn(256)
    coeffs_pr = pywt.wavedec(x_pr, 'db4', level=4)
    x_pr_rec = pywt.waverec(coeffs_pr, 'db4')[:256]
    err_pr = np.max(np.abs(x_pr - x_pr_rec))
    print(f'\nPerfect Reconstruction (db4, 4 levels):')
    print(f'  Max error: {err_pr:.2e}')
    print(f'  PASS: {err_pr < 1e-10}')

    if HAS_MPL:
        fig, ax = plt.subplots(1, 1, figsize=(8, 4))
        xi_plot = np.linspace(0, 0.5, N_grid)
        ax.plot(xi_plot, np.abs(H_half[:N_grid])**2,
                color=COLORS['primary'], lw=2, label='|H(ξ)|² (low-pass)')
        ax.plot(xi_plot, np.abs(H_shift[:N_grid])**2,
                color=COLORS['secondary'], lw=2, label='|H(ξ+0.5)|² (alias)')
        ax.plot(xi_plot, power_sum[:N_grid],
                color=COLORS['tertiary'], lw=1.5, ls='--', label='Sum (should = 1)')
        ax.set_xlabel('Normalized frequency ξ')
        ax.set_ylabel('|H(ξ)|²')
        ax.set_title('Power Complementary Condition (db4)')
        ax.legend()
        plt.tight_layout()
        plt.savefig('/tmp/wav_qmf.png', dpi=100, bbox_inches='tight')
        plt.show()

print("audit output: 4.3 QMF Verification: db4 === complete or optional branch skipped.")

5. Daubechies Wavelet Construction

5.1 Vanishing Moments

A wavelet ψ\psi has NN vanishing moments if tkψ(t)dt=0\int t^k\psi(t)\,dt = 0 for k=0,,N1k=0,\ldots,N-1.

Implication: If ff is well-approximated by a degree-(N1)(N-1) polynomial over the support of ψj,k\psi_{j,k}, then dj,k=f,ψj,k0d_{j,k} = \langle f, \psi_{j,k}\rangle \approx 0. Smooth signals have sparse wavelet representations — ideal for compression.

Code cell 19

# === 5.1 Vanishing Moments: Polynomial Annihilation ===

import numpy as np

COLORS = {
    'primary':   '#0077BB',
    'secondary': '#EE7733',
    'tertiary':  '#009988',
    'error':     '#CC3311',
    'neutral':   '#555555',
    'highlight': '#EE3377',
}

try:
    import matplotlib.pyplot as plt
    HAS_MPL = True
except ImportError:
    HAS_MPL = False

try:
    import pywt
    HAS_PYWT = True
except ImportError:
    HAS_PYWT = False

np.random.seed(42)

if HAS_PYWT:
    # Test vanishing moments for db1, db2, db4
    print('Vanishing moment verification:')
    print(f"{'Wavelet':<10} {'VM':<5} m=0      m=1      m=2      m=3      m=4")

    for wname in ['db1', 'db2', 'db3', 'db4']:
        w = pywt.Wavelet(wname)
        g = np.array(w.dec_hi)  # high-pass = wavelet filter
        K = len(g)
        k = np.arange(K)

        # Compute sum_k k^m * g[k] for m = 0..4
        moments = [np.sum((k**m) * g) for m in range(5)]
        moments_str = '  '.join(f'{m:7.4f}' for m in moments)
        print(f"{wname:<10} {w.vanishing_moments_psi:<5} {moments_str}")

    print()
    print('Zero entries = vanishing moments confirmed.')
    print('db4 has 4 vanishing moments: first 4 moments are ~0.')

    # Sparse representation test
    N_sp = 256
    t_sp = np.linspace(0, 1, N_sp)

    # Signal 1: smooth polynomial (should be very sparse in db4)
    f_poly = 0.5*t_sp**3 - 1.2*t_sp**2 + 0.8*t_sp + 0.3
    # Signal 2: non-smooth (piecewise constant + jump)
    f_jump = np.sign(np.sin(2*np.pi*3*t_sp)).astype(float)

    for fname, fsig in [('Smooth polynomial', f_poly), ('Piecewise constant', f_jump)]:
        coeffs = pywt.wavedec(fsig, 'db4', level=5)
        all_detail = np.concatenate(coeffs[1:])
        pct_small = np.mean(np.abs(all_detail) < 0.01) * 100
        print(f"{fname}: {pct_small:.1f}% of detail coefficients < 0.01 (sparse!)")

print("audit output: 5.1 Vanishing Moments: Polynomial Annihilation === complete or optional branch skipped.")

Code cell 20

# === 5.2 db2 Filter Construction via Spectral Factorization ===

# Exact db2 filter from Daubechies spectral factorization
sqrt3 = np.sqrt(3)
h_db2 = np.array([
    (1 + sqrt3) / (4*np.sqrt(2)),
    (3 + sqrt3) / (4*np.sqrt(2)),
    (3 - sqrt3) / (4*np.sqrt(2)),
    (1 - sqrt3) / (4*np.sqrt(2)),
])

print('db2 filter coefficients:')
for i, h in enumerate(h_db2):
    print(f'  h[{i}] = {h:.10f}')

# Verification
print('\nVerification:')
print(f'  sum(h) = {h_db2.sum():.10f} (should be sqrt(2) = {np.sqrt(2):.10f})')
print(f'  sum(h²) = {np.sum(h_db2**2):.10f} (should be 1.0)')
print(f'  sum(h[k]*h[k-2]) = {np.sum(h_db2[:-2]*h_db2[2:]):.2e} (should be 0)')

# Compare with pywt
if HAS_PYWT:
    w_db2 = pywt.Wavelet('db2')
    h_pywt = np.array(w_db2.dec_lo)
    err = np.max(np.abs(h_db2 - h_pywt))
    print(f'\nMatch with pywt db2: {err:.2e}')
    print(f'PASS: {err < 1e-10}')

# Check vanishing moments for db2
g_db2 = np.array([(-1)**k * h_db2[3-k] for k in range(4)])
k_idx = np.arange(4)
vm0 = np.sum(g_db2)  # should be 0
vm1 = np.sum(k_idx * g_db2)  # should be 0
vm2 = np.sum(k_idx**2 * g_db2)  # should be nonzero
print(f'\nVanishing moments check (high-pass g):')
print(f'  m=0: {vm0:.2e} (should be ~0)')
print(f'  m=1: {vm1:.2e} (should be ~0)')
print(f'  m=2: {vm2:.4f} (should be nonzero)')
print(f'PASS: 2 vanishing moments confirmed')

6. DWT in Practice

6.1 Multi-Level Wavelet Decomposition

The standard DWT tree applies the filter bank only to the approximation branch. After JJ levels: {aJ,dJ,dJ1,,d1}\{a_J, d_J, d_{J-1}, \ldots, d_1\} with total length NN — no redundancy.

Code cell 22

# === 6.1 Multi-Level DWT Decomposition and Reconstruction ===

if HAS_PYWT:
    np.random.seed(42)
    N_ml = 512
    t_ml = np.linspace(0, 1, N_ml)

    # ECG-like signal: baseline + QRS + T-wave
    ecg = (0.3*np.sin(2*np.pi*1.5*t_ml)     # slow drift
           + 2.0*np.exp(-0.5*((t_ml-0.3)/0.02)**2)   # QRS peak
           + 0.5*np.exp(-0.5*((t_ml-0.55)/0.06)**2)  # T-wave
           + 0.1*np.random.randn(N_ml))      # noise

    J_ml = 6
    coeffs_ml = pywt.wavedec(ecg, 'db4', level=J_ml)

    print('DWT coefficient structure:')
    total = 0
    for i, c in enumerate(coeffs_ml):
        label = f'a{J_ml}' if i == 0 else f'd{J_ml-i+1}'
        print(f'  {label}: {len(c)} coefficients')
        total += len(c)
    print(f'  Total: {total} = N = {N_ml}')

    if HAS_MPL:
        fig, axes = plt.subplots(J_ml+2, 1, figsize=(12, 14))
        axes[0].plot(t_ml, ecg, color=COLORS['primary'], lw=1.5)
        axes[0].set_title('ECG-like Signal')

        axes[1].plot(coeffs_ml[0], color=COLORS['tertiary'], lw=2)
        axes[1].set_title(f'Approx. a{J_ml} (global slow trend)')

        for j in range(1, J_ml+1):
            t_d = np.linspace(0, 1, len(coeffs_ml[j]))
            axes[j+1].plot(t_d, coeffs_ml[j], color=COLORS['secondary'], lw=1)
            axes[j+1].set_title(f'd{J_ml-j+1} ({N_ml//2**j} coeff, '
                                f'{N_ml//2**j:.0f}-{N_ml//2**(j-1):.0f} Hz range)')

        plt.suptitle('Multi-Level DWT (db4, 6 levels) — ECG Signal', fontsize=13, fontweight='bold')
        plt.tight_layout()
        plt.savefig('/tmp/wav_multilevel.png', dpi=100, bbox_inches='tight')
        plt.show()

print("audit output: 6.1 Multi-Level DWT Decomposition and Reconstruction === complete or optional branch skipped.")

6.3 2D DWT: Image Subband Decomposition

The 2D DWT applies 1D DWT separably (rows then columns), producing 4 subbands per level:

  • LL — approximation (coarse image)
  • LH — horizontal detail (vertical edges)
  • HL — vertical detail (horizontal edges)
  • HH — diagonal detail (diagonal edges)

Code cell 24

# === 6.3 2D DWT Image Decomposition ===

if HAS_PYWT:
    from scipy.datasets import face
    # Try to load scipy face image, fallback to synthetic
    try:
        img = face(gray=True).astype(float)[:256, :256]
        img = img / 255.0
    except Exception:
        # Synthetic image: smooth gradient + sharp edges
        N_img = 256
        x_img, y_img = np.meshgrid(np.linspace(0, 1, N_img), np.linspace(0, 1, N_img))
        img = (np.sin(2*np.pi*3*x_img) * np.cos(2*np.pi*2*y_img)
               + 0.5*(x_img > 0.5).astype(float)
               + 0.3*(y_img > 0.6).astype(float))
        img = (img - img.min()) / (img.max() - img.min())

    # Apply 3-level 2D DWT
    coeffs_2d = pywt.wavedec2(img, 'db2', level=3)
    # coeffs_2d = [cA3, (cH3,cV3,cD3), (cH2,cV2,cD2), (cH1,cV1,cD1)]

    # Build full coefficient image for visualization
    # Top-left = LL3, rest = subbands
    def make_coeff_image(coeffs):
        cA = coeffs[0]
        rows, cols = cA.shape[0]*2, cA.shape[1]*2
        out = np.zeros((img.shape[0], img.shape[1]))
        out[:cA.shape[0], :cA.shape[1]] = cA / (np.abs(cA).max() + 1e-10)

        h, w = cA.shape
        for j, (cH, cV, cD) in enumerate(coeffs[1:]):
            # Each level fills a quadrant
            scale = 2**(j+1) // 2
            def norm(c): return c / (np.abs(c).max() + 1e-10)
            r0, c0 = h - cH.shape[0], 0  # LH upper-left
            out[r0:r0+cH.shape[0], 0:cH.shape[1]] = norm(cH)
            out[r0:r0+cV.shape[0], c0+cH.shape[1]:c0+cH.shape[1]+cV.shape[1]] = norm(cV)
            # Place HL and HH
            out[0:cD.shape[0], c0+cV.shape[1]:c0+cV.shape[1]+cD.shape[1]] = norm(cD)
            break  # just first level for simplicity
        return out

    if HAS_MPL:
        fig, axes = plt.subplots(2, 5, figsize=(16, 7))

        axes[0,0].imshow(img, cmap='gray', vmin=0, vmax=1)
        axes[0,0].set_title('Original Image')
        axes[0,0].axis('off')

        axes[0,1].imshow(coeffs_2d[0], cmap='gray')
        axes[0,1].set_title('LL3 (Approx.)')
        axes[0,1].axis('off')

        subband_names = [('LH3','LH2','LH1'), ('HL3','HL2','HL1'), ('HH3','HH2','HH1')]
        for row, (sb_names) in enumerate(subband_names):
            for col, (sb_name, level) in enumerate(zip(sb_names, [1,2,3])):
                ax = axes[row//3 + (row>0), (row%3)*1 + col + 2 - (2*(row>0))]
                data = coeffs_2d[level][row%3]
                axes[1 if level>1 else 0, col + 2].imshow(
                    np.abs(coeffs_2d[level][row]), cmap='hot', vmin=0)
                axes[1 if level>1 else 0, col + 2].set_title(
                    f'|{["LH","HL","HH"][row]}{4-level}|')
                axes[1 if level>1 else 0, col + 2].axis('off')
                break

        # Simpler: just show 4 subbands of first level
        fig2, axes2 = plt.subplots(1, 5, figsize=(15, 3))
        axes2[0].imshow(img, cmap='gray')
        axes2[0].set_title('Original')
        axes2[0].axis('off')

        axes2[1].imshow(np.abs(coeffs_2d[1][0]), cmap='hot')
        axes2[1].set_title('|LH1| (vertical edges)')
        axes2[1].axis('off')

        axes2[2].imshow(np.abs(coeffs_2d[1][1]), cmap='hot')
        axes2[2].set_title('|HL1| (horizontal edges)')
        axes2[2].axis('off')

        axes2[3].imshow(np.abs(coeffs_2d[1][2]), cmap='hot')
        axes2[3].set_title('|HH1| (diagonal edges)')
        axes2[3].axis('off')

        axes2[4].imshow(coeffs_2d[0], cmap='gray')
        axes2[4].set_title('LL3 (approximation)')
        axes2[4].axis('off')

        plt.tight_layout()
        plt.savefig('/tmp/wav_2d.png', dpi=100, bbox_inches='tight')
        plt.show()

    # Image compression: keep top-k% coefficients
    all_coeffs = np.concatenate([coeffs_2d[0].ravel()] +
                                [c.ravel() for level in coeffs_2d[1:] for c in level])
    total_coeffs = len(all_coeffs)
    threshold = np.percentile(np.abs(all_coeffs), 90)  # keep top 10%
    kept = np.mean(np.abs(all_coeffs) >= threshold) * 100
    print(f'Total coefficients: {total_coeffs}')
    print(f'Keeping top 10% (threshold={threshold:.4f}): {kept:.1f}% kept')

print("audit output: 6.3 2D DWT Image Decomposition === complete or optional branch skipped.")

7. Time-Frequency Analysis: Scalogram

The scalogram Wψf(a,b)2|W_\psi f(a,b)|^2 shows how energy is distributed in time and scale. Unlike the STFT which uses uniform tiles, the CWT uses logarithmically-spaced tiles — matching the constant-Q structure of natural signals.

Code cell 26

# === 7.1 Scalogram: Chirp + Burst ===

if HAS_PYWT and HAS_MPL:
    from scipy.signal import chirp

    fs_sc = 2000
    t_sc = np.linspace(0, 1, fs_sc)

    # Multi-component: chirp + pure tone + transient
    s1 = chirp(t_sc, 50, 1.0, 400, method='quadratic')      # accelerating chirp
    s2 = 0.5 * np.sin(2*np.pi*150*t_sc)                     # steady 150 Hz
    s3 = np.zeros_like(t_sc)                                 # short transient
    s3[(t_sc > 0.7) & (t_sc < 0.72)] = 2.0

    signal_sc = s1 + s2 + s3

    # CWT with Morlet
    scales_sc = np.logspace(0.2, 2.2, 100)
    coefs_sc, freqs_sc = pywt.cwt(signal_sc, scales_sc, 'morl',
                                   sampling_period=1/fs_sc)

    fig, axes = plt.subplots(2, 1, figsize=(13, 9), sharex=True)

    axes[0].plot(t_sc, signal_sc, color=COLORS['primary'], lw=0.8)
    axes[0].set_title('Signal: Quadratic Chirp + 150 Hz Tone + Transient')
    axes[0].set_ylabel('Amplitude')

    power = np.abs(coefs_sc)**2
    im = axes[1].pcolormesh(t_sc, freqs_sc, power, shading='gouraud', cmap='plasma')
    axes[1].set_yscale('log')
    axes[1].set_ylim([30, 600])
    axes[1].set_title('Morlet CWT Scalogram — Logarithmic Tiling')
    axes[1].set_xlabel('Time (s)')
    axes[1].set_ylabel('Frequency (Hz)')
    plt.colorbar(im, ax=axes[1], label='Power')

    # Annotate components
    axes[1].annotate('Quadratic chirp\n(parabolic ridge)', xy=(0.5, 200),
                     color='white', fontsize=9, ha='center')
    axes[1].axhline(150, color='cyan', ls='--', lw=1, alpha=0.7)
    axes[1].axvline(0.71, color='yellow', ls='--', lw=1, alpha=0.7)

    plt.tight_layout()
    plt.savefig('/tmp/wav_scalogram2.png', dpi=100, bbox_inches='tight')
    plt.show()

    print('The scalogram clearly shows:')
    print('  - Quadratic chirp as a parabolic ridge')
    print('  - Steady 150 Hz as a horizontal line (cyan dashed)')
    print('  - Transient at t=0.71s as a vertical streak (yellow dashed)')
    print('  - Low frequencies: wide tiles (coarse time, fine frequency)')
    print('  - High frequencies: narrow tiles (fine time, coarse frequency)')

print("audit output: 7.1 Scalogram: Chirp + Burst === complete or optional branch skipped.")

8. Machine Learning Applications

8.1 Mallat Scattering Networks

The scattering transform provides provably stable, translation-invariant features without learned parameters. It cascades wavelet transforms with pointwise modulus:

S1[j](x)=xψjϕJS_1[j](x) = |x * \psi_j| * \phi_J S2[j1,j2](x)=xψj1ψj2ϕJS_2[j_1, j_2](x) = ||x * \psi_{j_1}| * \psi_{j_2}| * \phi_J

Key theorem: SfS(Tτf)Cτf\|Sf - S(T_\tau f)\| \leq C|\tau|\|f\| — scattering is Lipschitz-stable under diffeomorphisms τ\tau.

Code cell 28

# === 8.1 Scattering Transform (Order 1 and 2) ===

if HAS_PYWT:
    def scattering_1d(x, wavelet='db4', J=5):
        """Simplified 1D scattering transform (orders 0, 1, 2)."""
        N_sc = len(x)

        # Order 0: low-pass average
        coeffs0 = pywt.wavedec(x, wavelet, level=J)
        S0 = np.abs(coeffs0[0]).mean()

        # Order 1: |DWT_j x| averaged at coarsest scale
        S1 = []
        modulus_coeffs = []
        for j in range(1, J+1):
            coeffs_j = pywt.wavedec(x, wavelet, level=j)
            mod_j = np.abs(coeffs_j[1])  # finest detail at level j
            S1.append(mod_j.mean())  # average = zeroth-order scattering of mod_j
            modulus_coeffs.append(mod_j)

        # Order 2: ||DWT_j1 x| * psi_j2| averaged
        S2 = []
        for j1 in range(len(modulus_coeffs)):
            for j2 in range(j1+1, min(j1+3, len(modulus_coeffs))):
                # Apply wavelet j2 to |wavelet j1 response|
                sig_j1 = modulus_coeffs[j1]
                if len(sig_j1) < 4:
                    continue
                c_j2 = pywt.wavedec(sig_j1, wavelet, level=1)
                s2 = np.abs(c_j2[1]).mean()
                S2.append(s2)

        return S0, np.array(S1), np.array(S2)

    # Test translation invariance
    np.random.seed(42)
    N_sc = 256
    n_sc = np.arange(N_sc)
    x_sc = np.sin(2*np.pi*0.1*n_sc) + 0.3*np.sin(2*np.pi*0.25*n_sc)

    # Translations
    shifts = [0, 5, 10, 20, 40]
    print('Scattering translation invariance test:')
    print(f"{'Shift':<8} {'|S1 diff|/|S1|':>16} {'|S2 diff|/|S2|':>16}")

    S0_ref, S1_ref, S2_ref = scattering_1d(x_sc)

    for shift in shifts:
        x_shifted = np.roll(x_sc, shift)
        _, S1_s, S2_s = scattering_1d(x_shifted)

        err1 = np.linalg.norm(S1_s - S1_ref) / (np.linalg.norm(S1_ref) + 1e-10)
        err2 = np.linalg.norm(S2_s - S2_ref) / (np.linalg.norm(S2_ref) + 1e-10)
        print(f"{shift:<8} {err1:>16.6f} {err2:>16.6f}")

    print()
    print('PASS — small errors confirm near-translation-invariance of scattering features')

print("audit output: 8.1 Scattering Transform (Order 1 and 2) === complete or optional branch skipped.")

8.5 Wavelet Denoising: Donoho-Johnstone

Soft thresholding of wavelet coefficients is the proximal operator for 1\ell^1 regularization:

d^j,k=sign(dj,k)max(dj,kλ,0)\hat{d}_{j,k} = \text{sign}(d_{j,k})\max(|d_{j,k}| - \lambda, 0)

Universal threshold: λ=σ^2logN\lambda = \hat{\sigma}\sqrt{2\log N} where σ^=median(d1)/0.6745\hat{\sigma} = \text{median}(|d_1|)/0.6745.

Near-optimal: Within a 2logN2\log N factor of the minimax risk over Besov function classes.

Code cell 30

# === 8.5 Wavelet Denoising ===

if HAS_PYWT:
    def wavelet_denoise(y, wavelet='db4', level=5, mode='soft'):
        """Donoho-Johnstone wavelet thresholding denoiser."""
        N_d = len(y)
        coeffs_d = pywt.wavedec(y, wavelet, level=level)

        # Estimate noise from finest scale (robust via MAD)
        sigma_est = np.median(np.abs(coeffs_d[-1])) / 0.6745

        # Universal threshold
        threshold = sigma_est * np.sqrt(2 * np.log(N_d))

        # Threshold all detail coefficients
        denoised = [coeffs_d[0]]  # keep approximation unchanged
        for c in coeffs_d[1:]:
            denoised.append(pywt.threshold(c, threshold, mode=mode))

        return pywt.waverec(denoised, wavelet)[:N_d], sigma_est, threshold

    np.random.seed(42)
    N_dn = 512
    t_dn = np.linspace(0, 1, N_dn)

    # True signal: Doppler function
    f_true = np.sqrt(t_dn*(1-t_dn)) * np.sin(2*np.pi*1.05/(t_dn+0.05))

    noise_std = 0.2
    y_noisy = f_true + noise_std * np.random.randn(N_dn)

    f_soft, sigma_est, threshold = wavelet_denoise(y_noisy, 'db4', 5, 'soft')
    f_hard, _, _ = wavelet_denoise(y_noisy, 'db4', 5, 'hard')

    mse_noisy = np.mean((y_noisy - f_true)**2)
    mse_soft  = np.mean((f_soft  - f_true)**2)
    mse_hard  = np.mean((f_hard  - f_true)**2)

    snr = lambda mse: 10*np.log10(np.var(f_true)/mse)
    print(f'Noise sigma: {noise_std}, Estimated: {sigma_est:.4f}')
    print(f'Universal threshold: {threshold:.4f}')
    print(f'\n{'Method':<15} {'MSE':>12} {'SNR (dB)':>10}')
    print(f"{'Noisy input':<15} {mse_noisy:>12.6f} {snr(mse_noisy):>10.2f}")
    print(f"{'Soft thresh':<15} {mse_soft:>12.6f} {snr(mse_soft):>10.2f}")
    print(f"{'Hard thresh':<15} {mse_hard:>12.6f} {snr(mse_hard):>10.2f}")
    print(f'\nSNR improvement (soft): {snr(mse_soft)-snr(mse_noisy):.1f} dB')

    if HAS_MPL:
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        axes[0,0].plot(t_dn, f_true, color=COLORS['primary'], lw=2)
        axes[0,0].set_title('True Signal (Doppler)')

        axes[0,1].plot(t_dn, y_noisy, color=COLORS['neutral'], alpha=0.7, lw=0.8)
        axes[0,1].set_title(f'Noisy (σ={noise_std})')

        axes[1,0].plot(t_dn, f_true, color=COLORS['primary'], lw=1.5, alpha=0.4, label='True')
        axes[1,0].plot(t_dn, f_soft, color=COLORS['tertiary'], lw=2, label='Soft threshold')
        axes[1,0].set_title(f'Soft Thresholding (MSE={mse_soft:.4f})')
        axes[1,0].legend()

        axes[1,1].plot(t_dn, f_true, color=COLORS['primary'], lw=1.5, alpha=0.4, label='True')
        axes[1,1].plot(t_dn, f_hard, color=COLORS['error'], lw=2, label='Hard threshold')
        axes[1,1].set_title(f'Hard Thresholding (MSE={mse_hard:.4f})')
        axes[1,1].legend()

        plt.suptitle('Wavelet Denoising: Donoho-Johnstone', fontsize=13, fontweight='bold')
        plt.tight_layout()
        plt.savefig('/tmp/wav_denoise.png', dpi=100, bbox_inches='tight')
        plt.show()

print("audit output: 8.5 Wavelet Denoising === complete or optional branch skipped.")

9. Complexity and JPEG 2000 Compression

The DWT achieves O(N)O(N) complexity — faster than FFT. This enables real-time processing of large signals and underlies JPEG 2000's compression pipeline.

Code cell 32

# === 9. DWT vs FFT Timing + Image Compression ===

import time

# Timing comparison
sizes = [256, 1024, 4096, 16384, 65536]
t_fft_list = []
t_dwt_list = []

if HAS_PYWT:
    for N_size in sizes:
        x_time = np.random.randn(N_size)

        t0 = time.perf_counter()
        for _ in range(20): np.fft.rfft(x_time)
        t_fft_list.append((time.perf_counter()-t0)/20*1000)

        t0 = time.perf_counter()
        for _ in range(20): pywt.wavedec(x_time, 'db4', level=5)
        t_dwt_list.append((time.perf_counter()-t0)/20*1000)

    print('Timing comparison: FFT vs DWT')
    print(f"{'N':>8} {'FFT (ms)':>12} {'DWT (ms)':>12} {'DWT/FFT':>10}")
    for N_size, tf, td in zip(sizes, t_fft_list, t_dwt_list):
        print(f"{N_size:>8} {tf:>12.3f} {td:>12.3f} {td/tf:>10.3f}")

    print('\nDWT grows linearly; FFT grows as N*log(N)')

    if HAS_MPL:
        fig, ax = plt.subplots(figsize=(8, 4))
        ax.loglog(sizes, t_fft_list, 'o-', color=COLORS['primary'], lw=2, label='FFT (O(N log N))')
        ax.loglog(sizes, t_dwt_list, 's-', color=COLORS['secondary'], lw=2, label='DWT (O(N))')

        # Reference lines
        N_ref = np.array(sizes, dtype=float)
        ax.loglog(N_ref, t_fft_list[0]*N_ref/sizes[0]*np.log2(N_ref)/np.log2(sizes[0]),
                  '--', color='gray', lw=1, label='O(N log N)')
        ax.loglog(N_ref, t_dwt_list[0]*N_ref/sizes[0],
                  ':', color='gray', lw=1, label='O(N)')

        ax.set_xlabel('Signal length N')
        ax.set_ylabel('Time (ms)')
        ax.set_title('DWT O(N) vs FFT O(N log N)')
        ax.legend()
        plt.tight_layout()
        plt.savefig('/tmp/wav_timing.png', dpi=100, bbox_inches='tight')
        plt.show()

    # Image compression with wavelet thresholding
    N_img = 128
    x_img, y_img = np.meshgrid(np.linspace(0,1,N_img), np.linspace(0,1,N_img))
    test_img = np.sin(2*np.pi*4*x_img) * np.cos(2*np.pi*3*y_img) + 0.5*(x_img > 0.5)
    test_img = (test_img - test_img.min()) / (test_img.max() - test_img.min())

    keep_fracs = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5]
    psnrs = []

    print('\nImage compression (DWT thresholding):')
    print(f"{'Keep%':>8} {'PSNR (dB)':>12}")

    for kf in keep_fracs:
        coeffs_img = pywt.wavedec2(test_img, 'db2', level=3)
        all_c = np.concatenate([coeffs_img[0].ravel()] +
                               [c.ravel() for level in coeffs_img[1:] for c in level])
        thresh = np.percentile(np.abs(all_c), 100*(1-kf))

        thresh_coeffs = [pywt.threshold(coeffs_img[0], thresh, mode='soft')]
        for level in coeffs_img[1:]:
            thresh_coeffs.append(tuple(pywt.threshold(c, thresh, mode='soft') for c in level))

        rec_img = pywt.waverec2(thresh_coeffs, 'db2')[:N_img, :N_img]
        mse_img = np.mean((rec_img - test_img)**2)
        psnr = -10*np.log10(mse_img + 1e-20)
        psnrs.append(psnr)
        print(f"{kf*100:>7.0f}% {psnr:>12.2f}")

print("audit output: 9. DWT vs FFT Timing + Image Compression === complete or optional branch skipped.")

10. Summary: Wavelet Properties Verification

A comprehensive verification suite confirming the key theoretical properties of wavelets.

Code cell 34

# === 10. Comprehensive Wavelet Verification Suite ===

if HAS_PYWT:
    np.random.seed(42)
    x_v = np.random.randn(256)
    wavelet_v = 'db4'
    J_v = 5

    results = {}

    # 1. Perfect Reconstruction
    coeffs_v = pywt.wavedec(x_v, wavelet_v, level=J_v)
    x_rec_v = pywt.waverec(coeffs_v, wavelet_v)[:len(x_v)]
    err_pr = np.max(np.abs(x_v - x_rec_v))
    results['Perfect Reconstruction'] = err_pr < 1e-10

    # 2. Parseval's Theorem
    energy_sig = np.sum(x_v**2)
    energy_wav = sum(np.sum(c**2) for c in coeffs_v)
    err_parseval = abs(energy_sig - energy_wav) / energy_sig
    results['Parseval Theorem'] = err_parseval < 1e-10

    # 3. QMF Condition
    w_v = pywt.Wavelet(wavelet_v)
    h_v = np.array(w_v.dec_lo)
    g_v = np.array(w_v.dec_hi)
    K_v = len(h_v)
    g_qmf_v = np.array([(-1)**k * h_v[K_v-1-k] for k in range(K_v)])
    results['QMF Relation'] = np.max(np.abs(g_v - g_qmf_v)) < 1e-10

    # 4. Vanishing Moments
    k_v = np.arange(K_v)
    vm_check = all(
        abs(np.sum((k_v**m) * g_v)) < 1e-8
        for m in range(w_v.vanishing_moments_psi)
    )
    results['Vanishing Moments'] = vm_check

    # 5. Translation equivariance (DWT is NOT translation invariant, just equivariant)
    # Circular shift by 2 (power of 2) should permute coefficients
    x_shift = np.roll(x_v, 2)
    coeffs_shift = pywt.wavedec(x_shift, wavelet_v, level=1)
    coeffs_orig  = pywt.wavedec(x_v,     wavelet_v, level=1)
    # After shift by 2: level-1 detail should shift by 1 (half)
    err_equiv = np.max(np.abs(np.roll(coeffs_orig[1], 1) - coeffs_shift[1]))
    results['Shift-by-2 equivariance'] = err_equiv < 1e-8

    # 6. Normalization: phi integrates to 1
    phi_vals, psi_vals, x_wf = w_v.wavefun(level=10)
    dx = x_wf[1] - x_wf[0]
    phi_int = np.sum(phi_vals) * dx
    results['phi integrates to 1'] = abs(phi_int - 1.0) < 1e-3

    # 7. psi integrates to 0 (admissibility)
    psi_int = np.sum(psi_vals) * dx
    results['psi integrates to 0'] = abs(psi_int) < 1e-3

    print('Wavelet Theory Verification Suite')
    print('=' * 50)
    for name, passed in results.items():
        status = 'PASS' if passed else 'FAIL'
        print(f'  {status}  {name}')
    print('=' * 50)
    all_pass = all(results.values())
    print(f'\n{"ALL TESTS PASS" if all_pass else "SOME TESTS FAILED"}')

    print(f'\nSummary: {wavelet_v} wavelet')
    print(f'  Vanishing moments: {w_v.vanishing_moments_psi}')
    print(f'  Filter length:     {K_v} taps')
    print(f'  Support:           [0, {K_v-1}]')
    print(f'  Energy:            {np.sum(h_v**2):.6f} (should be 1)')
    print(f'  Sum:               {np.sum(h_v):.6f} (should be sqrt(2) = {np.sqrt(2):.6f})')

print("audit output: 10. Comprehensive Wavelet Verification Suite === complete or optional branch skipped.")

References and Further Reading

Foundational texts:

  • Mallat, S. (1998). A Wavelet Tour of Signal Processing. Academic Press.
  • Daubechies, I. (1992). Ten Lectures on Wavelets. SIAM.
  • Vetterli, M. & Kovačević, J. (1995). Wavelets and Subband Coding. Prentice-Hall.

Software:

AI papers:

  • Mallat (2012). Group Invariant Scattering. Communications on Pure and Applied Mathematics.
  • Bruna & Mallat (2013). Invariant Scattering Convolution Networks. IEEE T-PAMI.
  • Yao et al. (2021). WaveBERT: wavelet token compression for long-range Transformers.
  • Liu et al. (2022). WaveMix: 2D wavelet mixing for vision.

Next section: