Private notes
0/8000

Notes stay private to your browser until account sync is configured.

Part 3
26 min read18 headingsSplit lesson page

Lesson overview | Previous part | Next part

Graph Neural Networks: Part 11: Training and Scaling GNNs to Appendix B: Implementation Notes

11. Training and Scaling GNNs

11.1 Full-Batch vs Mini-Batch Training

Full-batch training. Compute the GNN on the entire graph in each step. Requires the entire graph and all node features to fit in GPU memory. Practical for graphs with n105n \leq 10^5 nodes. Exact gradient computation.

Mini-batch training. Sample a set of target nodes (a "mini-batch") and compute the GNN only on those nodes and their kk-hop neighborhoods. Much more memory-efficient but requires careful sampling to produce unbiased gradient estimates.

The challenge of mini-batch training for GNNs is the neighborhood explosion: if each node has degree dˉ\bar{d}, a kk-hop neighborhood has dˉk\bar{d}^k nodes. For dˉ=10\bar{d}=10, k=3k=3, this is 1000 nodes per target node - often larger than the desired mini-batch size.

Three approaches manage this explosion: neighbor sampling (11.2), graph partitioning (11.3), and subgraph sampling (11.4).

11.2 Neighbor Sampling

GraphSAGE neighbor sampling (Hamilton et al., 2017). For each target node at each layer, sample a fixed-size subset SlS_l of neighbors:

  • Layer 1 (nearest neighbors): sample S1=25S_1 = 25 neighbors
  • Layer 2 (2-hop): sample S2=10S_2 = 10 neighbors of each sampled layer-1 neighbor

Total nodes per target: S1×S2=250S_1 \times S_2 = 250. Constant cost per target node, regardless of graph degree.

Variance reduction. Random sampling introduces variance in the gradient estimate. This can be reduced by importance sampling (weight each sampled neighbor by dv/Sld_v / S_l) or by using multiple samples. In practice, Sl=5S_l = 5-2525 is sufficient for most tasks.

VR-GCN (Chen et al., 2018). Control variates method: maintain historical node embeddings hˉv[l]\bar{\mathbf{h}}_v^{[l]} as running averages. Use the historical embedding as a control variate for unsampled neighbors:

m^v=nSuS(huhˉu)+uN(v)hˉu\hat{\mathbf{m}}_v = \frac{n}{|S|}\sum_{u \in S} \left(\mathbf{h}_u - \bar{\mathbf{h}}_u\right) + \sum_{u \in \mathcal{N}(v)} \bar{\mathbf{h}}_u

This gives an unbiased estimate of the full-batch gradient with lower variance than simple neighbor sampling.

11.3 Cluster-GCN

Cluster-GCN (Chiang et al., 2019) takes a graph partitioning approach: partition the graph into CC balanced clusters V1,,VC\mathcal{V}_1, \ldots, \mathcal{V}_C using METIS or other graph partitioning algorithms. A mini-batch consists of one or more clusters.

Key insight. Within a cluster, most edges are intra-cluster (partitioning algorithms minimize the cut). Running GCN on the induced subgraph of a cluster produces embeddings that see most relevant edges. Between-cluster edges (the cut) are ignored - this introduces approximation error, but the cut is small by design.

Memory: a cluster of size Vc=n/C\lvert\mathcal{V}_c\rvert = n/C requires O(n/C)O(n/C) node features + the induced adjacency. For C=100C=100, memory is 100×100\times smaller than full-batch training.

Variance: the gradient approximation error comes from ignoring cross-cluster edges. To reduce this, Cluster-GCN uses multiple cluster sampling: each mini-batch consists of qq randomly selected clusters i=1qVci\bigcup_{i=1}^q \mathcal{V}_{c_i}, retaining all edges within the union. This recovers some cross-cluster edges.

11.4 GraphSAINT: Subgraph Sampling

GraphSAINT (Zeng et al., 2020) frames mini-batch training as subgraph sampling: each mini-batch is a subgraph G[S]G[S] induced by a sampled node set SVS \subseteq V.

Three sampling strategies:

  1. Node sampler: sample S|S| nodes uniformly; include all edges between sampled nodes. Simple but misses important edges.
  2. Edge sampler: sample SE|S_E| edges uniformly; include all nodes incident to sampled edges. Better edge coverage.
  3. Random walk sampler: start rr random walks from random nodes; include all visited nodes (and their induced edges). Produces connected subgraphs with richer local structure.

Normalization. Since different edges are sampled with different probabilities, the gradient is biased without correction. GraphSAINT computes the sampling probability p(v)p(v) for each node and p(u,v)p(u,v) for each edge analytically (or by estimation), then reweights the loss:

L(G[S])=1SvS1np(v)(y^v,yv)\mathcal{L}(G[S]) = \frac{1}{\lvert S \rvert} \sum_{v \in S} \frac{1}{n \cdot p(v)} \ell(\hat{y}_v, y_v)

This gives an unbiased gradient estimate and removes systematic bias from subgraph sampling.

Scaling. GraphSAINT successfully trains GNNs on Flickr (89K nodes, 900K edges), Reddit (232K nodes, 114M edges), and Yelp (716K nodes, 13M edges) - orders of magnitude larger than full-batch training can handle.


12. Applications in Machine Learning

12.1 Molecular Property Prediction and Drug Discovery

Molecular property prediction is the "killer application" of GNNs. A molecule is naturally a graph: atoms are nodes (with features: atomic number, hybridization, formal charge, chirality), and bonds are edges (with features: bond type, aromaticity, stereo configuration).

MPNN for quantum chemistry (Gilmer et al., 2017). Applied to the QM9 dataset (134K small organic molecules, 12 quantum chemical properties including dipole moment, polarizability, and HOMO-LUMO gap). The MPNN with edge-network message functions achieves state-of-the-art on 11 of 12 properties, demonstrating that learned graph representations outperform hand-engineered molecular descriptors.

SchNet (Schutt et al., 2017). A rotationally equivariant GNN for 3D molecular geometry. Nodes are atoms; edges connect atoms within a cutoff radius rcr_c with edge features encoding pairwise distances. Uses continuous-filter convolutional layers: muv=huffilter(rurv)\mathbf{m}_{uv} = \mathbf{h}_u \odot f_\text{filter}(\lVert\mathbf{r}_u - \mathbf{r}_v\rVert) where ffilterf_\text{filter} is an MLP applied to interatomic distance. Achieves chemical accuracy on QM9 for most properties.

AlphaFold 2 (Jumper et al., 2021). The protein structure prediction breakthrough. Key GNN component: the Evoformer, which processes pairs of residues (edge features) and individual residues (node features) through an iterative process. The structure module then uses equivariant graph operations to predict 3D residue coordinates. AlphaFold 2 won the CASP14 protein structure prediction competition with >90%>90\% accuracy on most targets - a problem open for 50 years.

Drug discovery pipeline. Modern computational drug discovery uses GNNs at every stage:

  • Virtual screening: predict binding affinity between drug candidates (small molecules) and protein targets; GNNs on molecular graphs reduce wet-lab screening costs by 100×\sim 100\times
  • ADMET prediction: predict absorption, distribution, metabolism, excretion, and toxicity properties from molecular structure
  • De novo design: generative GNNs (junction tree VAE, graph diffusion models) design novel molecules with desired properties

12.2 Knowledge Graph Reasoning

A knowledge graph (KG) is a directed, heterogeneous graph G=(V,E,R)G = (V, E, R) where VV is a set of entities, RR is a set of relation types, and EV×R×VE \subseteq V \times R \times V is a set of triples (subject, relation, object): e.g., (Einstein, bornIn, Germany), (Germany, locatedIn, Europe).

Triple completion (link prediction). Given an incomplete KG, predict missing triples: is (Einstein, worksWith, Bohr) a true triple? Key methods:

  • TransE (Bordes et al., 2013): model relation rr as a translation in embedding space: hs+rho\mathbf{h}_s + \mathbf{r} \approx \mathbf{h}_o for true triples (s,r,o)(s,r,o). Simple and effective for 1-to-1 relations.
  • RotatE (Sun et al., 2019): model relation as a rotation in complex space: hsr=ho\mathbf{h}_s \circ \mathbf{r} = \mathbf{h}_o where \circ is element-wise complex multiplication. Captures symmetry, antisymmetry, and composition patterns.

R-GCN (Schlichtkrull et al., 2018). GNN for relational (heterogeneous) graphs. Separate weight matrix WrW_r for each relation type:

hv[l+1]=σ ⁣(W0hv[l]+rRuNr(v)1cv,rWrhu[l])\mathbf{h}_v^{[l+1]} = \sigma\!\left(W_0 \mathbf{h}_v^{[l]} + \sum_{r \in R} \sum_{u \in \mathcal{N}_r(v)} \frac{1}{c_{v,r}} W_r \mathbf{h}_u^{[l]}\right)

where Nr(v)\mathcal{N}_r(v) are the neighbors of vv through relation rr and cv,rc_{v,r} is a normalization constant. Used for entity classification and relation prediction in Freebase and AIFB knowledge graphs.

For AI: In 2024, knowledge graphs power retrieval-augmented generation systems. Microsoft's Graph RAG (Edge et al., 2024) builds a knowledge graph from documents, then uses GNNs (and graph traversal) to answer complex multi-hop questions that flat vector retrieval cannot handle.

12.3 Recommendation Systems

Recommendation is a bipartite graph problem: users UU and items II form nodes; interactions (clicks, purchases, ratings) form edges. The task is link prediction: predict which unobserved (u,i)(u, i) pairs represent true preferences.

LightGCN (He et al., 2020). Simplifies GCN for collaborative filtering by removing the feature transformation and nonlinearity, keeping only the graph smoothing:

hu[l+1]=iNu1dudihi[l],hi[l+1]=uNi1diduhu[l]\mathbf{h}_u^{[l+1]} = \sum_{i \in \mathcal{N}_u} \frac{1}{\sqrt{d_u d_i}} \mathbf{h}_i^{[l]}, \qquad \mathbf{h}_i^{[l+1]} = \sum_{u \in \mathcal{N}_i} \frac{1}{\sqrt{d_i d_u}} \mathbf{h}_u^{[l]}

Final embeddings: hv=l=0Lαlhv[l]\mathbf{h}_v = \sum_{l=0}^L \alpha_l \mathbf{h}_v^{[l]} (JK-style layer combination). Prediction: y^ui=huhi\hat{y}_{ui} = \mathbf{h}_u^\top \mathbf{h}_i.

LightGCN outperforms standard GCN for collaborative filtering, suggesting that the feature transformation is unnecessary (or harmful) when the only input feature is an ID embedding. The key contribution is the propagation of user-item signals through multi-hop paths.

PinSage (already discussed in 5.5) extends GraphSAGE to the bipartite user-pin graph with importance-based sampling and hard negative mining.

12.4 Code and Program Analysis

Source code can be represented as multiple graphs simultaneously:

  • AST (Abstract Syntax Tree): hierarchical tree showing syntactic structure
  • CFG (Control Flow Graph): nodes are basic blocks; edges show execution flow (if/else, loops)
  • DFG (Data Flow Graph): edges connect variable definitions to their uses
  • Call Graph: nodes are functions; edges connect callers to callees

code2vec (Alon et al., 2019). Represents code snippets as bags of AST paths; learns path embeddings and aggregates them for downstream tasks. While not a full GNN, it demonstrates the power of structural code representations.

GNNs for bug detection. Allamanis et al. (2018) build heterogeneous graphs from Python/C code with AST, data flow, and control flow edges; train a GNN to classify whether a variable name is misused. Achieves high precision on detecting certain classes of bugs.

Program synthesis. GNNs operating on execution traces (program state as graph) guide neural program search. DeepCoder (Balog et al., 2017) and more recent systems use graph representations of input-output examples to synthesize programs.

12.5 LLM Integration: Graph RAG and Structure-Aware Language Models

The 2024-2026 frontier is combining GNNs with large language models - exploiting the complementary strengths of relational structure (GNNs) and natural language understanding (LLMs).

Graph RAG (Edge et al., 2024, Microsoft Research). Standard RAG (retrieval-augmented generation) retrieves flat text chunks. Graph RAG builds a knowledge graph from documents and uses graph community detection (Leiden algorithm) to generate hierarchical summaries. At query time, relevant communities are retrieved and their summaries are used to ground the LLM response.

Key insight: complex questions requiring synthesis across many documents (e.g., "What are the main themes in this corpus?") are better answered by traversing a knowledge graph than by retrieving isolated text chunks. Graph RAG achieves 72\sim 72% win rate over naive RAG on these "global sensemaking" queries.

GNN+LLM for molecular generation. LLMs can generate SMILES strings (text representations of molecules), but they struggle with structural validity constraints (valence, aromaticity). GNN-based validity checkers or graph-structured decoders can enforce these constraints.

Graph-structured memory for agents. Recent AI agents (2025) maintain working memory as knowledge graphs updated with each observation. A GNN processes the graph state to inform action selection. This enables multi-hop reasoning ("I know A->B->C, so if I see C, I should infer A") that flat token sequences handle poorly.


13. Common Mistakes

#MistakeWhy It's WrongFix
1Using mean aggregation for graph classification when graph size variesMean normalizes by node count - two graphs with identical local structure but different sizes get identical embeddings, despite being differentUse sum aggregation (which preserves size information) or pair with a size feature
2Forgetting self-loops in GCNWithout self-loops (A~=A+I\tilde{A} = A + I), a node does not include its own features in aggregation - it only receives neighbors' information, losing its own signalAlways add II to AA before normalization; this is the renormalization trick from Kipf & Welling
3Adding too many GCN layers and blaming model capacityDeep GCNs fail due to over-smoothing, not lack of capacity - adding parameters won't helpUse 2-4 layers; add residual connections (GCNII), DropEdge, or PairNorm if more depth is needed
4Treating GAT attention weights as feature importanceAttention weights tell you which neighbors were weighted more, not which features were important - this is the "attention is not explanation" problemUse gradient-based attribution (GradCAM, Integrated Gradients) for feature importance; treat attention as architectural choice, not explanation
5Using GIN with mean aggregationGIN's theoretical power (matching 1-WL) requires sum aggregation with an injective MLP. Mean aggregation in GIN is worse than GCN in expressivenessAlways use sum aggregation in GIN; verify that the MLP has sufficient depth (\geq2 layers)
6Forgetting that 1-WL cannot detect trianglesAny MPNN is bounded by 1-WL; 1-WL cannot count triangles or cliques. If your task requires triangle detection (e.g., social network analysis), a standard GNN will failAdd structural features (triangle count, RWSE) or use a higher-order GNN (NGNN, subgraph GNN)
7Normalizing node features but not edge featuresUnnormalized edge features with large variance can dominate the attention scores in GAT or the message in MPNNNormalize edge features to zero mean, unit variance; use LayerNorm before message computation
8Using transductive GCN for inductive tasksGCN's normalized adjacency is computed on the training graph; new nodes at test time require recomputing the entire adjacency and rerunning the networkUse inductive methods (GraphSAGE, GAT) that learn aggregation functions, not fixed propagation matrices
9Applying global pooling before sufficient local aggregationWith only 1 GNN layer before readout, each node only knows its immediate neighbors; graph-level representations lack structural contextUse 3-5 GNN layers before readout; consider hierarchical pooling (DiffPool) for hierarchically structured graphs
10Ignoring the over-squashing problem for long-range tasksIf the task requires integrating information from nodes far apart, a shallow MPNN will fail due to over-squashing even without over-smoothingUse graph rewiring (DIGL, EGP) or graph transformers (GPS, Graphormer) for long-range tasks; measure effective receptive field
11Training GNN on the test graph for transductive semi-supervised learningUsing the test graph topology during training is allowed (transductive setting), but using test node labels is a data leakage errorOnly mask the labels of test nodes; the full graph adjacency is legitimately used during both training and testing in the transductive setting
12Confusing graph-level and node-level tasks in the readoutUsing node embeddings directly for graph classification (without pooling) produces an embedding for each node, not for the graph - dimensions won't matchAlways apply a readout function (sum/mean/attention pooling) after the final GNN layer for graph-level tasks

14. Exercises

Exercise 1 * - MPNN Implementation

Implement a 2-layer MPNN on a small graph and compute node representations from scratch.

(a) Construct the adjacency matrix AA for the graph: nodes {0,1,2,3,4}\{0,1,2,3,4\}, edges {(0,1),(1,2),(2,3),(3,4),(0,3)}\{(0,1),(1,2),(2,3),(3,4),(0,3)\}.

(b) Initialize node features X=I5X = I_5 (identity matrix - each node has a one-hot representation).

(c) Implement one GCN layer: H[1]=ReLU(A^XW[0])H^{[1]} = \operatorname{ReLU}(\hat{A} X W^{[0]}) where A^=D~1/2A~D~1/2\hat{A} = \tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}, A~=A+I\tilde{A} = A + I, and W[0]R5×3W^{[0]} \in \mathbb{R}^{5 \times 3} is a random matrix (fixed seed).

(d) Apply a second GCN layer to get H[2]R5×2H^{[2]} \in \mathbb{R}^{5 \times 2}.

(e) Verify that permuting the nodes (applying a permutation PP to rows of XX and both rows and columns of AA) produces permuted outputs PH[2]P H^{[2]}.


Exercise 2 * - GCN Propagation Matrix

Analyze the spectral properties of the GCN propagation matrix A^=D~1/2A~D~1/2\hat{A} = \tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}.

(a) For a path graph P5P_5 (5 nodes in a line), compute A^\hat{A} and find its eigenvalues.

(b) Show that all eigenvalues of A^\hat{A} lie in (1,1](-1, 1] by relating A^\hat{A} to the normalized Laplacian Lsym=IA^L_{\text{sym}} = I - \hat{A} and using the PSD property of LsymL_{\text{sym}}.

(c) Compute A^kX\hat{A}^k X for k=1,2,4,8k = 1, 2, 4, 8 and visualize the row norms as a function of kk. Observe over-smoothing: the rows converge.

(d) Show that the limiting matrix A^X\hat{A}^\infty X has all rows proportional to π\boldsymbol{\pi}, the stationary distribution. Compute π\boldsymbol{\pi} for P5P_5.


Exercise 3 * - Aggregation Expressiveness

Demonstrate that sum aggregation is strictly more expressive than mean for multisets.

(a) Construct two multisets M1={1,1}\mathcal{M}_1 = \{1, 1\} and M2={1,1,1}\mathcal{M}_2 = \{1, 1, 1\} of node features (scalar, value 1). Compute sum(M1\mathcal{M}_1), mean(M1\mathcal{M}_1), sum(M2\mathcal{M}_2), mean(M2\mathcal{M}_2). Verify that sum distinguishes them but mean does not.

(b) Construct two multisets M3={1,2}\mathcal{M}_3 = \{1, 2\} and M4={2}\mathcal{M}_4 = \{2\}. Show that max(M3\mathcal{M}_3) = max(M4\mathcal{M}_4) = 2, but sum(M3\mathcal{M}_3) \neq sum(M4\mathcal{M}_4).

(c) For two graphs G1=K1,3G_1 = K_{1,3} (star with 3 leaves) and G2=P4G_2 = P_4 (path of 4 nodes), show that a GCN with mean aggregation assigns identical representations to some nodes, but a GIN with sum aggregation does not.

(d) Implement the WL color refinement algorithm for G1G_1 and G2G_2 with uniform initial colors. How many iterations until stable? Does WL distinguish G1G_1 from G2G_2?


Exercise 4 ** - Weisfeiler-Leman Test

Implement the 1-WL algorithm and apply it to pairs of graphs.

(a) Implement wl_isomorphism_test(G1, G2, max_iter=10): run 1-WL color refinement on both graphs simultaneously; return "MAYBE_ISOMORPHIC" if the final color histograms match, or "NOT_ISOMORPHIC" at the first iteration where they differ.

(b) Test on G1=C6G_1 = C_6 (6-cycle) and G2=K3,3G_2 = K_{3,3} (complete bipartite, 6 nodes, 9 edges). What does WL conclude? Are they actually isomorphic?

(c) Construct two non-isomorphic 3-regular graphs on 6 nodes (K3,3K_{3,3} and the prism graph Y3Y_3). Does 1-WL distinguish them? (Hint: both are 3-regular, so the degree-based first iteration is identical.)

(d) Explain why any MPNN with mean aggregation would fail to distinguish these graphs, while GIN with sum aggregation would succeed (or also fail, if WL itself fails).


Exercise 5 ** - GAT Attention Mechanism

Implement a single GAT attention head from scratch.

(a) For a graph with 5 nodes and edge set E={(0,1),(1,2),(2,3),(3,4),(4,0),(0,2)}E = \{(0,1),(1,2),(2,3),(3,4),(4,0),(0,2)\}, compute attention coefficients αuv\alpha_{uv} for all edges.

(b) Use d=4d=4, d=3d'=3, random WR3×4W \in \mathbb{R}^{3 \times 4} and aR6\mathbf{a} \in \mathbb{R}^6 (fixed seed). Initialize node features HR5×4H \in \mathbb{R}^{5 \times 4} randomly.

(c) Compute the attention logits euv=a[zvzu]e_{uv} = \mathbf{a}^\top [\mathbf{z}_v \| \mathbf{z}_u] with LeakyReLU (slope 0.2) for all edges, then apply softmax over each node's neighborhood.

(d) Show the difference between GAT and GATv2: compute GATv2 attention scores euvv2=aLeakyReLU(W[hvhu])e_{uv}^{\text{v2}} = \mathbf{a}^\top \operatorname{LeakyReLU}(W[\mathbf{h}_v \| \mathbf{h}_u]) and verify that the neighbor ranking for node 0 can differ between GAT and GATv2.

(e) Compute the updated node representations H=σ(attention-weighted aggregation of H)H' = \sigma(\text{attention-weighted aggregation of }H).


Exercise 6 ** - Over-Smoothing Dynamics

Quantify over-smoothing as a function of GCN depth.

(a) Construct a Cora-like graph: generate a stochastic block model with 200 nodes, 4 blocks, within-block edge probability p=0.15p=0.15, between-block probability q=0.01q=0.01.

(b) Initialize node features as class one-hot vectors (H[0]=H^{[0]} = class assignments as a 200×4200 \times 4 matrix).

(c) Apply the GCN propagation A^\hat{A} (no weight matrix, no nonlinearity - pure smoothing) for L=1,2,4,8,16,32L = 1, 2, 4, 8, 16, 32 steps. Compute the Dirichlet energy E(H[L])=tr(H[L]LH[L])E(H^{[L]}) = \operatorname{tr}(H^{[L]\top} L H^{[L]}) at each depth.

(d) Plot E(H[L])E(H^{[L]}) vs LL on a log scale. Fit an exponential ECλLE \approx C \lambda^L and estimate λ\lambda.

(e) Compare with the theoretical rate: λλ2(Lrw)2\lambda \approx \lambda_2(L_{\text{rw}})^2 where λ2\lambda_2 is the Fiedler value of the random-walk Laplacian. Compute λ2\lambda_2 numerically and compare.


Exercise 7 *** - GIN Implementation and Expressiveness

Implement GIN and compare its expressiveness to GCN on graph-level classification.

(a) Implement a 3-layer GIN with sum aggregation and 2-layer MLP update: hv[l+1]=MLP[l]((1+ε)hv[l]+uN(v)hu[l])\mathbf{h}_v^{[l+1]} = \operatorname{MLP}^{[l]}((1+\varepsilon)\mathbf{h}_v^{[l]} + \sum_{u \in \mathcal{N}(v)}\mathbf{h}_u^{[l]}) with ε=0\varepsilon = 0.

(b) Construct two non-isomorphic graphs G1=C6G_1 = C_6 and G2=C3C3G_2 = C_3 \cup C_3 (as computed in 7.2, WL cannot distinguish them without initial node features). Initialize all nodes with the same feature vector x=1Rd\mathbf{x} = \mathbf{1} \in \mathbb{R}^d.

(c) Show that with uniform initial features, neither GIN nor GCN can distinguish G1G_1 and G2G_2 (both produce identical graph-level sum representations). Why?

(d) Now add degree as an initial node feature: xv=[dv]R1\mathbf{x}_v = [d_v] \in \mathbb{R}^1. Rerun both GIN and GCN. Does GIN now distinguish the two graphs? Does GCN?

(e) Generate a dataset of 500 random graphs (mix of Erdos-Renyi and stochastic block models); train GIN and GCN for graph-level binary classification. Report test accuracy. Verify that GIN \geq GCN accuracy.


Exercise 8 *** - Graph Transformer with Positional Encoding

Implement a simplified graph transformer layer with Laplacian positional encoding.

(a) For a graph with n=20n=20 nodes (Barabasi-Albert model with m=2m=2), compute the normalized Laplacian LsymL_{\text{sym}} and extract the first k=4k=4 non-trivial eigenvectors UkR20×4U_k \in \mathbb{R}^{20 \times 4}.

(b) Initialize node features XR20×8X \in \mathbb{R}^{20 \times 8} randomly and augment with LapPE: Xaug=[XUk]R20×12X^{\text{aug}} = [X \| U_k] \in \mathbb{R}^{20 \times 12}.

(c) Implement one layer of scaled dot-product self-attention over all 20 nodes (fully connected, ignoring graph edges): Attention(Q,K,V)=softmax(QK/dk)V\operatorname{Attention}(Q, K, V) = \operatorname{softmax}(QK^\top / \sqrt{d_k})V.

(d) Implement one GCN layer on the same graph. Visualize the attention matrices of both: which nodes attend to which? Does the graph transformer's attention recover some of the graph structure (do nearby nodes attend to each other more)?

(e) Compare: run 3 iterations of graph transformer attention vs 3 layers of GCN on the same random features. Compute the Dirichlet energy after each: which decreases faster? What does this reveal about over-smoothing in graph transformers vs GNNs?


15. Why This Matters for AI (2026 Perspective)

ConceptAI Impact
MPNN frameworkThe unifying abstraction behind AlphaFold 2 (protein structure), drug screening, and materials design GNNs. Every new spatial GNN architecture published in 2024-2026 is a special case
GCN propagation ruleDirectly computable from sparse adjacency; used in LightGCN powering Netflix/Pinterest/TikTok recommendation at billion-node scale
GraphSAGE inductive learningEnables daily-updated embeddings for dynamic graphs (social networks, e-commerce) without retraining the full model
GAT / GATv2Learned, sparse attention patterns over structured data; used in molecular GNNs for drug-target interaction prediction (AstraZeneca, Recursion Pharmaceuticals)
WL expressiveness theoremTheoretical bound establishing what structure any MPNN can detect; determines when to add structural features or upgrade to subgraph GNNs; guides architecture search
GIN sum aggregationThe foundation for most molecular property prediction models; identifies that all mainstream GCN deployments are sub-optimal in expressiveness
Over-smoothing analysisExplains why 2-layer GNNs outperform 8-layer GNNs in most production deployments; directly informs depth tuning decisions
Over-squashing and graph rewiringMotivates graph preprocessing in biochemistry (adding virtual bonds between distant atoms) and knowledge graph reasoning (adding transitive edges)
Graph Transformers (GPS, Graphormer)State-of-the-art on quantum chemistry benchmarks (QM9, OGB-LSC); powering the next generation of molecular AI models
LapPE and RWSEStandard input features for all graph foundation models; analogous to positional encodings in LLMs - without them, graph models cannot distinguish relative node positions
Graph RAGMicrosoft's production system (Azure AI Search) for document understanding queries requiring multi-hop reasoning; deployed across Office 365 and Microsoft 365 Copilot
Hierarchical pooling (DiffPool, MinCutPool)Used in drug candidate scoring for proteins and polymer graphs where hierarchical structure (residues -> domains -> whole protein) is essential
Neighbor sampling (GraphSAGE, Cluster-GCN)The key to training GNNs on billion-node graphs; enables GPU training without the full graph fitting in memory

16. Conceptual Bridge

Looking Backward: What This Section Builds On

Graph Neural Networks are the synthesis of all earlier mathematics in this curriculum. From Chapter 2 Linear Algebra, we use matrix multiplication (the propagation rule A^H\hat{A}H), eigenvectors (the spectral derivation of GCN), and norms (Dirichlet energy). From Chapter 3 Advanced Linear Algebra, we use eigendecomposition (Laplacian positional encodings), PSD theory (the Laplacian is PSD, proven in 11-04), and the spectral theorem (graph Fourier transform). From Chapter 4 Calculus, we use gradients and the chain rule for backpropagation through GNN layers. From Chapter 8 Optimization, we use SGD variants and the mini-batch training strategies of 11. From Chapter 11 Graph Theory itself, we have built on 11-04's spectral theory (the GCN derivation via Chebyshev polynomials, over-smoothing as diffusion, Laplacian eigenmaps as positional encodings) and 11-03's algorithms (BFS neighborhoods as the receptive field, max-flow as the analogy to Cheeger's inequality).

Looking Forward: What This Section Enables

Chapter 12 Functional Analysis generalizes the spectral ideas of GNNs to infinite-dimensional Hilbert spaces. The graph Fourier transform is the finite-dimensional analogue of the Fourier transform on L2(R)L^2(\mathbb{R}); the Laplacian eigenvectors are the finite-dimensional analogue of Fourier modes. Functional analysis provides the rigorous framework for this extension, enabling kernel methods on graphs and the theoretical analysis of continuous-limit GNNs (graphons).

Chapter 14 Math for Specific Models will revisit GNN architectures in the context of equivariant neural networks - networks that respect symmetries (rotation, reflection, permutation) by design. Geometric GNNs for 3D molecular data (DimeNet, SE(3)-Transformers, NequIP) extend the MPNN framework with SO(3)SO(3)-equivariant message functions, enabling prediction of orientation-dependent molecular properties.

Chapter 21 Statistical Learning Theory will study GNN generalization bounds: how many training graphs are needed for a GNN to generalize? The WL expressiveness hierarchy (7) directly informs these bounds - more expressive GNNs require more data to generalize.

Chapter 22 Causal Inference uses directed acyclic graphs (DAGs) - a special class of directed graphs studied in 11-01. The graph algorithms of 11-03 (topological sort for causal ordering) and the GNN architectures of 11-05 (for learning on causal graphs) connect these chapters.

GRAPH NEURAL NETWORKS - POSITION IN THE CURRICULUM
========================================================================

  Ch.02-03: Linear Algebra, Eigenvalues, SVD
       down
  Ch.11-04: Spectral Graph Theory
       |    [Laplacian, GCN spectral derivation,
       |     over-smoothing preview, LapPE]
       down
  +==========================================+
  |  Ch.11-05: Graph Neural Networks         |  <- YOU ARE HERE
  |                                          |
  |  MPNN -> GCN -> GraphSAGE -> GAT -> GIN     |
  |  WL Expressiveness -> Over-smoothing      |
  |  Graph Transformers -> GPS/Graphormer     |
  |  Applications: AlphaFold 2, Graph RAG    |
  +==========================================+
       down                    down
  Ch.11-06:            Ch.12: Functional Analysis
  Random Graphs        [Graph limits (graphons),
  [Graph models        kernel methods on graphs,
   for GNN             infinite-dimensional
   benchmarks]         spectral theory]
       down
  Ch.14: Math for Specific Models
  [Equivariant GNNs, SE(3)-Transformers,
   3D molecular AI, geometric deep learning]
       down
  Ch.21: Statistical Learning Theory
  [GNN generalization bounds,
   WL-based complexity, PAC learning on graphs]

========================================================================

<- Back to Graph Theory | Previous: Spectral Graph Theory <- | Next: Random Graphs ->


Appendix A: Mathematical Derivations

A.1 Proof that the GCN Propagation Rule is Permutation Equivariant

Claim. The GCN layer H=σ(A^HW)H' = \sigma(\hat{A}HW) is permutation equivariant: for any permutation matrix PP,

σ ⁣(PAP^PHW)=Pσ ⁣(A^HW)\sigma\!\left(\widehat{PAP^\top} \cdot PH \cdot W\right) = P \cdot \sigma\!\left(\hat{A}HW\right)

Proof. Let A~=PAP+I=P(A+I)P=PA~P\tilde{A}' = PAP^\top + I = P(A+I)P^\top = P\tilde{A}P^\top. The degree matrix of A~\tilde{A}' is:

D~ii=j(PA~P)ij=jkPikA~kl(P)lj=kA~ki=D~ii\tilde{D}'_{ii} = \sum_j (P\tilde{A}P^\top)_{ij} = \sum_j \sum_k P_{ik}\tilde{A}_{kl}(P^\top)_{lj} = \sum_k \tilde{A}_{ki} = \tilde{D}_{ii}

Wait - let me be precise. Since PP is a permutation, (PA~P)ij=A~π1(i),π1(j)(P\tilde{A}P^\top)_{ij} = \tilde{A}_{\pi^{-1}(i),\pi^{-1}(j)} where π\pi is the permutation. Thus:

D~ii=jA~π1(i),π1(j)=D~π1(i),π1(i)=(PD~P)ii\tilde{D}'_{ii} = \sum_j \tilde{A}_{\pi^{-1}(i),\pi^{-1}(j)} = \tilde{D}_{\pi^{-1}(i),\pi^{-1}(i)} = (P\tilde{D}P^\top)_{ii}

So D~=PD~P\tilde{D}' = P\tilde{D}P^\top, and:

A^=D~1/2A~D~1/2=(PD~P)1/2(PA~P)(PD~P)1/2\hat{A}' = \tilde{D}'^{-1/2}\tilde{A}'\tilde{D}'^{-1/2} = (P\tilde{D}P^\top)^{-1/2}(P\tilde{A}P^\top)(P\tilde{D}P^\top)^{-1/2} =PD~1/2PPA~PPD~1/2P=P(D~1/2A~D~1/2)P=PA^P= P\tilde{D}^{-1/2}P^\top \cdot P\tilde{A}P^\top \cdot P\tilde{D}^{-1/2}P^\top = P(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2})P^\top = P\hat{A}P^\top

Therefore:

σ ⁣(A^(PH)W)=σ ⁣(PA^PPHW)=σ ⁣(PA^HW)=Pσ ⁣(A^HW)\sigma\!\left(\hat{A}'(PH)W\right) = \sigma\!\left(P\hat{A}P^\top P H W\right) = \sigma\!\left(P\hat{A}HW\right) = P\sigma\!\left(\hat{A}HW\right) \quad \blacksquare

The last step uses the fact that σ\sigma is applied element-wise and PP permutes rows: σ(PZ)=Pσ(Z)\sigma(PZ) = P\sigma(Z) for any matrix ZZ and element-wise σ\sigma.

A.2 Universal Approximation of Injective Multiset Functions

The theoretical foundation of GIN rests on the following characterization:

Theorem (Xu et al. 2019, following Zaheer et al. 2017). Let X\mathcal{X} be a countable set. A function f:M(X)Rdf: \mathcal{M}(\mathcal{X}) \to \mathbb{R}^d on the space of finite multisets over X\mathcal{X} is injective if and only if there exists φ:XRd\varphi: \mathcal{X} \to \mathbb{R}^d and g:RdRdg: \mathbb{R}^d \to \mathbb{R}^d such that:

f(M)=g ⁣(xMφ(x))f(\mathcal{M}) = g\!\left(\sum_{x \in \mathcal{M}} \varphi(x)\right)

Proof sketch (sufficiency). Assume X\mathcal{X} is countable: enumerate as X={a1,a2,}\mathcal{X} = \{a_1, a_2, \ldots\}. A multiset M\mathcal{M} over X\mathcal{X} is characterized by its multiplicity function m:XNm: \mathcal{X} \to \mathbb{N} where m(ak)m(a_k) is the number of times aka_k appears. Choose φ(ak)=ek\varphi(a_k) = e^{-k} (unique real value per element). Then:

xMφ(x)=km(ak)ek\sum_{x \in \mathcal{M}} \varphi(x) = \sum_k m(a_k) \cdot e^{-k}

This is an injective mapping from the multiplicity function mm to R\mathbb{R} (by the uniqueness of representations in base ee, this holds for bounded multiplicities). Then gg can be any function that recovers ff from this sum.

Necessity. If ff is injective and we use sum aggregation φ(x)\sum \varphi(x), then different multisets must map to different sums. The existence of φ\varphi and gg satisfying this is guaranteed by injectivity and the separating power of sums over countable sets. \blacksquare

Consequence for GIN. With a sufficiently expressive φ\varphi (a deep MLP, by universal approximation) and gg (another deep MLP), GIN can represent any injective function on multisets of node features. This gives GIN the maximum discriminative power achievable by any MPNN.

A.3 Dirichlet Energy Decay Rate

Theorem. Let S=D~1/2A~D~1/2S = \tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2} be the GCN propagation matrix with A~=A+I\tilde{A} = A + I. For any node feature matrix HH:

E(SH)λmax(S)2E(H)E(SH) \leq \lambda_{\max}(S)^2 \cdot E(H)

where E(H)=tr(HLA~H)E(H) = \operatorname{tr}(H^\top L_{\tilde{A}} H) and LA~=ISL_{\tilde{A}} = I - S is the normalized Laplacian of A~\tilde{A}.

Proof. Write the eigendecomposition S=UΛUS = U\Lambda U^\top where UU is orthonormal and Λ=diag(λ1,,λn)\Lambda = \operatorname{diag}(\lambda_1, \ldots, \lambda_n) with λi1|\lambda_i| \leq 1 (since SS is doubly stochastic after normalization). Then:

E(SH)=tr((SH)L(SH))=tr(HSLSH)E(SH) = \operatorname{tr}((SH)^\top L(SH)) = \operatorname{tr}(H^\top S^\top L S H)

Since LA~=ISL_{\tilde{A}} = I - S:

SLS=S(IS)S=S2S3S^\top L S = S(I-S)S = S^2 - S^3

The eigenvalues of S2S3S^2 - S^3 are λi2λi3=λi2(1λi)\lambda_i^2 - \lambda_i^3 = \lambda_i^2(1-\lambda_i). For λi[0,1]\lambda_i \in [0,1] (which holds since A^\hat{A} is a non-negative symmetric matrix with row sums at most 1 after self-loop normalization), we have λi2(1λi)λi2λmax2\lambda_i^2(1-\lambda_i) \leq \lambda_i^2 \leq \lambda_{\max}^2.

Therefore: tr(HSLSH)λmax2tr(HLH)=λmax2E(H)\operatorname{tr}(H^\top S^\top L S H) \leq \lambda_{\max}^2 \operatorname{tr}(H^\top L H) = \lambda_{\max}^2 E(H). \blacksquare

For the GCN with self-loops, λmax(S)<1\lambda_{\max}(S) < 1 (strictly), so E(H[L])λmax2LE(H[0])0E(H^{[L]}) \leq \lambda_{\max}^{2L} E(H^{[0]}) \to 0 exponentially fast.

A.4 Jacobian Bound for Over-Squashing

Theorem (Alon & Yahav, 2021). For a GNN with LL layers and bounded weight matrices W[l]α\lVert W^{[l]} \rVert \leq \alpha and Lipschitz activation σ(x)β\sigma'(x) \leq \beta:

hv[L]xuC(αβ)L(A^L)vu\left\lVert\frac{\partial \mathbf{h}_v^{[L]}}{\partial \mathbf{x}_u}\right\rVert \leq C \cdot (\alpha\beta)^L \cdot \left(\hat{A}^L\right)_{vu}

where CC is a constant depending on the architecture.

Key observation. The entry (A^L)vu(\hat{A}^L)_{vu} counts the weighted number of walks of length LL from uu to vv. For nodes uu and vv separated by a bottleneck edge (edge (s,t)(s,t) with large removal betweenness centrality), all walks from uu to vv must pass through (s,t)(s,t). The bottleneck effect:

If node ss has degree dsd_s, then at most 1/ds1/d_s of the walks that reach ss proceed to tt (in the normalized walk). Thus (A^L)vu(1/ds)L/d(u,v)(\hat{A}^L)_{vu} \leq (1/d_s)^{\lceil L / d(u,v) \rceil}, which decays exponentially in LL when dsd_s is large.

Practical consequence. Nodes separated by a high-degree hub (common in scale-free social networks) have near-zero Jacobian - the GNN cannot propagate useful gradient signal between them regardless of depth.

A.5 Graph Isomorphism Network: Full Algorithm

Algorithm: GIN for Graph Classification

Input: Graph G=(V,E,X)G = (V, E, X) with nn nodes Output: Graph embedding hGRd\mathbf{h}_G \in \mathbb{R}^d

  1. Initialize: hv[0]=xv\mathbf{h}_v^{[0]} = \mathbf{x}_v for all vVv \in V

  2. For l=1,,Ll = 1, \ldots, L:

    • For each node vv: hv[l]=MLP[l] ⁣((1+ε[l])hv[l1]+uN(v)hu[l1])\mathbf{h}_v^{[l]} = \operatorname{MLP}^{[l]}\!\left((1 + \varepsilon^{[l]}) \mathbf{h}_v^{[l-1]} + \sum_{u \in \mathcal{N}(v)} \mathbf{h}_u^{[l-1]}\right)
    • Apply batch normalization to {hv[l]}vV\{\mathbf{h}_v^{[l]}\}_{v \in V}
  3. Readout: for each layer ll, compute layer-specific graph embedding:

hG[l]=vVhv[l]\mathbf{h}_G^{[l]} = \sum_{v \in V} \mathbf{h}_v^{[l]}
  1. Concatenate: hG=[hG[0]hG[1]hG[L]]\mathbf{h}_G = \left[\mathbf{h}_G^{[0]} \,\|\, \mathbf{h}_G^{[1]} \,\|\, \cdots \,\|\, \mathbf{h}_G^{[L]}\right]

  2. Apply MLP classifier: y^=MLPpred(hG)\hat{y} = \operatorname{MLP}_{\text{pred}}(\mathbf{h}_G)

Why concatenate all layers? Each layer captures patterns at a different structural scale: layer 0 is node features; layer 1 is immediate neighborhood; layer ll is the ll-hop subgraph structure. Concatenating all layers allows the classifier to use patterns at all scales simultaneously, matching the JK-Net (Jumping Knowledge Networks) readout.

Choice of ε\varepsilon. In practice, ε=0\varepsilon = 0 works well - the MLP is expressive enough to learn the correct weighting. Learned ε\varepsilon adds one scalar parameter per layer and sometimes improves performance on sparse graphs where the self-feature has a very different scale from the aggregated neighborhood.

A.6 GPS Layer: Formal Specification

The GPS (General, Powerful, Scalable) Graph Transformer layer (Rampasek et al., 2022) is specified as follows.

Input: Node embeddings H[l]Rn×dH^{[l]} \in \mathbb{R}^{n \times d}, adjacency AA, positional encodings PERn×p\text{PE} \in \mathbb{R}^{n \times p}

MPNN sub-layer. For each node vv:

h~v[l]=hv[l]+MPNN ⁣(hv[l],{hu[l]:uN(v)},A:,v)\tilde{\mathbf{h}}_v^{[l]} = \mathbf{h}_v^{[l]} + \operatorname{MPNN}\!\left(\mathbf{h}_v^{[l]}, \left\{\mathbf{h}_u^{[l]} : u \in \mathcal{N}(v)\right\}, A_{:,v}\right)

The MPNN can be any of: GCN, GINE, GAT, GATv2. GINE (GIN with Edge features) is:

GINE(v)=MLP ⁣((1+ε)hv+uN(v)ReLU ⁣(hu+euv))\operatorname{GINE}(v) = \operatorname{MLP}\!\left((1+\varepsilon)\mathbf{h}_v + \sum_{u \in \mathcal{N}(v)} \operatorname{ReLU}\!\left(\mathbf{h}_u + \mathbf{e}_{uv}\right)\right)

Transformer sub-layer. Standard multi-head self-attention over all nn nodes:

H^[l]=H~[l]+MultiHeadAttention ⁣(H~[l]WQ,  H~[l]WK,  H~[l]WV)\hat{H}^{[l]} = \tilde{H}^{[l]} + \operatorname{MultiHeadAttention}\!\left(\tilde{H}^{[l]} W_Q,\; \tilde{H}^{[l]} W_K,\; \tilde{H}^{[l]} W_V\right)

Feed-forward + LayerNorm:

H[l+1]=LayerNorm ⁣(H^[l]+FFN ⁣(H^[l]))H^{[l+1]} = \operatorname{LayerNorm}\!\left(\hat{H}^{[l]} + \operatorname{FFN}\!\left(\hat{H}^{[l]}\right)\right)

where FFN(h)=W2ReLU(W1h+b1)+b2\operatorname{FFN}(\mathbf{h}) = W_2 \operatorname{ReLU}(W_1 \mathbf{h} + \mathbf{b}_1) + \mathbf{b}_2 is a 2-layer MLP.

Complexity: O(md)O(m \cdot d) for the MPNN sub-layer (sparse), O(n2d)O(n^2 \cdot d) for the Transformer sub-layer (dense). Total: O((m+n2)d)O((m + n^2) \cdot d). For sparse graphs where mn2m \ll n^2, the Transformer dominates.

Practical approximation. For large graphs (n>104n > 10^4), replace the full Transformer with a linear attention approximation (Performer, Longformer-style local+global attention, or Nystromformer). This reduces Transformer cost to O(nd)O(n \cdot d), making GPS scalable to large graphs while retaining the long-range attention benefit.


Appendix B: Implementation Notes

B.1 GCN in Matrix Form: Step-by-Step

Given a graph with adjacency matrix ARn×nA \in \mathbb{R}^{n \times n}, node features XRn×dinX \in \mathbb{R}^{n \times d_{\text{in}}}, and weight matrices W[0]Rdin×d1W^{[0]} \in \mathbb{R}^{d_{\text{in}} \times d_1}, W[1]Rd1×doutW^{[1]} \in \mathbb{R}^{d_1 \times d_{\text{out}}}:

Step 1: Compute A = A + I_n
Step 2: Compute D = diag(A 1) = diag of row sums
Step 3: Compute D^{-1/2}: take 1/sqrt of diagonal entries
Step 4: Compute A = D^{-1/2} A D^{-1/2}   (sparse: O(m) entries)
Step 5: H^1 = ReLU(A X W^0)    - shape: (n, d_1)
Step 6: H^2 = softmax(A H^1 W^1) - shape: (n, d_out)   [for classification]
Step 7: Loss = -sum_{v in V_L} sum_c Y_{vc} log H^2_{vc}
Step 8: Backpropagate through H^2, H^1, W^1, W^0

Sparsity. AA for real-world graphs is extremely sparse (average degree 10 out of n=106n=10^6 nodes means 105\sim 10^{-5} fill). All operations involving A^\hat{A} should use sparse matrix formats (CSR, COO). The computation A^H\hat{A}H for sparse A^\hat{A} and dense HH costs O(md)O(m \cdot d), not O(n2d)O(n^2 \cdot d).

B.2 Edge List Representation and Scatter Operations

In practice, GNNs are implemented using an edge list representation rather than adjacency matrices. The graph is stored as two arrays:

edge_index: shape (2, m) - edge_index[0] = source nodes, edge_index[1] = target nodes
edge_attr:  shape (m, d_e) - edge features
x:          shape (n, d_v) - node features

The message-passing computation mv=uN(v)M(hu,euv)\mathbf{m}_v = \sum_{u \in \mathcal{N}(v)} M(\mathbf{h}_u, \mathbf{e}_{uv}) is implemented as:

# Gather source node features for all edges
messages = M(x[edge_index[0]], edge_attr)  # shape: (m, d')

# Scatter messages to target nodes (sum aggregation)
agg = torch.zeros(n, d').scatter_add(0, 
    edge_index[1].unsqueeze(-1).expand(-1, d'),
    messages)  # shape: (n, d')

This scatter_add operation (implemented in PyTorch Geometric and DGL) is the fundamental primitive of GPU-accelerated GNN training. It is equivalent to a sparse matrix-vector product A^H\hat{A}H but more flexible (supports arbitrary message functions).

B.3 Mini-Batch Neighbor Sampling Implementation

GraphSAGE Training Step:
========================
1. Sample batch of target nodes: B \subseteq V, |B| = B_size
2. For l = L, L-1, ..., 1:
   - For each node v in current frontier:
     - Sample S_l neighbors: N_l(v) \subseteq N(v), |N_l(v)| = S_l
   - Current frontier = B \cup 1-hop(B) \cup ... \cup l-hop(B) [sampled]
3. Extract induced subgraph of all frontier nodes
4. Run GNN on induced subgraph (full-batch over small subgraph)
5. Compute loss only on target nodes B
6. Backpropagate

Total nodes per training step: B×l=1LSl|B| \times \prod_{l=1}^L S_l. For B=256|B|=256, L=2L=2, S1=S2=10S_1=S_2=10: 256×100=25,600256 \times 100 = 25,600 nodes per step.

B.4 Practical GNN Hyperparameter Guide

HyperparameterTypical RangeNotes
Number of layers LL2-4Start with 2; increase only if task requires long-range
Hidden dimension dd64-512Match to dataset size; 256 is a strong default
Dropout rate0.0-0.50.3-0.5 on feature/edge dropout helps regularize
AggregationSum (GIN), Mean (GCN)Sum for classification; mean for regression
ActivationReLU, PReLUPReLU slightly better on sparse graphs
Batch normalizationAfter each layerEssential for deep GNNs (4+ layers)
Learning rate0.001-0.01Adam optimizer; cosine decay schedule
Neighbor sample sizes(25,10)(25, 10) or (15,10,5)(15, 10, 5)Larger S1S_1 improves accuracy; diminishing returns
Weight decay10510^{-5}-10410^{-4}L2 regularization on weight matrices
Number of attention heads KK4-8For GAT; each head uses d/Kd/K features

Skill Check

Test this lesson

Answer 4 quick questions to lock in the lesson and feed your adaptive practice queue.

--
Score
0/4
Answered
Not attempted
Status
1

Which module does this lesson belong to?

2

Which section is covered in this lesson content?

3

Which term is most central to this lesson?

4

What is the best way to use this lesson for real learning?

Your answers save locally first, then sync when account storage is available.
Practice queue