Lesson overview | Previous part | Next part
Graph Neural Networks: Part 11: Training and Scaling GNNs to Appendix B: Implementation Notes
11. Training and Scaling GNNs
11.1 Full-Batch vs Mini-Batch Training
Full-batch training. Compute the GNN on the entire graph in each step. Requires the entire graph and all node features to fit in GPU memory. Practical for graphs with nodes. Exact gradient computation.
Mini-batch training. Sample a set of target nodes (a "mini-batch") and compute the GNN only on those nodes and their -hop neighborhoods. Much more memory-efficient but requires careful sampling to produce unbiased gradient estimates.
The challenge of mini-batch training for GNNs is the neighborhood explosion: if each node has degree , a -hop neighborhood has nodes. For , , this is 1000 nodes per target node - often larger than the desired mini-batch size.
Three approaches manage this explosion: neighbor sampling (11.2), graph partitioning (11.3), and subgraph sampling (11.4).
11.2 Neighbor Sampling
GraphSAGE neighbor sampling (Hamilton et al., 2017). For each target node at each layer, sample a fixed-size subset of neighbors:
- Layer 1 (nearest neighbors): sample neighbors
- Layer 2 (2-hop): sample neighbors of each sampled layer-1 neighbor
Total nodes per target: . Constant cost per target node, regardless of graph degree.
Variance reduction. Random sampling introduces variance in the gradient estimate. This can be reduced by importance sampling (weight each sampled neighbor by ) or by using multiple samples. In practice, - is sufficient for most tasks.
VR-GCN (Chen et al., 2018). Control variates method: maintain historical node embeddings as running averages. Use the historical embedding as a control variate for unsampled neighbors:
This gives an unbiased estimate of the full-batch gradient with lower variance than simple neighbor sampling.
11.3 Cluster-GCN
Cluster-GCN (Chiang et al., 2019) takes a graph partitioning approach: partition the graph into balanced clusters using METIS or other graph partitioning algorithms. A mini-batch consists of one or more clusters.
Key insight. Within a cluster, most edges are intra-cluster (partitioning algorithms minimize the cut). Running GCN on the induced subgraph of a cluster produces embeddings that see most relevant edges. Between-cluster edges (the cut) are ignored - this introduces approximation error, but the cut is small by design.
Memory: a cluster of size requires node features + the induced adjacency. For , memory is smaller than full-batch training.
Variance: the gradient approximation error comes from ignoring cross-cluster edges. To reduce this, Cluster-GCN uses multiple cluster sampling: each mini-batch consists of randomly selected clusters , retaining all edges within the union. This recovers some cross-cluster edges.
11.4 GraphSAINT: Subgraph Sampling
GraphSAINT (Zeng et al., 2020) frames mini-batch training as subgraph sampling: each mini-batch is a subgraph induced by a sampled node set .
Three sampling strategies:
- Node sampler: sample nodes uniformly; include all edges between sampled nodes. Simple but misses important edges.
- Edge sampler: sample edges uniformly; include all nodes incident to sampled edges. Better edge coverage.
- Random walk sampler: start random walks from random nodes; include all visited nodes (and their induced edges). Produces connected subgraphs with richer local structure.
Normalization. Since different edges are sampled with different probabilities, the gradient is biased without correction. GraphSAINT computes the sampling probability for each node and for each edge analytically (or by estimation), then reweights the loss:
This gives an unbiased gradient estimate and removes systematic bias from subgraph sampling.
Scaling. GraphSAINT successfully trains GNNs on Flickr (89K nodes, 900K edges), Reddit (232K nodes, 114M edges), and Yelp (716K nodes, 13M edges) - orders of magnitude larger than full-batch training can handle.
12. Applications in Machine Learning
12.1 Molecular Property Prediction and Drug Discovery
Molecular property prediction is the "killer application" of GNNs. A molecule is naturally a graph: atoms are nodes (with features: atomic number, hybridization, formal charge, chirality), and bonds are edges (with features: bond type, aromaticity, stereo configuration).
MPNN for quantum chemistry (Gilmer et al., 2017). Applied to the QM9 dataset (134K small organic molecules, 12 quantum chemical properties including dipole moment, polarizability, and HOMO-LUMO gap). The MPNN with edge-network message functions achieves state-of-the-art on 11 of 12 properties, demonstrating that learned graph representations outperform hand-engineered molecular descriptors.
SchNet (Schutt et al., 2017). A rotationally equivariant GNN for 3D molecular geometry. Nodes are atoms; edges connect atoms within a cutoff radius with edge features encoding pairwise distances. Uses continuous-filter convolutional layers: where is an MLP applied to interatomic distance. Achieves chemical accuracy on QM9 for most properties.
AlphaFold 2 (Jumper et al., 2021). The protein structure prediction breakthrough. Key GNN component: the Evoformer, which processes pairs of residues (edge features) and individual residues (node features) through an iterative process. The structure module then uses equivariant graph operations to predict 3D residue coordinates. AlphaFold 2 won the CASP14 protein structure prediction competition with accuracy on most targets - a problem open for 50 years.
Drug discovery pipeline. Modern computational drug discovery uses GNNs at every stage:
- Virtual screening: predict binding affinity between drug candidates (small molecules) and protein targets; GNNs on molecular graphs reduce wet-lab screening costs by
- ADMET prediction: predict absorption, distribution, metabolism, excretion, and toxicity properties from molecular structure
- De novo design: generative GNNs (junction tree VAE, graph diffusion models) design novel molecules with desired properties
12.2 Knowledge Graph Reasoning
A knowledge graph (KG) is a directed, heterogeneous graph where is a set of entities, is a set of relation types, and is a set of triples (subject, relation, object): e.g., (Einstein, bornIn, Germany), (Germany, locatedIn, Europe).
Triple completion (link prediction). Given an incomplete KG, predict missing triples: is (Einstein, worksWith, Bohr) a true triple? Key methods:
- TransE (Bordes et al., 2013): model relation as a translation in embedding space: for true triples . Simple and effective for 1-to-1 relations.
- RotatE (Sun et al., 2019): model relation as a rotation in complex space: where is element-wise complex multiplication. Captures symmetry, antisymmetry, and composition patterns.
R-GCN (Schlichtkrull et al., 2018). GNN for relational (heterogeneous) graphs. Separate weight matrix for each relation type:
where are the neighbors of through relation and is a normalization constant. Used for entity classification and relation prediction in Freebase and AIFB knowledge graphs.
For AI: In 2024, knowledge graphs power retrieval-augmented generation systems. Microsoft's Graph RAG (Edge et al., 2024) builds a knowledge graph from documents, then uses GNNs (and graph traversal) to answer complex multi-hop questions that flat vector retrieval cannot handle.
12.3 Recommendation Systems
Recommendation is a bipartite graph problem: users and items form nodes; interactions (clicks, purchases, ratings) form edges. The task is link prediction: predict which unobserved pairs represent true preferences.
LightGCN (He et al., 2020). Simplifies GCN for collaborative filtering by removing the feature transformation and nonlinearity, keeping only the graph smoothing:
Final embeddings: (JK-style layer combination). Prediction: .
LightGCN outperforms standard GCN for collaborative filtering, suggesting that the feature transformation is unnecessary (or harmful) when the only input feature is an ID embedding. The key contribution is the propagation of user-item signals through multi-hop paths.
PinSage (already discussed in 5.5) extends GraphSAGE to the bipartite user-pin graph with importance-based sampling and hard negative mining.
12.4 Code and Program Analysis
Source code can be represented as multiple graphs simultaneously:
- AST (Abstract Syntax Tree): hierarchical tree showing syntactic structure
- CFG (Control Flow Graph): nodes are basic blocks; edges show execution flow (if/else, loops)
- DFG (Data Flow Graph): edges connect variable definitions to their uses
- Call Graph: nodes are functions; edges connect callers to callees
code2vec (Alon et al., 2019). Represents code snippets as bags of AST paths; learns path embeddings and aggregates them for downstream tasks. While not a full GNN, it demonstrates the power of structural code representations.
GNNs for bug detection. Allamanis et al. (2018) build heterogeneous graphs from Python/C code with AST, data flow, and control flow edges; train a GNN to classify whether a variable name is misused. Achieves high precision on detecting certain classes of bugs.
Program synthesis. GNNs operating on execution traces (program state as graph) guide neural program search. DeepCoder (Balog et al., 2017) and more recent systems use graph representations of input-output examples to synthesize programs.
12.5 LLM Integration: Graph RAG and Structure-Aware Language Models
The 2024-2026 frontier is combining GNNs with large language models - exploiting the complementary strengths of relational structure (GNNs) and natural language understanding (LLMs).
Graph RAG (Edge et al., 2024, Microsoft Research). Standard RAG (retrieval-augmented generation) retrieves flat text chunks. Graph RAG builds a knowledge graph from documents and uses graph community detection (Leiden algorithm) to generate hierarchical summaries. At query time, relevant communities are retrieved and their summaries are used to ground the LLM response.
Key insight: complex questions requiring synthesis across many documents (e.g., "What are the main themes in this corpus?") are better answered by traversing a knowledge graph than by retrieving isolated text chunks. Graph RAG achieves % win rate over naive RAG on these "global sensemaking" queries.
GNN+LLM for molecular generation. LLMs can generate SMILES strings (text representations of molecules), but they struggle with structural validity constraints (valence, aromaticity). GNN-based validity checkers or graph-structured decoders can enforce these constraints.
Graph-structured memory for agents. Recent AI agents (2025) maintain working memory as knowledge graphs updated with each observation. A GNN processes the graph state to inform action selection. This enables multi-hop reasoning ("I know A->B->C, so if I see C, I should infer A") that flat token sequences handle poorly.
13. Common Mistakes
| # | Mistake | Why It's Wrong | Fix |
|---|---|---|---|
| 1 | Using mean aggregation for graph classification when graph size varies | Mean normalizes by node count - two graphs with identical local structure but different sizes get identical embeddings, despite being different | Use sum aggregation (which preserves size information) or pair with a size feature |
| 2 | Forgetting self-loops in GCN | Without self-loops (), a node does not include its own features in aggregation - it only receives neighbors' information, losing its own signal | Always add to before normalization; this is the renormalization trick from Kipf & Welling |
| 3 | Adding too many GCN layers and blaming model capacity | Deep GCNs fail due to over-smoothing, not lack of capacity - adding parameters won't help | Use 2-4 layers; add residual connections (GCNII), DropEdge, or PairNorm if more depth is needed |
| 4 | Treating GAT attention weights as feature importance | Attention weights tell you which neighbors were weighted more, not which features were important - this is the "attention is not explanation" problem | Use gradient-based attribution (GradCAM, Integrated Gradients) for feature importance; treat attention as architectural choice, not explanation |
| 5 | Using GIN with mean aggregation | GIN's theoretical power (matching 1-WL) requires sum aggregation with an injective MLP. Mean aggregation in GIN is worse than GCN in expressiveness | Always use sum aggregation in GIN; verify that the MLP has sufficient depth (\geq2 layers) |
| 6 | Forgetting that 1-WL cannot detect triangles | Any MPNN is bounded by 1-WL; 1-WL cannot count triangles or cliques. If your task requires triangle detection (e.g., social network analysis), a standard GNN will fail | Add structural features (triangle count, RWSE) or use a higher-order GNN (NGNN, subgraph GNN) |
| 7 | Normalizing node features but not edge features | Unnormalized edge features with large variance can dominate the attention scores in GAT or the message in MPNN | Normalize edge features to zero mean, unit variance; use LayerNorm before message computation |
| 8 | Using transductive GCN for inductive tasks | GCN's normalized adjacency is computed on the training graph; new nodes at test time require recomputing the entire adjacency and rerunning the network | Use inductive methods (GraphSAGE, GAT) that learn aggregation functions, not fixed propagation matrices |
| 9 | Applying global pooling before sufficient local aggregation | With only 1 GNN layer before readout, each node only knows its immediate neighbors; graph-level representations lack structural context | Use 3-5 GNN layers before readout; consider hierarchical pooling (DiffPool) for hierarchically structured graphs |
| 10 | Ignoring the over-squashing problem for long-range tasks | If the task requires integrating information from nodes far apart, a shallow MPNN will fail due to over-squashing even without over-smoothing | Use graph rewiring (DIGL, EGP) or graph transformers (GPS, Graphormer) for long-range tasks; measure effective receptive field |
| 11 | Training GNN on the test graph for transductive semi-supervised learning | Using the test graph topology during training is allowed (transductive setting), but using test node labels is a data leakage error | Only mask the labels of test nodes; the full graph adjacency is legitimately used during both training and testing in the transductive setting |
| 12 | Confusing graph-level and node-level tasks in the readout | Using node embeddings directly for graph classification (without pooling) produces an embedding for each node, not for the graph - dimensions won't match | Always apply a readout function (sum/mean/attention pooling) after the final GNN layer for graph-level tasks |
14. Exercises
Exercise 1 * - MPNN Implementation
Implement a 2-layer MPNN on a small graph and compute node representations from scratch.
(a) Construct the adjacency matrix for the graph: nodes , edges .
(b) Initialize node features (identity matrix - each node has a one-hot representation).
(c) Implement one GCN layer: where , , and is a random matrix (fixed seed).
(d) Apply a second GCN layer to get .
(e) Verify that permuting the nodes (applying a permutation to rows of and both rows and columns of ) produces permuted outputs .
Exercise 2 * - GCN Propagation Matrix
Analyze the spectral properties of the GCN propagation matrix .
(a) For a path graph (5 nodes in a line), compute and find its eigenvalues.
(b) Show that all eigenvalues of lie in by relating to the normalized Laplacian and using the PSD property of .
(c) Compute for and visualize the row norms as a function of . Observe over-smoothing: the rows converge.
(d) Show that the limiting matrix has all rows proportional to , the stationary distribution. Compute for .
Exercise 3 * - Aggregation Expressiveness
Demonstrate that sum aggregation is strictly more expressive than mean for multisets.
(a) Construct two multisets and of node features (scalar, value 1). Compute sum(), mean(), sum(), mean(). Verify that sum distinguishes them but mean does not.
(b) Construct two multisets and . Show that max() = max() = 2, but sum() sum().
(c) For two graphs (star with 3 leaves) and (path of 4 nodes), show that a GCN with mean aggregation assigns identical representations to some nodes, but a GIN with sum aggregation does not.
(d) Implement the WL color refinement algorithm for and with uniform initial colors. How many iterations until stable? Does WL distinguish from ?
Exercise 4 ** - Weisfeiler-Leman Test
Implement the 1-WL algorithm and apply it to pairs of graphs.
(a) Implement wl_isomorphism_test(G1, G2, max_iter=10): run 1-WL color refinement on both graphs simultaneously; return "MAYBE_ISOMORPHIC" if the final color histograms match, or "NOT_ISOMORPHIC" at the first iteration where they differ.
(b) Test on (6-cycle) and (complete bipartite, 6 nodes, 9 edges). What does WL conclude? Are they actually isomorphic?
(c) Construct two non-isomorphic 3-regular graphs on 6 nodes ( and the prism graph ). Does 1-WL distinguish them? (Hint: both are 3-regular, so the degree-based first iteration is identical.)
(d) Explain why any MPNN with mean aggregation would fail to distinguish these graphs, while GIN with sum aggregation would succeed (or also fail, if WL itself fails).
Exercise 5 ** - GAT Attention Mechanism
Implement a single GAT attention head from scratch.
(a) For a graph with 5 nodes and edge set , compute attention coefficients for all edges.
(b) Use , , random and (fixed seed). Initialize node features randomly.
(c) Compute the attention logits with LeakyReLU (slope 0.2) for all edges, then apply softmax over each node's neighborhood.
(d) Show the difference between GAT and GATv2: compute GATv2 attention scores and verify that the neighbor ranking for node 0 can differ between GAT and GATv2.
(e) Compute the updated node representations .
Exercise 6 ** - Over-Smoothing Dynamics
Quantify over-smoothing as a function of GCN depth.
(a) Construct a Cora-like graph: generate a stochastic block model with 200 nodes, 4 blocks, within-block edge probability , between-block probability .
(b) Initialize node features as class one-hot vectors ( class assignments as a matrix).
(c) Apply the GCN propagation (no weight matrix, no nonlinearity - pure smoothing) for steps. Compute the Dirichlet energy at each depth.
(d) Plot vs on a log scale. Fit an exponential and estimate .
(e) Compare with the theoretical rate: where is the Fiedler value of the random-walk Laplacian. Compute numerically and compare.
Exercise 7 *** - GIN Implementation and Expressiveness
Implement GIN and compare its expressiveness to GCN on graph-level classification.
(a) Implement a 3-layer GIN with sum aggregation and 2-layer MLP update: with .
(b) Construct two non-isomorphic graphs and (as computed in 7.2, WL cannot distinguish them without initial node features). Initialize all nodes with the same feature vector .
(c) Show that with uniform initial features, neither GIN nor GCN can distinguish and (both produce identical graph-level sum representations). Why?
(d) Now add degree as an initial node feature: . Rerun both GIN and GCN. Does GIN now distinguish the two graphs? Does GCN?
(e) Generate a dataset of 500 random graphs (mix of Erdos-Renyi and stochastic block models); train GIN and GCN for graph-level binary classification. Report test accuracy. Verify that GIN \geq GCN accuracy.
Exercise 8 *** - Graph Transformer with Positional Encoding
Implement a simplified graph transformer layer with Laplacian positional encoding.
(a) For a graph with nodes (Barabasi-Albert model with ), compute the normalized Laplacian and extract the first non-trivial eigenvectors .
(b) Initialize node features randomly and augment with LapPE: .
(c) Implement one layer of scaled dot-product self-attention over all 20 nodes (fully connected, ignoring graph edges): .
(d) Implement one GCN layer on the same graph. Visualize the attention matrices of both: which nodes attend to which? Does the graph transformer's attention recover some of the graph structure (do nearby nodes attend to each other more)?
(e) Compare: run 3 iterations of graph transformer attention vs 3 layers of GCN on the same random features. Compute the Dirichlet energy after each: which decreases faster? What does this reveal about over-smoothing in graph transformers vs GNNs?
15. Why This Matters for AI (2026 Perspective)
| Concept | AI Impact |
|---|---|
| MPNN framework | The unifying abstraction behind AlphaFold 2 (protein structure), drug screening, and materials design GNNs. Every new spatial GNN architecture published in 2024-2026 is a special case |
| GCN propagation rule | Directly computable from sparse adjacency; used in LightGCN powering Netflix/Pinterest/TikTok recommendation at billion-node scale |
| GraphSAGE inductive learning | Enables daily-updated embeddings for dynamic graphs (social networks, e-commerce) without retraining the full model |
| GAT / GATv2 | Learned, sparse attention patterns over structured data; used in molecular GNNs for drug-target interaction prediction (AstraZeneca, Recursion Pharmaceuticals) |
| WL expressiveness theorem | Theoretical bound establishing what structure any MPNN can detect; determines when to add structural features or upgrade to subgraph GNNs; guides architecture search |
| GIN sum aggregation | The foundation for most molecular property prediction models; identifies that all mainstream GCN deployments are sub-optimal in expressiveness |
| Over-smoothing analysis | Explains why 2-layer GNNs outperform 8-layer GNNs in most production deployments; directly informs depth tuning decisions |
| Over-squashing and graph rewiring | Motivates graph preprocessing in biochemistry (adding virtual bonds between distant atoms) and knowledge graph reasoning (adding transitive edges) |
| Graph Transformers (GPS, Graphormer) | State-of-the-art on quantum chemistry benchmarks (QM9, OGB-LSC); powering the next generation of molecular AI models |
| LapPE and RWSE | Standard input features for all graph foundation models; analogous to positional encodings in LLMs - without them, graph models cannot distinguish relative node positions |
| Graph RAG | Microsoft's production system (Azure AI Search) for document understanding queries requiring multi-hop reasoning; deployed across Office 365 and Microsoft 365 Copilot |
| Hierarchical pooling (DiffPool, MinCutPool) | Used in drug candidate scoring for proteins and polymer graphs where hierarchical structure (residues -> domains -> whole protein) is essential |
| Neighbor sampling (GraphSAGE, Cluster-GCN) | The key to training GNNs on billion-node graphs; enables GPU training without the full graph fitting in memory |
16. Conceptual Bridge
Looking Backward: What This Section Builds On
Graph Neural Networks are the synthesis of all earlier mathematics in this curriculum. From Chapter 2 Linear Algebra, we use matrix multiplication (the propagation rule ), eigenvectors (the spectral derivation of GCN), and norms (Dirichlet energy). From Chapter 3 Advanced Linear Algebra, we use eigendecomposition (Laplacian positional encodings), PSD theory (the Laplacian is PSD, proven in 11-04), and the spectral theorem (graph Fourier transform). From Chapter 4 Calculus, we use gradients and the chain rule for backpropagation through GNN layers. From Chapter 8 Optimization, we use SGD variants and the mini-batch training strategies of 11. From Chapter 11 Graph Theory itself, we have built on 11-04's spectral theory (the GCN derivation via Chebyshev polynomials, over-smoothing as diffusion, Laplacian eigenmaps as positional encodings) and 11-03's algorithms (BFS neighborhoods as the receptive field, max-flow as the analogy to Cheeger's inequality).
Looking Forward: What This Section Enables
Chapter 12 Functional Analysis generalizes the spectral ideas of GNNs to infinite-dimensional Hilbert spaces. The graph Fourier transform is the finite-dimensional analogue of the Fourier transform on ; the Laplacian eigenvectors are the finite-dimensional analogue of Fourier modes. Functional analysis provides the rigorous framework for this extension, enabling kernel methods on graphs and the theoretical analysis of continuous-limit GNNs (graphons).
Chapter 14 Math for Specific Models will revisit GNN architectures in the context of equivariant neural networks - networks that respect symmetries (rotation, reflection, permutation) by design. Geometric GNNs for 3D molecular data (DimeNet, SE(3)-Transformers, NequIP) extend the MPNN framework with -equivariant message functions, enabling prediction of orientation-dependent molecular properties.
Chapter 21 Statistical Learning Theory will study GNN generalization bounds: how many training graphs are needed for a GNN to generalize? The WL expressiveness hierarchy (7) directly informs these bounds - more expressive GNNs require more data to generalize.
Chapter 22 Causal Inference uses directed acyclic graphs (DAGs) - a special class of directed graphs studied in 11-01. The graph algorithms of 11-03 (topological sort for causal ordering) and the GNN architectures of 11-05 (for learning on causal graphs) connect these chapters.
GRAPH NEURAL NETWORKS - POSITION IN THE CURRICULUM
========================================================================
Ch.02-03: Linear Algebra, Eigenvalues, SVD
down
Ch.11-04: Spectral Graph Theory
| [Laplacian, GCN spectral derivation,
| over-smoothing preview, LapPE]
down
+==========================================+
| Ch.11-05: Graph Neural Networks | <- YOU ARE HERE
| |
| MPNN -> GCN -> GraphSAGE -> GAT -> GIN |
| WL Expressiveness -> Over-smoothing |
| Graph Transformers -> GPS/Graphormer |
| Applications: AlphaFold 2, Graph RAG |
+==========================================+
down down
Ch.11-06: Ch.12: Functional Analysis
Random Graphs [Graph limits (graphons),
[Graph models kernel methods on graphs,
for GNN infinite-dimensional
benchmarks] spectral theory]
down
Ch.14: Math for Specific Models
[Equivariant GNNs, SE(3)-Transformers,
3D molecular AI, geometric deep learning]
down
Ch.21: Statistical Learning Theory
[GNN generalization bounds,
WL-based complexity, PAC learning on graphs]
========================================================================
<- Back to Graph Theory | Previous: Spectral Graph Theory <- | Next: Random Graphs ->
Appendix A: Mathematical Derivations
A.1 Proof that the GCN Propagation Rule is Permutation Equivariant
Claim. The GCN layer is permutation equivariant: for any permutation matrix ,
Proof. Let . The degree matrix of is:
Wait - let me be precise. Since is a permutation, where is the permutation. Thus:
So , and:
Therefore:
The last step uses the fact that is applied element-wise and permutes rows: for any matrix and element-wise .
A.2 Universal Approximation of Injective Multiset Functions
The theoretical foundation of GIN rests on the following characterization:
Theorem (Xu et al. 2019, following Zaheer et al. 2017). Let be a countable set. A function on the space of finite multisets over is injective if and only if there exists and such that:
Proof sketch (sufficiency). Assume is countable: enumerate as . A multiset over is characterized by its multiplicity function where is the number of times appears. Choose (unique real value per element). Then:
This is an injective mapping from the multiplicity function to (by the uniqueness of representations in base , this holds for bounded multiplicities). Then can be any function that recovers from this sum.
Necessity. If is injective and we use sum aggregation , then different multisets must map to different sums. The existence of and satisfying this is guaranteed by injectivity and the separating power of sums over countable sets.
Consequence for GIN. With a sufficiently expressive (a deep MLP, by universal approximation) and (another deep MLP), GIN can represent any injective function on multisets of node features. This gives GIN the maximum discriminative power achievable by any MPNN.
A.3 Dirichlet Energy Decay Rate
Theorem. Let be the GCN propagation matrix with . For any node feature matrix :
where and is the normalized Laplacian of .
Proof. Write the eigendecomposition where is orthonormal and with (since is doubly stochastic after normalization). Then:
Since :
The eigenvalues of are . For (which holds since is a non-negative symmetric matrix with row sums at most 1 after self-loop normalization), we have .
Therefore: .
For the GCN with self-loops, (strictly), so exponentially fast.
A.4 Jacobian Bound for Over-Squashing
Theorem (Alon & Yahav, 2021). For a GNN with layers and bounded weight matrices and Lipschitz activation :
where is a constant depending on the architecture.
Key observation. The entry counts the weighted number of walks of length from to . For nodes and separated by a bottleneck edge (edge with large removal betweenness centrality), all walks from to must pass through . The bottleneck effect:
If node has degree , then at most of the walks that reach proceed to (in the normalized walk). Thus , which decays exponentially in when is large.
Practical consequence. Nodes separated by a high-degree hub (common in scale-free social networks) have near-zero Jacobian - the GNN cannot propagate useful gradient signal between them regardless of depth.
A.5 Graph Isomorphism Network: Full Algorithm
Algorithm: GIN for Graph Classification
Input: Graph with nodes Output: Graph embedding
-
Initialize: for all
-
For :
- For each node :
- Apply batch normalization to
-
Readout: for each layer , compute layer-specific graph embedding:
-
Concatenate:
-
Apply MLP classifier:
Why concatenate all layers? Each layer captures patterns at a different structural scale: layer 0 is node features; layer 1 is immediate neighborhood; layer is the -hop subgraph structure. Concatenating all layers allows the classifier to use patterns at all scales simultaneously, matching the JK-Net (Jumping Knowledge Networks) readout.
Choice of . In practice, works well - the MLP is expressive enough to learn the correct weighting. Learned adds one scalar parameter per layer and sometimes improves performance on sparse graphs where the self-feature has a very different scale from the aggregated neighborhood.
A.6 GPS Layer: Formal Specification
The GPS (General, Powerful, Scalable) Graph Transformer layer (Rampasek et al., 2022) is specified as follows.
Input: Node embeddings , adjacency , positional encodings
MPNN sub-layer. For each node :
The MPNN can be any of: GCN, GINE, GAT, GATv2. GINE (GIN with Edge features) is:
Transformer sub-layer. Standard multi-head self-attention over all nodes:
Feed-forward + LayerNorm:
where is a 2-layer MLP.
Complexity: for the MPNN sub-layer (sparse), for the Transformer sub-layer (dense). Total: . For sparse graphs where , the Transformer dominates.
Practical approximation. For large graphs (), replace the full Transformer with a linear attention approximation (Performer, Longformer-style local+global attention, or Nystromformer). This reduces Transformer cost to , making GPS scalable to large graphs while retaining the long-range attention benefit.
Appendix B: Implementation Notes
B.1 GCN in Matrix Form: Step-by-Step
Given a graph with adjacency matrix , node features , and weight matrices , :
Step 1: Compute A = A + I_n
Step 2: Compute D = diag(A 1) = diag of row sums
Step 3: Compute D^{-1/2}: take 1/sqrt of diagonal entries
Step 4: Compute A = D^{-1/2} A D^{-1/2} (sparse: O(m) entries)
Step 5: H^1 = ReLU(A X W^0) - shape: (n, d_1)
Step 6: H^2 = softmax(A H^1 W^1) - shape: (n, d_out) [for classification]
Step 7: Loss = -sum_{v in V_L} sum_c Y_{vc} log H^2_{vc}
Step 8: Backpropagate through H^2, H^1, W^1, W^0
Sparsity. for real-world graphs is extremely sparse (average degree 10 out of nodes means fill). All operations involving should use sparse matrix formats (CSR, COO). The computation for sparse and dense costs , not .
B.2 Edge List Representation and Scatter Operations
In practice, GNNs are implemented using an edge list representation rather than adjacency matrices. The graph is stored as two arrays:
edge_index: shape (2, m) - edge_index[0] = source nodes, edge_index[1] = target nodes
edge_attr: shape (m, d_e) - edge features
x: shape (n, d_v) - node features
The message-passing computation is implemented as:
# Gather source node features for all edges
messages = M(x[edge_index[0]], edge_attr) # shape: (m, d')
# Scatter messages to target nodes (sum aggregation)
agg = torch.zeros(n, d').scatter_add(0,
edge_index[1].unsqueeze(-1).expand(-1, d'),
messages) # shape: (n, d')
This scatter_add operation (implemented in PyTorch Geometric and DGL) is the fundamental primitive of GPU-accelerated GNN training. It is equivalent to a sparse matrix-vector product but more flexible (supports arbitrary message functions).
B.3 Mini-Batch Neighbor Sampling Implementation
GraphSAGE Training Step:
========================
1. Sample batch of target nodes: B \subseteq V, |B| = B_size
2. For l = L, L-1, ..., 1:
- For each node v in current frontier:
- Sample S_l neighbors: N_l(v) \subseteq N(v), |N_l(v)| = S_l
- Current frontier = B \cup 1-hop(B) \cup ... \cup l-hop(B) [sampled]
3. Extract induced subgraph of all frontier nodes
4. Run GNN on induced subgraph (full-batch over small subgraph)
5. Compute loss only on target nodes B
6. Backpropagate
Total nodes per training step: . For , , : nodes per step.
B.4 Practical GNN Hyperparameter Guide
| Hyperparameter | Typical Range | Notes |
|---|---|---|
| Number of layers | 2-4 | Start with 2; increase only if task requires long-range |
| Hidden dimension | 64-512 | Match to dataset size; 256 is a strong default |
| Dropout rate | 0.0-0.5 | 0.3-0.5 on feature/edge dropout helps regularize |
| Aggregation | Sum (GIN), Mean (GCN) | Sum for classification; mean for regression |
| Activation | ReLU, PReLU | PReLU slightly better on sparse graphs |
| Batch normalization | After each layer | Essential for deep GNNs (4+ layers) |
| Learning rate | 0.001-0.01 | Adam optimizer; cosine decay schedule |
| Neighbor sample sizes | or | Larger improves accuracy; diminishing returns |
| Weight decay | - | L2 regularization on weight matrices |
| Number of attention heads | 4-8 | For GAT; each head uses features |