Theory Notebook
Converted from
theory.ipynbfor web reading.
Graph Neural Networks — Theory Notebook
"A graph neural network is a machine that reads a graph and learns by listening to its neighbors."
Interactive derivations covering: GCN propagation, over-smoothing dynamics, WL color refinement, GAT attention, GIN expressiveness, graph pooling, positional encodings, and training at scale.
Companion: notes.md | exercises.ipynb
Code cell 2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style="whitegrid", palette="colorblind")
HAS_SNS = True
except ImportError:
plt.style.use("seaborn-v0_8-whitegrid")
HAS_SNS = False
mpl.rcParams.update({
"figure.figsize": (10, 6),
"figure.dpi": 120,
"font.size": 13,
"axes.titlesize": 15,
"axes.labelsize": 13,
"xtick.labelsize": 11,
"ytick.labelsize": 11,
"legend.fontsize": 11,
"legend.framealpha": 0.85,
"lines.linewidth": 2.0,
"axes.spines.top": False,
"axes.spines.right": False,
"savefig.bbox": "tight",
"savefig.dpi": 150,
})
np.random.seed(42)
print("Plot setup complete.")
Code cell 3
import numpy as np
import scipy.linalg as la
import scipy.sparse as sp
try:
import matplotlib.pyplot as plt
import matplotlib as mpl
try:
import seaborn as sns
sns.set_theme(style='whitegrid', palette='colorblind')
HAS_SNS = True
except ImportError:
plt.style.use('seaborn-v0_8-whitegrid')
HAS_SNS = False
mpl.rcParams.update({
'figure.figsize': (10, 6), 'figure.dpi': 120,
'font.size': 13, 'axes.titlesize': 15, 'axes.labelsize': 13,
'xtick.labelsize': 11, 'ytick.labelsize': 11,
'legend.fontsize': 11, 'lines.linewidth': 2.0,
'axes.spines.top': False, 'axes.spines.right': False,
})
HAS_MPL = True
except ImportError:
HAS_MPL = False
COLORS = {
'primary': '#0077BB',
'secondary': '#EE7733',
'tertiary': '#009988',
'error': '#CC3311',
'neutral': '#555555',
'highlight': '#EE3377',
}
np.set_printoptions(precision=6, suppress=True)
np.random.seed(42)
print('Setup complete.')
1. Graph Data Structures
We build graphs as adjacency matrices and edge lists, then verify the key representations used in GNNs.
Code cell 5
# === 1. Graph Data Structures ===
def make_graph(n, edges):
"""Build adjacency matrix from edge list (undirected)."""
A = np.zeros((n, n))
for u, v in edges:
A[u, v] = 1
A[v, u] = 1
return A
# Karate-club-like small graph: 8 nodes
n = 8
edges = [(0,1),(0,2),(0,3),(1,4),(2,4),(3,5),(4,6),(5,6),(6,7),(1,2)]
A = make_graph(n, edges)
degrees = A.sum(axis=1)
print('Adjacency matrix A:')
print(A.astype(int))
print(f'\nDegrees: {degrees.astype(int)}')
print(f'Edges: {int(A.sum()//2)}')
# Node feature matrix: one-hot encoding
X = np.eye(n)
print(f'\nNode features X shape: {X.shape}')
2. GCN Propagation Matrix
Derive and study its spectral properties.
Code cell 7
# === 2. GCN Propagation Matrix ===
def gcn_propagation_matrix(A):
"""Compute A_hat = D_tilde^{-1/2} A_tilde D_tilde^{-1/2}."""
n = A.shape[0]
A_tilde = A + np.eye(n) # Add self-loops
D_tilde = np.diag(A_tilde.sum(axis=1)) # Degree matrix
D_inv_sqrt = np.diag(1.0 / np.sqrt(np.diag(D_tilde)))
A_hat = D_inv_sqrt @ A_tilde @ D_inv_sqrt
return A_hat, A_tilde, D_tilde
A_hat, A_tilde, D_tilde = gcn_propagation_matrix(A)
# Verify: eigenvalues should be in (-1, 1]
eigvals = np.linalg.eigvalsh(A_hat)
print('Eigenvalues of A_hat:')
print(np.sort(eigvals)[::-1].round(4))
print(f'\nMax eigenvalue: {eigvals.max():.6f} (should be <= 1.0)')
print(f'Min eigenvalue: {eigvals.min():.6f} (should be > -1.0)')
ok_max = eigvals.max() <= 1.0 + 1e-10
ok_min = eigvals.min() > -1.0 - 1e-10
print(f'\nPASS: eigenvalues in (-1,1]' if (ok_max and ok_min) else 'FAIL: eigenvalue out of range')
Code cell 8
# === 2.1 Two-Layer GCN Forward Pass ===
def relu(x):
return np.maximum(0, x)
def softmax(x):
e = np.exp(x - x.max(axis=1, keepdims=True))
return e / e.sum(axis=1, keepdims=True)
np.random.seed(42)
d0, d1, d2 = n, 4, 3 # input_dim, hidden_dim, output_dim
W0 = np.random.randn(d0, d1) * 0.5
W1 = np.random.randn(d1, d2) * 0.5
# Two-layer GCN
H1 = relu(A_hat @ X @ W0) # Shape: (n, d1)
H2 = softmax(A_hat @ H1 @ W1) # Shape: (n, d2)
print('H1 (hidden layer):')
print(H1.round(4))
print(f'\nH2 (output, softmax): shape {H2.shape}')
print(H2.round(4))
print(f'\nRow sums of H2 (should be 1.0): {H2.sum(axis=1).round(6)}')
ok = np.allclose(H2.sum(axis=1), 1.0)
print(f"{'PASS' if ok else 'FAIL'} - softmax rows sum to 1")
2.2 Permutation Equivariance Verification
Verify that GCN satisfies for any permutation .
Code cell 10
# === 2.2 Permutation Equivariance ===
# Create a random permutation
perm = np.random.permutation(n)
P = np.eye(n)[perm] # Permutation matrix
# Permuted graph
A_perm = P @ A @ P.T
X_perm = P @ X
# GCN on permuted graph
A_hat_perm, _, _ = gcn_propagation_matrix(A_perm)
H1_perm = relu(A_hat_perm @ X_perm @ W0)
H2_perm = softmax(A_hat_perm @ H1_perm @ W1)
# Check: H2_perm should equal P @ H2
H2_expected = P @ H2
diff = np.abs(H2_perm - H2_expected).max()
print(f'Max difference |H2_perm - P@H2|: {diff:.2e}')
ok = np.allclose(H2_perm, H2_expected, atol=1e-10)
print(f"{'PASS' if ok else 'FAIL'} - GCN is permutation equivariant")
3. Over-Smoothing: Dirichlet Energy Decay
Visualize how the Dirichlet energy decays to zero as GCN depth increases.
Code cell 12
# === 3. Over-Smoothing Analysis ===
# Unnormalized Laplacian L = D - A
L = np.diag(A.sum(axis=1)) - A
def dirichlet_energy(H, L):
return np.trace(H.T @ L @ H)
# Apply pure smoothing (A_hat, no weight matrix) for increasing depths
H_current = X.copy() # Start from identity features
depths = list(range(0, 33))
energies = []
H_iter = X.copy()
for d in range(depths[-1] + 1):
energies.append(dirichlet_energy(H_iter, L))
H_iter = A_hat @ H_iter # One smoothing step
print('Dirichlet energy at selected depths:')
for d in [0, 1, 2, 4, 8, 16, 32]:
print(f' Depth {d:2d}: E = {energies[d]:.6f}')
ok = energies[32] < energies[0] * 0.01
print(f"\n{'PASS' if ok else 'FAIL'} - energy decays by 99%+ at depth 32")
Code cell 13
# === 3.1 Over-Smoothing Visualization ===
if HAS_MPL:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Left: Dirichlet energy vs depth
ax = axes[0]
ax.semilogy(depths, energies, color=COLORS['primary'], linewidth=2)
ax.axvline(2, color=COLORS['neutral'], linestyle='--', alpha=0.7, label='Depth 2 (typical GCN)')
ax.set_title('Over-smoothing: Dirichlet energy vs depth')
ax.set_xlabel('Number of propagation steps')
ax.set_ylabel('Dirichlet energy $E(H)$ (log scale)')
ax.legend()
# Right: Node representations at depth 0 vs 16
ax = axes[1]
H_shallow = np.linalg.matrix_power(A_hat, 2) @ X
H_deep = np.linalg.matrix_power(A_hat, 16) @ X
# Project to 2D via first two columns
ax.scatter(H_shallow[:, 0], H_shallow[:, 1],
color=COLORS['primary'], s=120, label='Depth 2', zorder=5)
ax.scatter(H_deep[:, 0], H_deep[:, 1],
color=COLORS['error'], s=120, marker='x', linewidths=2,
label='Depth 16 (over-smoothed)', zorder=5)
ax.set_title('Node representations: shallow vs deep')
ax.set_xlabel('Feature dim 1')
ax.set_ylabel('Feature dim 2')
ax.legend()
fig.tight_layout()
plt.show()
print('Plot displayed.')
Code cell 14
# === 3.2 Convergence to Stationary Distribution ===
# Compute stationary distribution: pi_v = d_tilde_v / sum(d_tilde)
d_tilde = A_hat.sum(axis=1) # row sums of A_hat (should be ~1 for normalized)
# For random walk: stationary = degree / vol(G)
A_rw = (A_tilde.T / A_tilde.sum(axis=1)).T # row-normalized (random walk)
pi = A_tilde.sum(axis=1) / A_tilde.sum()
# After many steps, each row of A_rw^k should converge to pi
A_rw_k = np.linalg.matrix_power(A_rw, 50)
print('First 3 rows of A_rw^50 (should all be approx. pi):')
print(A_rw_k[:3].round(6))
print(f'\nStationary distribution pi:')
print(pi.round(6))
ok = np.allclose(A_rw_k[:3], pi[np.newaxis, :], atol=1e-4)
print(f"\n{'PASS' if ok else 'FAIL'} - rows converge to stationary distribution")
4. Weisfeiler-Leman Color Refinement
Implement 1-WL and test it on pairs of graphs to understand GNN expressiveness limits.
Code cell 16
# === 4. Weisfeiler-Leman Color Refinement ===
from collections import Counter
def wl_refinement(adj, max_iter=10):
"""Run 1-WL color refinement. adj: dict {node: set_of_neighbors}."""
n = len(adj)
# Initial: all same color
colors = {v: 0 for v in adj}
color_history = [dict(colors)]
for t in range(max_iter):
new_colors = {}
color_map = {} # (color, sorted nbr colors) -> new color
counter = [0]
def get_color(key):
if key not in color_map:
color_map[key] = counter[0]
counter[0] += 1
return color_map[key]
for v in adj:
nbr_colors = tuple(sorted(colors[u] for u in adj[v]))
key = (colors[v], nbr_colors)
new_colors[v] = get_color(key)
if new_colors == colors:
break
colors = new_colors
color_history.append(dict(colors))
return colors, color_history
def adj_from_edges(n, edges):
adj = {i: set() for i in range(n)}
for u, v in edges:
adj[u].add(v)
adj[v].add(u)
return adj
# C6 (6-cycle) vs C3+C3 (two disjoint triangles)
C6_adj = adj_from_edges(6, [(0,1),(1,2),(2,3),(3,4),(4,5),(5,0)])
C3C3_adj = adj_from_edges(6, [(0,1),(1,2),(2,0),(3,4),(4,5),(5,3)])
c_C6, hist_C6 = wl_refinement(C6_adj)
c_C3C3, hist_C3C3 = wl_refinement(C3C3_adj)
print('C6 final colors:', c_C6)
print('C3+C3 final colors:', c_C3C3)
print(f'C6 color histogram: {dict(Counter(c_C6.values()))}')
print(f'C3+C3 color histogram: {dict(Counter(c_C3C3.values()))}')
# WL says they are isomorphic if histograms match
wl_distinguishes = Counter(c_C6.values()) != Counter(c_C3C3.values())
print(f'\n1-WL distinguishes C6 vs C3+C3: {wl_distinguishes}')
print('(Expected: False — 1-WL cannot distinguish them)')
Code cell 17
# === 4.1 WL on Distinguishable Graphs ===
# K_{1,3} (star with 3 leaves) vs P4 (path of 4 nodes)
K13_adj = adj_from_edges(4, [(0,1),(0,2),(0,3)]) # node 0 is hub
P4_adj = adj_from_edges(4, [(0,1),(1,2),(2,3)]) # path
c_K13, _ = wl_refinement(K13_adj)
c_P4, _ = wl_refinement(P4_adj)
print('K_{1,3} colors:', c_K13)
print('P4 colors:', c_P4)
hist_K13 = Counter(c_K13.values())
hist_P4 = Counter(c_P4.values())
distinguishes = hist_K13 != hist_P4
print(f'\n1-WL distinguishes K_{{1,3}} vs P4: {distinguishes}')
print(f'K_{{1,3}} histogram: {dict(hist_K13)}')
print(f'P4 histogram: {dict(hist_P4)}')
ok = distinguishes == True
print(f"\n{'PASS' if ok else 'FAIL'} - WL correctly distinguishes K_{{1,3}} vs P4")
5. Graph Attention Network (GAT)
Implement GAT and GATv2 attention from scratch and visualize attention patterns.
Code cell 19
# === 5. GAT Attention ===
np.random.seed(42)
def gat_layer(H, A, W, a, leaky_slope=0.2):
"""
GAT layer (single head).
H: (n, d) node features
A: (n, n) adjacency (with self-loops)
W: (d, d') linear transform
a: (2*d',) attention vector
Returns: H_new (n, d'), alpha (n, n)
"""
n, d = H.shape
Z = H @ W # (n, d')
d_prime = Z.shape[1]
# Compute attention scores for all edges
a_left = a[:d_prime] # (d',)
a_right = a[d_prime:] # (d',)
# e_ij = LeakyReLU(a^T [z_i || z_j])
score_left = Z @ a_left # (n,)
score_right = Z @ a_right # (n,)
# e[i,j] = score_left[i] + score_right[j] (for neighbors)
E = score_left[:, np.newaxis] + score_right[np.newaxis, :] # (n, n)
# LeakyReLU
E = np.where(E >= 0, E, leaky_slope * E)
# Mask: only attend to neighbors (where A > 0)
mask = (A > 0).astype(float)
E_masked = np.where(mask > 0, E, -1e9)
# Softmax over neighbors
E_exp = np.exp(E_masked - E_masked.max(axis=1, keepdims=True))
alpha = E_exp * mask
alpha = alpha / (alpha.sum(axis=1, keepdims=True) + 1e-10)
# Aggregate
H_new = np.tanh(alpha @ Z) # (n, d')
return H_new, alpha
# Setup
n_small = 6
edges_small = [(0,1),(1,2),(2,3),(3,4),(4,5),(5,0),(0,2),(1,4)]
A_small = make_graph(n_small, edges_small)
A_small_sl = A_small + np.eye(n_small) # with self-loops
d_in, d_out = 4, 3
H0 = np.random.randn(n_small, d_in)
W_gat = np.random.randn(d_in, d_out) * 0.5
a_gat = np.random.randn(2 * d_out) * 0.5
H_gat, alpha = gat_layer(H0, A_small_sl, W_gat, a_gat)
print('GAT output shape:', H_gat.shape)
print('\nAttention matrix alpha (rows = target nodes):')
print(alpha.round(3))
print('\nRow sums (should be 1 for each node):', alpha.sum(axis=1).round(4))
Code cell 20
# === 5.1 GATv2 vs GAT: Static vs Dynamic Attention ===
def gatv2_layer(H, A, W, a, leaky_slope=0.2):
"""GATv2: dynamic attention e_ij = a^T LeakyReLU(W[h_i || h_j])."""
n, d = H.shape
n_out = W.shape[1]
# Concatenate all pairs (broadcast)
# h_i repeated across columns, h_j repeated across rows
H_i = np.repeat(H[:, np.newaxis, :], n, axis=1) # (n, n, d)
H_j = np.repeat(H[np.newaxis, :, :], n, axis=0) # (n, n, d)
H_cat = np.concatenate([H_i, H_j], axis=2) # (n, n, 2d)
# W in R^{2d x d_out} maps concat to features
# For simplicity, use W and a directly on concatenated
scores = H_cat @ W # (n, n, n_out) -- W should be (2d, n_out)
# Apply LeakyReLU
scores = np.where(scores >= 0, scores, leaky_slope * scores)
# Dot with a: (n, n, n_out) . (n_out,) -> (n, n)
E = scores @ a[:n_out] # (n, n)
mask = (A > 0).astype(float)
E_masked = np.where(mask > 0, E, -1e9)
E_exp = np.exp(E_masked - E_masked.max(axis=1, keepdims=True))
alpha_v2 = E_exp * mask
alpha_v2 = alpha_v2 / (alpha_v2.sum(axis=1, keepdims=True) + 1e-10)
W_feat = np.random.randn(d, n_out) * 0.3
H_new = np.tanh(alpha_v2 @ (H @ W_feat))
return H_new, alpha_v2
W_v2 = np.random.randn(2 * d_in, d_out) * 0.3
a_v2 = np.random.randn(d_out) * 0.3
H_v2, alpha_v2 = gatv2_layer(H0, A_small_sl, W_v2, a_v2)
# Compare attention patterns for node 0
print('Node 0 neighbor attention weights:')
nbrs_0 = [i for i in range(n_small) if A_small_sl[0, i] > 0]
print(f' GAT neighbors {nbrs_0}: {alpha[0, nbrs_0].round(4)}')
print(f' GATv2 neighbors {nbrs_0}: {alpha_v2[0, nbrs_0].round(4)}')
print('\n(Different weights = dynamic attention working in GATv2)')
Code cell 21
# === 5.2 Visualize Attention Matrix ===
if HAS_MPL:
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
for ax, mat, title in zip(axes,
[alpha, alpha_v2],
['GAT (static attention)', 'GATv2 (dynamic attention)']):
im = ax.imshow(mat, cmap='viridis', vmin=0, vmax=mat.max())
fig.colorbar(im, ax=ax, label='Attention weight $\\alpha_{uv}$')
ax.set_title(title)
ax.set_xlabel('Source node $u$')
ax.set_ylabel('Target node $v$')
ax.set_xticks(range(n_small))
ax.set_yticks(range(n_small))
fig.suptitle('Attention patterns: GAT vs GATv2', fontsize=15)
fig.tight_layout()
plt.show()
print('Attention heatmaps displayed.')
6. GIN: Graph Isomorphism Network
Demonstrate that sum aggregation (GIN) distinguishes graphs that mean aggregation (GCN) cannot.
Code cell 23
# === 6. Sum vs Mean Aggregation Expressiveness ===
# Multiset counterexample: {1,1} vs {1,1,1}
M1 = np.array([1.0, 1.0])
M2 = np.array([1.0, 1.0, 1.0])
print('Multiset expressiveness comparison:')
print(f' M1 = {M1}, M2 = {M2}')
print(f' sum(M1) = {M1.sum():.1f}, sum(M2) = {M2.sum():.1f} -> distinguishable')
print(f' mean(M1)= {M1.mean():.1f}, mean(M2)= {M2.mean():.1f} -> INDISTINGUISHABLE')
print(f' max(M1) = {M1.max():.1f}, max(M2) = {M2.max():.1f} -> INDISTINGUISHABLE')
# Graph-level: C6 vs C3+C3 with degree as feature
# C6: all nodes degree 2
# C3+C3: all nodes degree 2
# With one-hot initial features based on degree:
print('\nGIN sum-readout on C6 vs C3+C3 (with degree feature):')
def gin_layer(H, adj_dict, eps=0.0):
"""One GIN layer: MLP((1+eps)*h_v + sum_{u in N(v)} h_u)."""
H_new = np.zeros_like(H)
for v in adj_dict:
nbr_sum = sum(H[u] for u in adj_dict[v])
H_new[v] = (1 + eps) * H[v] + nbr_sum
return np.tanh(H_new) # MLP approximated by tanh
# Initialize with degree as feature
C6_feats = np.array([[d] for d in [2,2,2,2,2,2]], dtype=float)
C3C3_feats = np.array([[d] for d in [2,2,2,2,2,2]], dtype=float)
H_C6 = gin_layer(C6_feats, C6_adj)
H_C3C3 = gin_layer(C3C3_feats, C3C3_adj)
print(f' C6 after 1 GIN layer (sum readout): {H_C6.sum():.6f}')
print(f' C3C3 after 1 GIN layer (sum readout): {H_C3C3.sum():.6f}')
# Add second layer
H_C6_2 = gin_layer(H_C6, C6_adj)
H_C3C3_2 = gin_layer(H_C3C3, C3C3_adj)
print(f'\n C6 after 2 GIN layers (sum readout): {H_C6_2.sum():.6f}')
print(f' C3C3 after 2 GIN layers (sum readout): {H_C3C3_2.sum():.6f}')
print('\n(With uniform initial features, GIN still cannot distinguish them!')
print('WL bound: both have identical 1-WL color multisets at each iteration.')
Code cell 24
# === 6.1 Adding Structural Features Breaks the Tie ===
# Add RWSE-like feature: self-loop probability at step 2
def rwse(adj_matrix, p=2):
"""Random walk structural encoding: return prob at step p."""
n = adj_matrix.shape[0]
D_inv = np.diag(1.0 / (adj_matrix.sum(axis=1) + 1e-10))
P = D_inv @ adj_matrix # Row-stochastic
Pp = np.linalg.matrix_power(P, p)
return np.diag(Pp)
# Build adjacency matrices for C6 and C3+C3
C6_A = make_graph(6, [(0,1),(1,2),(2,3),(3,4),(4,5),(5,0)])
C3C3_A = make_graph(6, [(0,1),(1,2),(2,0),(3,4),(4,5),(5,3)])
rwse_C6 = rwse(C6_A, p=3)
rwse_C3C3 = rwse(C3C3_A, p=3)
print('RWSE (3-step return probability) for C6 and C3+C3:')
print(f' C6 (all values): {rwse_C6.round(4)}')
print(f' C3+C3 (all values): {rwse_C3C3.round(4)}')
print(f'\n C6 unique RWSE: {set(rwse_C6.round(4))}')
print(f' C3+C3 unique RWSE: {set(rwse_C3C3.round(4))}')
distinguishable = not np.allclose(sorted(rwse_C6), sorted(rwse_C3C3))
print(f'\nRWSE distinguishes C6 vs C3+C3: {distinguishable}')
print('PASS: structural features (RWSE) break the 1-WL tie' if distinguishable
else 'FAIL: RWSE cannot distinguish them')
7. GraphSAGE: Inductive Learning and Neighbor Sampling
Implement neighbor sampling and the GraphSAGE aggregation step.
Code cell 26
# === 7. GraphSAGE Neighbor Sampling ===
np.random.seed(42)
def build_random_graph(n, p_edge=0.15):
"""Erdos-Renyi G(n,p) graph."""
adj = {i: set() for i in range(n)}
for i in range(n):
for j in range(i+1, n):
if np.random.rand() < p_edge:
adj[i].add(j)
adj[j].add(i)
return adj
def sample_neighbors(adj, node, k):
"""Sample k neighbors of node uniformly (with replacement if needed)."""
nbrs = list(adj[node])
if len(nbrs) == 0:
return []
if len(nbrs) <= k:
return nbrs
return list(np.random.choice(nbrs, k, replace=False))
def sage_mean_layer(H, adj, target_nodes, S=10):
"""
GraphSAGE mean aggregation layer.
Returns embeddings only for target_nodes.
"""
d = H.shape[1]
W = np.random.randn(2*d, d) * 0.3
H_new = np.zeros((len(target_nodes), d))
for idx, v in enumerate(target_nodes):
sampled = sample_neighbors(adj, v, S)
if sampled:
nbr_mean = H[sampled].mean(axis=0)
else:
nbr_mean = np.zeros(d)
concat = np.concatenate([H[v], nbr_mean])
H_new[idx] = np.tanh(W.T @ concat)
return H_new
n_large = 100
adj_large = build_random_graph(n_large, p_edge=0.08)
X_large = np.random.randn(n_large, 8)
# Compute embeddings for a mini-batch of 10 target nodes
batch = list(range(10))
H_sage = sage_mean_layer(X_large, adj_large, batch, S=5)
degrees = [len(adj_large[v]) for v in range(n_large)]
print(f'Graph: {n_large} nodes, avg degree {np.mean(degrees):.1f}')
print(f'Mini-batch size: {len(batch)} nodes')
print(f'GraphSAGE output shape: {H_sage.shape}')
print(f'Output norms: {np.linalg.norm(H_sage, axis=1).round(4)}')
ok = H_sage.shape == (len(batch), 8)
print(f"\n{'PASS' if ok else 'FAIL'} - GraphSAGE output has correct shape")
Code cell 27
# === 7.1 Inductive Inference: New Node ===
# Simulate adding a new node to the graph
# New node features + connections to some existing nodes
new_node_features = np.random.randn(8)
new_node_neighbors = [0, 5, 12, 23] # Indices of existing nodes
# Compute embedding for new node using LEARNED aggregation function
# (In practice, W would be trained — here we just demonstrate the pipeline)
d = 8
W_inductive = np.random.randn(2*d, d) * 0.3
# Aggregate neighbor features
nbr_feats = X_large[new_node_neighbors]
nbr_mean = nbr_feats.mean(axis=0)
concat = np.concatenate([new_node_features, nbr_mean])
new_node_embedding = np.tanh(W_inductive.T @ concat)
print('Inductive embedding for new node:')
print(f' Features: {new_node_features[:3].round(3)}...')
print(f' Neighbors: {new_node_neighbors}')
print(f' Embedding: {new_node_embedding[:4].round(4)}...')
print(f' Embedding norm: {np.linalg.norm(new_node_embedding):.4f}')
print('\nKey insight: embedding computed WITHOUT retraining the model!')
8. Graph Pooling Methods
Compare global pooling strategies and implement DiffPool soft cluster assignment.
Code cell 29
# === 8. Global Pooling Methods ===
np.random.seed(42)
# Simulate node embeddings for a small graph
n_mol = 12 # 12-atom molecule
d_emb = 8
H_mol = np.random.randn(n_mol, d_emb)
# Global sum pooling
h_sum = H_mol.sum(axis=0)
# Global mean pooling
h_mean = H_mol.mean(axis=0)
# Global max pooling
h_max = H_mol.max(axis=0)
print(f'Node embeddings shape: {H_mol.shape}')
print(f'\nSum pooling -> norm: {np.linalg.norm(h_sum):.4f}')
print(f'Mean pooling -> norm: {np.linalg.norm(h_mean):.4f}')
print(f'Max pooling -> norm: {np.linalg.norm(h_max):.4f}')
print(f'\nSum vs Mean ratio: {np.linalg.norm(h_sum) / np.linalg.norm(h_mean):.4f}')
print(f'(Expected: ~sqrt(n) = {np.sqrt(n_mol):.4f} for random embeddings)')
# Expressiveness: two graphs with same avg but different sum
G1_nodes = np.ones((4, 2)) # 4 nodes, features all 1
G2_nodes = np.ones((8, 2)) # 8 nodes, features all 1
print(f'\nExpressiveness demo (all-ones graphs):')
print(f' G1 (4 nodes) mean: {G1_nodes.mean(axis=0)}, sum: {G1_nodes.sum(axis=0)}')
print(f' G2 (8 nodes) mean: {G2_nodes.mean(axis=0)}, sum: {G2_nodes.sum(axis=0)}')
print('Mean cannot distinguish G1 from G2; Sum can.')
Code cell 30
# === 8.1 DiffPool Soft Assignment ===
np.random.seed(42)
def diffpool_step(A, H, k):
"""
Simplified DiffPool: compute soft cluster assignments.
A: (n, n) adjacency
H: (n, d) node embeddings
k: target number of clusters
Returns: S (n, k), H_pooled (k, d), A_pooled (k, k)
"""
n, d = H.shape
# Soft assignment: S = softmax(H @ W_pool)
W_pool = np.random.randn(d, k) * 0.5
S_logits = H @ W_pool # (n, k)
S = np.exp(S_logits - S_logits.max(axis=1, keepdims=True))
S = S / S.sum(axis=1, keepdims=True) # (n, k) - soft assignments
# Coarsened embeddings: H_pooled = S^T @ H
H_pooled = S.T @ H # (k, d)
# Coarsened adjacency: A_pooled = S^T @ A @ S
A_pooled = S.T @ A @ S # (k, k)
# Auxiliary losses
lp_loss = np.linalg.norm(A - S @ S.T, 'fro') # Link prediction
ent_loss = -(S * np.log(S + 1e-10)).sum() / n # Entropy (min for crisp)
return S, H_pooled, A_pooled, lp_loss, ent_loss
# Run on our small graph
A_hat_dense = A_hat.copy()
H_demo = np.random.randn(n, 5) # n=8 nodes from earlier
k_clusters = 3
S, H_pool, A_pool, lp, ent = diffpool_step(A_hat_dense, H_demo, k_clusters)
print(f'DiffPool: {n} nodes -> {k_clusters} clusters')
print(f'Soft assignment matrix S shape: {S.shape}')
print(f'S row sums (should be 1): {S.sum(axis=1).round(4)}')
print(f'H_pooled shape: {H_pool.shape}')
print(f'A_pooled shape: {A_pool.shape}')
print(f'\nAuxiliary losses:')
print(f' Link prediction loss: {lp:.4f}')
print(f' Entropy loss: {ent:.4f} (lower = crisper assignments)')
9. Positional Encodings for Graph Transformers
Compute Laplacian PE (LapPE) and Random Walk SE (RWSE) and visualize them.
Code cell 32
# === 9. Laplacian Positional Encodings ===
np.random.seed(42)
def compute_lapPE(A, k=4):
"""
Compute first k non-trivial eigenvectors of normalized Laplacian.
Returns: eigvecs (n, k), eigvals (k,)
"""
n = A.shape[0]
D = np.diag(A.sum(axis=1))
D_inv_sqrt = np.diag(1.0 / (np.sqrt(A.sum(axis=1)) + 1e-10))
L_sym = np.eye(n) - D_inv_sqrt @ A @ D_inv_sqrt
eigvals, eigvecs = np.linalg.eigh(L_sym)
# Skip first eigvec (constant, eigenvalue ~0)
# Take next k eigvecs
return eigvecs[:, 1:k+1], eigvals[1:k+1]
def compute_rwse(A, max_p=5):
"""RWSE: landing probability at each step p."""
n = A.shape[0]
D_inv = np.diag(1.0 / (A.sum(axis=1) + 1e-10))
P = D_inv @ A # Row-stochastic
rwse = np.zeros((n, max_p))
Pp = np.eye(n)
for p in range(1, max_p+1):
Pp = Pp @ P
rwse[:, p-1] = np.diag(Pp)
return rwse
# Use larger graph for interesting PEs
# Build a ring graph (C_12)
n_ring = 12
ring_edges = [(i, (i+1) % n_ring) for i in range(n_ring)]
A_ring = make_graph(n_ring, ring_edges)
lapPE, lapEigvals = compute_lapPE(A_ring, k=4)
rwse_ring = compute_rwse(A_ring, max_p=5)
print(f'Graph: C_{n_ring} (ring with {n_ring} nodes)')
print(f'LapPE shape: {lapPE.shape} (n x k)')
print(f'First 4 non-trivial Laplacian eigenvalues: {lapEigvals.round(4)}')
print(f'\nLapPE for first 4 nodes (each row = position encoding):')
print(lapPE[:4].round(4))
print(f'\nRWSE for first 4 nodes (5-step return probs):')
print(rwse_ring[:4].round(4))
Code cell 33
# === 9.1 Visualize LapPE as 'Coordinates' ===
if HAS_MPL:
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
# Left: LapPE eigvec 1 vs eigvec 2 for ring graph
ax = axes[0]
x_coord = lapPE[:, 0]
y_coord = lapPE[:, 1]
sc = ax.scatter(x_coord, y_coord,
c=np.arange(n_ring), cmap='plasma', s=120, zorder=5)
for i in range(n_ring):
ax.annotate(str(i), (x_coord[i]+0.01, y_coord[i]+0.01), fontsize=10)
plt.colorbar(sc, ax=ax, label='Node index')
ax.set_title(f'LapPE: $C_{{12}}$ ring — eigvec 1 vs 2')
ax.set_xlabel('LapPE dim 1 ($\\mathbf{u}_2$)')
ax.set_ylabel('LapPE dim 2 ($\\mathbf{u}_3$)')
ax.set_aspect('equal')
# Right: RWSE for each node
ax = axes[1]
im = ax.imshow(rwse_ring.T, cmap='viridis', aspect='auto')
plt.colorbar(im, ax=ax, label='Return probability')
ax.set_title('RWSE for all nodes (C_{12} ring)')
ax.set_xlabel('Node index')
ax.set_ylabel('Walk length $p$')
ax.set_yticks(range(5))
ax.set_yticklabels([f'p={p+1}' for p in range(5)])
fig.tight_layout()
plt.show()
print('LapPE and RWSE visualizations displayed.')
10. Over-Squashing: Jacobian Analysis
Visualize how the GCN Jacobian decays with distance for bottleneck graphs.
Code cell 35
# === 10. Over-Squashing: Jacobian via A_hat powers ===
np.random.seed(42)
# Build a 'dumbbell' graph: two cliques connected by a single bridge edge
def make_dumbbell(k1=5, k2=5):
"""Two k-cliques connected by a bridge (bottleneck)."""
n = k1 + k2
A = np.zeros((n, n))
# Clique 1: nodes 0..k1-1
for i in range(k1):
for j in range(i+1, k1):
A[i,j] = A[j,i] = 1
# Clique 2: nodes k1..n-1
for i in range(k1, n):
for j in range(i+1, n):
A[i,j] = A[j,i] = 1
# Bridge: last node of clique 1 to first of clique 2
A[k1-1, k1] = A[k1, k1-1] = 1
return A
A_db = make_dumbbell(k1=5, k2=5)
A_hat_db, _, _ = gcn_propagation_matrix(A_db)
n_db = A_db.shape[0]
# Jacobian proxy: (A_hat^k)[v, u] = effective influence of u on v at depth k
depths_sq = [1, 2, 3, 4, 5]
# Nodes of interest: node 0 (clique 1) receiving info from node 9 (clique 2)
v_target = 0
u_source = n_db - 1
print(f'Dumbbell graph: {n_db} nodes, bridge at ({4},{5})')
print(f'Tracking influence of node {u_source} on node {v_target}:')
print()
A_hat_power = np.eye(n_db)
for k in depths_sq:
A_hat_power = A_hat_power @ A_hat_db
influence = A_hat_power[v_target, u_source]
print(f' Depth {k}: (A_hat^{k})[{v_target},{u_source}] = {influence:.6f}')
print()
print('Over-squashing: cross-bottleneck influence exponentially small.')
ok = A_hat_power[v_target, u_source] < 0.01
print(f"{'PASS' if ok else 'FAIL'} - influence < 0.01 after 5 steps")
Code cell 36
# === 10.1 Visualize Influence Matrix at Different Depths ===
if HAS_MPL:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
fig.suptitle('GCN influence matrix $(\\hat{A}^k)_{vu}$: dumbbell graph', fontsize=14)
for ax, k in zip(axes, [1, 2, 5]):
Ak = np.linalg.matrix_power(A_hat_db, k)
im = ax.imshow(Ak, cmap='plasma', vmin=0)
plt.colorbar(im, ax=ax)
ax.set_title(f'Depth $k={k}$')
ax.set_xlabel('Source node $u$')
ax.set_ylabel('Target node $v$')
# Mark the bottleneck
ax.axhline(4.5, color='white', linewidth=1.5, linestyle='--')
ax.axvline(4.5, color='white', linewidth=1.5, linestyle='--')
fig.tight_layout()
plt.show()
print('Influence matrices displayed. Note the dark cross-cluster region.')
11. Community Detection with Graph Spectral Clustering vs GCN
Compare spectral (Fiedler vector) and GCN-based node classification on a planted community graph.
Code cell 38
# === 11. Stochastic Block Model Experiment ===
np.random.seed(42)
def stochastic_block_model(n_per_block, n_blocks, p_in, p_out):
"""Generate SBM adjacency matrix."""
n = n_per_block * n_blocks
A = np.zeros((n, n))
labels = np.repeat(np.arange(n_blocks), n_per_block)
for i in range(n):
for j in range(i+1, n):
p = p_in if labels[i] == labels[j] else p_out
if np.random.rand() < p:
A[i,j] = A[j,i] = 1
return A, labels
n_per = 20
n_bl = 3
A_sbm, true_labels = stochastic_block_model(n_per, n_bl, p_in=0.4, p_out=0.02)
n_sbm = A_sbm.shape[0]
# Spectral clustering via Fiedler vector
D_sbm = np.diag(A_sbm.sum(axis=1))
L_sbm = D_sbm - A_sbm
D_inv_sqrt = np.diag(1.0 / (np.sqrt(A_sbm.sum(axis=1)) + 1e-10))
L_sym_sbm = np.eye(n_sbm) - D_inv_sqrt @ A_sbm @ D_inv_sqrt
eigvals_sbm, eigvecs_sbm = np.linalg.eigh(L_sym_sbm)
# Use first n_blocks non-trivial eigenvectors for clustering
spectral_coords = eigvecs_sbm[:, 1:n_bl+1]
# k-means on spectral coords (manual)
from scipy.spatial.distance import cdist
def kmeans_simple(X, k, max_iter=100):
np.random.seed(42)
centers = X[np.random.choice(len(X), k, replace=False)]
for _ in range(max_iter):
dists = cdist(X, centers)
assignments = dists.argmin(axis=1)
new_centers = np.array([X[assignments == c].mean(axis=0) for c in range(k)])
if np.allclose(centers, new_centers):
break
centers = new_centers
return assignments
spectral_pred = kmeans_simple(spectral_coords, n_bl)
# Accuracy (Hungarian matching)
from itertools import permutations
def clustering_accuracy(true, pred, k):
best_acc = 0
for perm in permutations(range(k)):
mapped = np.array([perm[p] for p in pred])
acc = (mapped == true).mean()
best_acc = max(best_acc, acc)
return best_acc
acc_spectral = clustering_accuracy(true_labels, spectral_pred, n_bl)
print(f'SBM: {n_sbm} nodes, {n_bl} blocks ({n_per} nodes each)')
print(f'Spectral clustering accuracy: {acc_spectral:.3f}')
print(f'\nTrue labels: {true_labels}')
print(f'Predicted: {spectral_pred}')
ok = acc_spectral > 0.85
print(f"\n{'PASS' if ok else 'FAIL'} - spectral clustering >85% accuracy")
Code cell 39
# === 11.1 GCN Propagation on SBM ===
# Initialize with class one-hot features and propagate
H_sbm = np.eye(n_sbm)[:, :n_bl] # n x 3 one-hot-like
A_hat_sbm, _, _ = gcn_propagation_matrix(A_sbm)
# Compute Dirichlet energy for label signal
L_sbm_unnorm = D_sbm - A_sbm
H_onehot = np.zeros((n_sbm, n_bl))
H_onehot[np.arange(n_sbm), true_labels] = 1.0
print('Dirichlet energy of label signal at increasing GCN depth:')
H_iter = H_onehot.copy()
for depth in [0, 1, 2, 3, 5, 10]:
H_iter_d = np.linalg.matrix_power(A_hat_sbm, depth) @ H_onehot
e = np.trace(H_iter_d.T @ L_sbm_unnorm @ H_iter_d)
print(f' Depth {depth:2d}: E = {e:.4f}')
print()
print('Key: at depth 2-3, label signal is smooth but discriminative.')
print('At depth 10, over-smoothing destroys cluster boundaries.')
12. Training Dynamics and Learning Curves
Simulate GNN training on a synthetic node classification task and visualize convergence.
Code cell 41
# === 12. Synthetic GCN Training Simulation ===
np.random.seed(42)
# Generate a simple 2-class node classification problem on SBM
n2 = 40
A2, labels2 = stochastic_block_model(n2//2, 2, p_in=0.4, p_out=0.04)
A_hat2, _, _ = gcn_propagation_matrix(A2)
# Features: class-correlated with noise
X2 = np.zeros((n2, 4))
X2[:n2//2, 0] = 1.0 # Class 0: feature 0 = 1
X2[n2//2:, 1] = 1.0 # Class 1: feature 1 = 1
X2 += np.random.randn(n2, 4) * 0.3 # Add noise
# Labels: 20% labeled (4 from each class)
labeled_idx = list(range(0, 4)) + list(range(n2//2, n2//2+4))
Y = labels2[labeled_idx]
# Two-layer GCN with SGD (manual backprop)
W1_train = np.random.randn(4, 4) * 0.5
W2_train = np.random.randn(4, 2) * 0.5
lr = 0.05
def forward(X, A_hat, W1, W2):
H1 = np.maximum(0, A_hat @ X @ W1) # ReLU
logits = A_hat @ H1 @ W2 # No activation
# Softmax
e = np.exp(logits - logits.max(axis=1, keepdims=True))
probs = e / e.sum(axis=1, keepdims=True)
return H1, probs
losses = []
for step in range(300):
H1, probs = forward(X2, A_hat2, W1_train, W2_train)
# Cross-entropy loss on labeled nodes
loss = -np.log(probs[labeled_idx, Y] + 1e-10).mean()
losses.append(loss)
# Gradient for W2 (simplified)
dL = probs.copy()
dL[labeled_idx, Y] -= 1
dL /= len(labeled_idx)
dW2 = (A_hat2 @ H1).T @ dL
dH1 = dL @ W2_train.T
dH1_relu = dH1 * (H1 > 0).astype(float)
dW1 = X2.T @ (A_hat2 @ dH1_relu)
W2_train -= lr * dW2
W1_train -= lr * dW1
_, final_probs = forward(X2, A_hat2, W1_train, W2_train)
pred_all = final_probs.argmax(axis=1)
acc = (pred_all == labels2).mean()
print(f'Training loss after 300 steps: {losses[-1]:.4f}')
print(f'Final accuracy (all nodes): {acc:.3f}')
ok = acc > 0.7
print(f"{'PASS' if ok else 'FAIL'} - GCN achieves >70% accuracy on SBM")
Code cell 42
# === 12.1 Training Curve Visualization ===
if HAS_MPL:
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
# Left: Loss curve
ax = axes[0]
ax.plot(losses, color=COLORS['primary'], linewidth=2)
ax.set_title('GCN training loss (semi-supervised)')
ax.set_xlabel('Training step')
ax.set_ylabel('Cross-entropy loss')
ax.axhline(losses[-1], color=COLORS['neutral'], linestyle='--',
alpha=0.6, label=f'Final: {losses[-1]:.4f}')
ax.legend()
# Right: Final node embeddings colored by predicted label
ax = axes[1]
# Use GCN output as 2D coordinates (take first 2 logit dims)
_, probs_viz = forward(X2, A_hat2, W1_train, W2_train)
# Color by true label, shape by prediction
for cls, color, label in [(0, COLORS['primary'], 'Class 0'),
(1, COLORS['secondary'], 'Class 1')]:
mask = labels2 == cls
ax.scatter(probs_viz[mask, 0], probs_viz[mask, 1],
color=color, s=60, alpha=0.8, label=label)
# Mark labeled nodes with larger markers
ax.scatter(probs_viz[labeled_idx, 0], probs_viz[labeled_idx, 1],
c='none', edgecolors='black', s=150, linewidths=2,
label='Labeled (train)')
ax.set_title('Node prediction probabilities after training')
ax.set_xlabel('$P$(class 0)')
ax.set_ylabel('$P$(class 1)')
ax.legend(markerscale=1.2)
fig.tight_layout()
plt.show()
print('Training curves and classification plot displayed.')
13. Architecture Comparison Summary
Tabulate key properties of the major GNN architectures.
Code cell 44
# === 13. Architecture Comparison ===
architectures = [
('GCN', 'Fixed (degree-norm)', 'Sum/Mean', 'MLP', 'O(m*d)', '< 1-WL', 'Transductive'),
('GraphSAGE', 'Fixed (mean/max)', 'Mean/Max', 'MLP', 'O(m*d)', '< 1-WL', 'Inductive'),
('GAT', 'Learned (static)', 'Attention','MLP', 'O(m*d)', '< 1-WL', 'Inductive'),
('GATv2', 'Learned (dynamic)', 'Attention','MLP', 'O(m*d)', '< 1-WL', 'Inductive'),
('GIN', 'Fixed (none)', 'Sum', 'MLP', 'O(m*d)', '= 1-WL', 'Inductive'),
('GPS', 'Local MPNN + Global Attn','Hybrid','MLP', 'O(m+n^2)*d','> 1-WL w/ PE','Inductive'),
('Graphormer', 'Full attention + bias','Attention','MLP', 'O(n^2*d)','> 1-WL w/ PE','Inductive'),
]
header = ['Architecture', 'Aggregation', 'Agg. fn', 'Update', 'Complexity', 'Expressiveness', 'Setting']
col_widths = [13, 22, 10, 5, 15, 16, 12]
def row_str(row):
return ' | '.join(str(v).ljust(w) for v, w in zip(row, col_widths))
print(row_str(header))
print('-+-'.join('-'*w for w in col_widths))
for arch in architectures:
print(row_str(arch))
print()
print('Key takeaways:')
print(' 1. GIN is the only MPNN matching 1-WL expressiveness (sum agg + MLP)')
print(' 2. Graph Transformers exceed 1-WL through structural PEs')
print(' 3. GraphSAGE+Cluster-GCN scale to billion-node graphs; GPS does not')
print(' 4. GAT vs GATv2: same complexity, GATv2 has dynamic (stronger) attention')
Summary
This notebook has covered:
- GCN propagation matrix and its spectral properties (eigenvalues in , permutation equivariance)
- Over-smoothing as Dirichlet energy decay: geometrically in
- Weisfeiler-Leman test: 1-WL cannot distinguish from — any MPNN without structural features faces the same limit
- GAT vs GATv2: static vs dynamic attention — GATv2's creates genuine query-dependent attention
- GIN sum aggregation: distinguishes multisets that mean/max cannot
- RWSE breaks 1-WL symmetry by encoding local loop structure
- GraphSAGE: inductive inference on unseen nodes using learned aggregation functions
- DiffPool: soft cluster assignment , coarsened graph
- LapPE and RWSE: positional encodings for graph transformers that capture relative structure
- Over-squashing: dumbbell graph shows cross-bottleneck influence decays exponentially
Companion materials:
- notes.md — Full mathematical treatment
- exercises.ipynb — Practice problems with solutions
14. Graph Rewiring Effect on Information Flow
Compare the Fiedler value and cross-cluster influence before and after adding a rewiring edge.
Code cell 47
# === 14. Graph Rewiring ===
import numpy as np
np.random.seed(42)
def make_graph_local(n, edges):
A = np.zeros((n, n))
for u, v in edges:
A[u,v]=1; A[v,u]=1
return A
def gcn_prop(A):
n = A.shape[0]
At = A + np.eye(n)
d = At.sum(axis=1)
D_inv_sqrt = np.diag(1.0/np.sqrt(d))
return D_inv_sqrt @ At @ D_inv_sqrt
def fiedler(A):
D = np.diag(A.sum(axis=1))
L = D - A
return np.sort(np.linalg.eigvalsh(L))[1]
# Dumbbell: two 5-cliques connected by one bridge edge
k = 5
n_db = 2*k
A_db = np.zeros((n_db, n_db))
for i in range(k):
for j in range(i+1,k):
A_db[i,j]=A_db[j,i]=1
for i in range(k, n_db):
for j in range(i+1,n_db):
A_db[i,j]=A_db[j,i]=1
A_db[k-1, k]=A_db[k, k-1]=1 # bridge
A_rew = A_db.copy()
A_rew[0, n_db-1]=A_rew[n_db-1,0]=1 # extra rewiring edge
lam2_orig = fiedler(A_db)
lam2_rew = fiedler(A_rew)
inf_orig = np.linalg.matrix_power(gcn_prop(A_db), 3)[0, n_db-1]
inf_rew = np.linalg.matrix_power(gcn_prop(A_rew), 3)[0, n_db-1]
print('Rewiring effect on dumbbell graph:')
print(f' lambda_2 original: {lam2_orig:.6f}')
print(f' lambda_2 rewired: {lam2_rew:.6f}')
print(f' Cross-cluster influence (A_hat^3)[0,9]: orig={inf_orig:.6f}, rewired={inf_rew:.6f}')
ok = inf_rew > inf_orig * 1.5
print(f"\n{'PASS' if ok else 'FAIL'} - rewiring improves cross-cluster influence")
15. MPNN Unification: GCN, SAGE, GIN Side-by-Side
Verify that all three fit the template and differ only in aggregation and update design.
Code cell 49
# === 15. MPNN Unification ===
import numpy as np
np.random.seed(42)
def make_adj(n, edges):
adj = {i: set() for i in range(n)}
for u,v in edges:
adj[u].add(v); adj[v].add(u)
return adj
n_u = 5
edges_u = [(0,1),(1,2),(2,3),(3,4),(0,3),(1,4)]
adj_u = make_adj(n_u, edges_u)
H_u = np.random.randn(n_u, 3)
W1u = np.random.randn(3,3)*0.5
W2u = np.random.randn(6,3)*0.3
# GCN-style: mean(normalized) + linear update
def layer_gcn(H, adj):
H_new = np.zeros_like(H)
for v in adj:
all_nodes = list(adj[v]) + [v] # include self
norms = np.array([np.sqrt(len(adj[u])+1) for u in all_nodes])
dv = np.sqrt(len(adj[v])+1)
m = sum(H[u]/norms[i]/dv for i,u in enumerate(all_nodes))
H_new[v] = np.tanh(W1u.T @ m)
return H_new
# SAGE: mean + concat update
def layer_sage(H, adj):
H_new = np.zeros_like(H)
for v in adj:
nbrs = list(adj[v])
m = H[nbrs].mean(axis=0) if nbrs else np.zeros(H.shape[1])
H_new[v] = np.tanh(W2u.T @ np.concatenate([H[v], m]))
return H_new
# GIN: sum + MLP update
W3u = np.random.randn(3,3)*0.3
def layer_gin(H, adj, eps=0.0):
H_new = np.zeros_like(H)
for v in adj:
s = sum(H[u] for u in adj[v]) if adj[v] else np.zeros(H.shape[1])
H_new[v] = np.tanh(W3u.T @ np.tanh(W1u.T @ ((1+eps)*H[v] + s)))
return H_new
H_gcn2 = layer_gcn(H_u, adj_u)
H_sage2 = layer_sage(H_u, adj_u)
H_gin2 = layer_gin(H_u, adj_u)
print('One MPNN layer on 5-node graph:')
for name, H_out in [('GCN ', H_gcn2), ('SAGE', H_sage2), ('GIN ', H_gin2)]:
print(f' {name}: shape={H_out.shape}, norms={np.linalg.norm(H_out,axis=1).round(4)}')
print('\nAll fit MPNN template. Key differences:')
print(' GCN: mean-normalized agg, linear update (fixed by spectral derivation)')
print(' SAGE: mean agg, concat update (inductive, scalable)')
print(' GIN: sum agg, 2-layer MLP update (maximally expressive, = 1-WL)')
Code cell 50
# === 15.1 Expressiveness: Which Aggregator Distinguishes Most Graphs? ===
import numpy as np
np.random.seed(42)
# Three informative multisets
tests = [
('M={1,1} vs M={1,1,1}', [1.0,1.0], [1.0,1.0,1.0]),
('M={1,2} vs M={2}', [1.0,2.0], [2.0]),
('M={1,1,2} vs M={1,2,2}', [1.0,1.0,2.0], [1.0,2.0,2.0]),
]
print('Aggregation expressiveness on multisets:')
print(f'{"Test":<30} {"Sum?":<6} {"Mean?":<6} {"Max?"}')
print('-'*55)
for name, m1, m2 in tests:
m1, m2 = np.array(m1), np.array(m2)
s = not np.isclose(m1.sum(), m2.sum())
me = not np.isclose(m1.mean(), m2.mean())
mx = not np.isclose(m1.max(), m2.max())
print(f'{name:<30} {str(s):<6} {str(me):<6} {str(mx)}')
print()
print('PASS: sum is strictly more powerful than mean and max')
16. GPS Layer: Local MPNN + Global Attention
Implement a simplified GPS layer combining neighborhood aggregation with full pairwise attention.
Code cell 52
# === 16. Simplified GPS Layer ===
import numpy as np
np.random.seed(42)
def softmax_rows(X):
e = np.exp(X - X.max(axis=1, keepdims=True))
return e / e.sum(axis=1, keepdims=True)
def scaled_dot_product_attention(H, W_Q, W_K, W_V):
Q = H @ W_Q; K = H @ W_K; V = H @ W_V
dk = Q.shape[1]
scores = Q @ K.T / np.sqrt(dk)
alpha = softmax_rows(scores)
return alpha @ V, alpha
def local_mpnn(H, adj):
"""Simple mean aggregation."""
H_new = np.zeros_like(H)
for v in adj:
nbrs = list(adj[v])
m = H[nbrs].mean(axis=0) if nbrs else np.zeros(H.shape[1])
H_new[v] = np.tanh(0.5*(H[v] + m))
return H_new
def layer_norm(H, eps=1e-6):
mu = H.mean(axis=1, keepdims=True)
sigma = H.std(axis=1, keepdims=True) + eps
return (H - mu) / sigma
n_gps = 8; d_gps = 6
edges_gps = [(0,1),(1,2),(2,3),(3,4),(4,5),(5,6),(6,7),(0,4),(2,6)]
adj_gps = {i: set() for i in range(n_gps)}
for u,v in edges_gps:
adj_gps[u].add(v); adj_gps[v].add(u)
H_gps = np.random.randn(n_gps, d_gps)
W_Q = np.random.randn(d_gps, d_gps) * 0.3
W_K = np.random.randn(d_gps, d_gps) * 0.3
W_V = np.random.randn(d_gps, d_gps) * 0.3
# GPS layer: H' = LayerNorm(H + MPNN(H,A) + Attention(H))
mpnn_out = local_mpnn(H_gps, adj_gps)
attn_out, attn_weights = scaled_dot_product_attention(H_gps, W_Q, W_K, W_V)
H_gps_out = layer_norm(H_gps + mpnn_out + attn_out)
print(f'GPS layer: {n_gps} nodes, d={d_gps}')
print(f'MPNN output norm: {np.linalg.norm(mpnn_out, axis=1).round(4)}')
print(f'Attention output norm: {np.linalg.norm(attn_out, axis=1).round(4)}')
print(f'GPS output norm: {np.linalg.norm(H_gps_out, axis=1).round(4)}')
print(f'\nAttention matrix shape: {attn_weights.shape}')
print(f'Attention row sums (should be 1): {attn_weights.sum(axis=1).round(4)}')
ok = np.allclose(attn_weights.sum(axis=1), 1.0, atol=1e-6)
print(f"\n{'PASS' if ok else 'FAIL'} - GPS attention rows sum to 1")
print('\nGPS combines: local structure (MPNN) + global context (Attention)')