Private notes
0/8000

Notes stay private to your browser until account sync is configured.

Part 1
27 min read18 headingsSplit lesson page

Lesson overview | Lesson overview | Next part

Graph Neural Networks: Part 1: Intuition to 5. GraphSAGE: Inductive Learning on Large Graphs

1. Intuition

1.1 The Core Challenge: Learning on Non-Euclidean Data

Standard deep learning architectures - convolutional networks, recurrent networks, transformers - were designed for data that lives on regular, well-ordered domains. An image is a grid: every pixel has exactly the same number of neighbors, arranged in the same spatial pattern. A sentence is a sequence: every token has a left neighbor and a right neighbor, and positions are canonically ordered. These regular structures make convolution and positional encodings natural.

Real-world data is rarely so obliging. A molecule has atoms connected by bonds in a pattern determined by chemistry, not a grid. A social network has users connected by friendships in patterns determined by behavior and geography. A knowledge base has concepts connected by relations in patterns determined by the world's semantic structure. A protein's function is determined by the 3D arrangement of amino acid residues, which a graph can represent but a grid cannot. These are non-Euclidean data structures: they have no natural notion of translation, no canonical coordinate system, no uniform neighborhood size.

The challenge GNNs solve is: how do we define a learnable function on a graph that generalizes across different graphs, respects the graph's structure, and can be trained end-to-end by gradient descent? This is harder than it sounds. The same function applied to two isomorphic graphs (graphs with the same structure but differently labeled vertices) should produce the same output. The function must handle graphs of arbitrary size and arbitrary vertex degrees. And it must be expressive enough to detect structural patterns (triangles, paths of length kk, specific subgraph motifs) that determine the graph's properties.

For AI: In 2026, GNNs are production components in AlphaFold 2's structure module (atoms as nodes, spatial edges), PinSage (users and items as nodes, interactions as edges with 3 billion monthly active users), and Microsoft's Graph RAG system (entities and relations from documents as knowledge graphs for retrieval-augmented generation). Understanding GNNs is not optional for anyone working with structured, relational, or molecular data.

1.2 The Message-Passing Insight

The insight that makes GNNs possible is conceptually simple: a node's representation should depend on the representations of its neighbors, and this dependency should be computed iteratively.

Consider a social network node - a person. Their social identity depends on whom they know. But whom they know also depends on whom those people know, and so on. The first round of "listening to neighbors" gives each person a sense of their immediate social circle. The second round gives a sense of friends-of-friends. After kk rounds, each node's representation encodes the structure of its kk-hop neighborhood.

This is message passing: in each layer, every node sends a message to its neighbors, every node aggregates the messages it receives, and every node updates its representation based on the aggregated messages. Formally, for node vv at layer ll:

mv[l]=AGGREGATE[l] ⁣({hu[l]:uN(v)})\mathbf{m}_v^{[l]} = \operatorname{AGGREGATE}^{[l]}\!\left(\left\{\mathbf{h}_u^{[l]} : u \in \mathcal{N}(v)\right\}\right) hv[l+1]=UPDATE[l] ⁣(hv[l],mv[l])\mathbf{h}_v^{[l+1]} = \operatorname{UPDATE}^{[l]}\!\left(\mathbf{h}_v^{[l]},\, \mathbf{m}_v^{[l]}\right)

where N(v)\mathcal{N}(v) is the set of neighbors of vv, hv[l]Rd\mathbf{h}_v^{[l]} \in \mathbb{R}^d is node vv's representation at layer ll, and hv[0]=xv\mathbf{h}_v^{[0]} = \mathbf{x}_v is the initial node feature.

The connection to BFS: Breadth-first search from node vv explores exactly the nodes reachable in kk steps at step kk. A kk-layer GNN computes a representation that depends on exactly the nodes reachable in kk steps. A BFS is the non-learned, binary-valued version of message passing: it answers "which nodes are in the kk-hop neighborhood?" whereas a GNN answers "what is the learned summary of the kk-hop neighborhood?"

For AI: This message-passing structure is the graph analogue of the convolutional receptive field. A pixel in a CNN "sees" a k×kk \times k grid after kk convolutional layers. A node in a GNN "sees" its kk-hop neighborhood after kk message-passing layers. The critical difference: the CNN's receptive field grows quadratically in kk (for a 2D grid), while the GNN's receptive field can grow exponentially in kk (for a well-connected graph), which creates both expressiveness and scalability challenges we will study in 8.

1.3 Two Perspectives: Spectral vs Spatial

Graph convolution was first defined in the spectral domain. The key idea from 11-04 is that the graph Laplacian L=DAL = D - A admits an eigendecomposition L=UΛUL = U \Lambda U^\top, and the columns of UU form a graph Fourier basis. A graph signal xRn\mathbf{x} \in \mathbb{R}^n (one scalar per node) is filtered by applying a function to its Fourier coefficients:

x^=Ux,filtered: y=Ug(Λ)x^\hat{\mathbf{x}} = U^\top \mathbf{x}, \qquad \text{filtered: } \mathbf{y} = U \cdot g(\Lambda) \cdot \hat{\mathbf{x}}

where g(Λ)g(\Lambda) is a diagonal matrix of filter coefficients. Full spectral convolution requires computing the full eigendecomposition - O(n3)O(n^3) cost, infeasible for large graphs.

Recall from 11-04 7.5: Kipf & Welling (2017) derived the GCN layer as a first-order Chebyshev polynomial approximation to spectral filtering, further simplified by setting λmax2\lambda_{\max} \approx 2 and adding self-loops. This gives the propagation rule H[l+1]=σ(D~1/2A~D~1/2H[l]W[l])H^{[l+1]} = \sigma(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}H^{[l]}W^{[l]}) - see 11-04 7.5 for the full derivation.

Once we arrive at the GCN layer, the spectral scaffolding can be discarded and the layer reinterpreted spatially: it is simply normalized aggregation of neighbor features followed by a linear transformation and nonlinearity. This spatial reinterpretation is what makes GNNs extensible: one can design new aggregation functions, learn attention weights, incorporate edge features, and handle directed graphs - none of which fit neatly into the spectral framework.

The two perspectives are complementary:

  • Spectral: principled derivation, frequency interpretation, connection to graph signal processing; limited to fixed graphs, requires eigendecomposition or polynomial approximation
  • Spatial: flexible, scalable, inductive, handles heterogeneous graphs; less theoretically constrained, harder to analyze convergence

Modern GNN research lives primarily in the spatial perspective, with spectral theory providing theoretical grounding and positional encodings (10).

1.4 What Makes Graphs Different from Other Data

Four properties of graph-structured data require fundamentally different architectural choices:

1. Variable and irregular neighborhoods. Node vv in a molecule may have 1 bond (hydrogen) or 4 bonds (carbon). Node uu in a social network may have 3 friends or 3,000. Standard convolution assumes a fixed-size, fixed-pattern receptive field. GNNs must aggregate over neighborhoods of arbitrary size - which forces a choice of aggregation function (sum, mean, max, attention) that is insensitive to neighborhood size.

2. No canonical node ordering. An image has a canonical pixel ordering (row-major). A sentence has a canonical word ordering (left-to-right). A graph has no such ordering. If we relabel the nodes of a graph, we get the same graph - just with different indices. This means a GNN must be permutation equivariant (the output node representations permute consistently with any relabeling of input nodes) and any graph-level prediction must be permutation invariant (the same under any relabeling). This rules out any architecture that depends on node indices.

3. Structural information is in the topology. Two molecules with identical atom composition but different bond structures (structural isomers) have very different chemical properties. The GNN must detect this topological difference from the adjacency structure alone. This is the expressiveness question of 7: which topological patterns can a GNN distinguish?

4. Heterogeneous graphs. Real graphs often have multiple types of nodes and edges (knowledge graphs: Person --worksAt--> Company, Person --bornIn--> City). Relational GNNs (R-GCN) and heterogeneous GNNs handle this by type-specific transformation matrices. This section focuses on homogeneous graphs; heterogeneous extensions are discussed in 12.2.

1.5 Historical Timeline (2013-2026)

GRAPH NEURAL NETWORK HISTORICAL TIMELINE
========================================================================

  2009  Scarselli et al. - Original GNN: fixed-point iteration on graphs
  2013  Bruna et al.    - Spectral graph convolution (full eigendecomp.)
  2015  Defferrard et al. - ChebNet: Chebyshev polynomial spectral filters
  2016  Li et al.       - Gated GNN: GRU-based update functions
  2017  Kipf & Welling  - GCN: simplified spectral GNN for node classification
  2017  Hamilton et al. - GraphSAGE: inductive learning with neighbor sampling
  2017  Gilmer et al.   - MPNN: unified message-passing framework
  2018  Velickovic et al. - GAT: attention-weighted neighborhood aggregation
  2019  Xu et al.       - GIN + WL expressiveness theorem
  2019  Ying et al.     - DiffPool: differentiable hierarchical pooling
  2019  Chiang et al.   - Cluster-GCN: scalable training by graph partitioning
  2020  Rong et al.     - DropEdge: data augmentation against over-smoothing
  2021  Ying et al.     - Graphormer: transformer for graphs (OGB-LSC winner)
  2021  Brody et al.    - GATv2: fixed dynamic attention for GAT
  2022  Rampasek et al. - GPS: general, powerful, scalable graph transformer
  2023  Chen et al.     - Graph Mamba: state-space models on graphs
  2024  Microsoft       - Graph RAG: GNNs for retrieval-augmented generation
  2024  -2026           - LLM+GNN: language models with graph-structured memory

========================================================================

2. Formal Setup: Graphs as Data

2.1 Graph with Node and Edge Features

A featured graph is a tuple G=(V,E,X,Efeat)G = (V, E, X, E_{\text{feat}}) where:

  • V={1,2,,n}V = \{1, 2, \ldots, n\} - the vertex set, n=Vn = |V| nodes
  • EV×VE \subseteq V \times V - the edge set, m=Em = |E| edges
  • XRn×dvX \in \mathbb{R}^{n \times d_v} - node feature matrix: row Xv,:=xvRdvX_{v,:} = \mathbf{x}_v \in \mathbb{R}^{d_v} is the feature vector of node vv
  • EfeatRm×deE_{\text{feat}} \in \mathbb{R}^{m \times d_e} - edge feature matrix: row euvRdee_{uv} \in \mathbb{R}^{d_e} is the feature of edge (u,v)(u,v)

The adjacency structure is encoded either as a dense matrix A{0,1}n×nA \in \{0,1\}^{n \times n} or (for large sparse graphs) as an edge list - a pair of tensors (s,t)Zm×Zm(\mathbf{s}, \mathbf{t}) \in \mathbb{Z}^m \times \mathbb{Z}^m where sk\mathbf{s}_k is the source and tk\mathbf{t}_k is the target of edge kk.

Examples:

DomainNode features xv\mathbf{x}_vEdge features euve_{uv}Task
Molecular graphAtom type (one-hot), charge, hybridizationBond type (single/double/aromatic), distancePredict solubility
Social networkUser demographics, activity historyInteraction frequency, timestampLink prediction
Knowledge graphEntity type embeddingRelation type embeddingTriple completion
Citation networkBag-of-words of paper abstractNone (unweighted)Classify research area
Code ASTToken type, identifier embeddingSyntactic relation typeBug detection

Non-examples of graph data:

  • An image stored as a grid: regular structure makes CNN superior; GNNs are overkill and less efficient
  • A flat table of independent samples: no edges, no relational structure; standard ML applies
  • A time series: sequential structure with canonical ordering; RNN/Transformer preferred unless interactions are truly non-sequential

2.2 Learning Tasks on Graphs

GNNs support three levels of prediction:

Node-level tasks: predict a label yvy_v for each node vv (or a subset). The GNN produces a node embedding hvRd\mathbf{h}_v \in \mathbb{R}^d and a final classifier y^v=f(hv)\hat{y}_v = f(\mathbf{h}_v).

  • Semi-supervised node classification: most nodes are unlabeled; a small fraction have known labels. The GNN propagates information from labeled to unlabeled nodes through the graph structure. Classic example: Cora/CiteSeer citation networks (Kipf & Welling, 2017).
  • Node regression: predict a continuous value per node. Example: predicting traffic speed at road intersections.

Edge-level tasks: predict a property of a pair of nodes (u,v)(u, v), whether or not the edge exists.

  • Link prediction: does edge (u,v)(u, v) exist? Used in recommendation (predict user-item interaction) and knowledge graph completion (predict missing relation). The GNN produces hu\mathbf{h}_u and hv\mathbf{h}_v; the prediction is y^uv=f(hu,hv)\hat{y}_{uv} = f(\mathbf{h}_u, \mathbf{h}_v) (e.g., inner product or MLP).
  • Edge classification: predict the type of an existing edge. Example: classify protein-protein interaction type (physical, genetic, etc.).

Graph-level tasks: predict a single label for an entire graph GG.

  • Graph classification: classify molecular graphs as toxic/non-toxic; classify citation subgraphs by topic. Requires a readout function that aggregates all node embeddings into a single graph embedding hGRd\mathbf{h}_G \in \mathbb{R}^d.
  • Graph regression: predict a scalar property of the graph (e.g., binding affinity, quantum energy). Most molecular property prediction benchmarks (OGB-molhiv, QM9) are graph regression tasks.

2.3 Permutation Invariance and Equivariance

Definition (Permutation Equivariance). Let Πn\Pi_n be the group of n×nn \times n permutation matrices. A function f:Rn×d×{0,1}n×nRn×df: \mathbb{R}^{n \times d} \times \{0,1\}^{n \times n} \to \mathbb{R}^{n \times d'} is permutation equivariant if for every permutation matrix PΠnP \in \Pi_n:

f(PX,PAP)=Pf(X,A)f(PX, PAP^\top) = P \cdot f(X, A)

That is, permuting the node ordering in the input permutes the output representations in the same way. Node-level GNNs must be permutation equivariant: if we relabel the nodes, the node embeddings should permute consistently.

Definition (Permutation Invariance). A function g:Rn×d×{0,1}n×nRg: \mathbb{R}^{n \times d} \times \{0,1\}^{n \times n} \to \mathbb{R} is permutation invariant if for every permutation PP:

g(PX,PAP)=g(X,A)g(PX, PAP^\top) = g(X, A)

Graph-level GNNs must be permutation invariant: the prediction for a graph is the same regardless of how we number its vertices.

Why this constrains architecture. Consider a simple approach: concatenate all node features into a vector [x1,x2,,xn][\mathbf{x}_1^\top, \mathbf{x}_2^\top, \ldots, \mathbf{x}_n^\top]^\top and pass it through an MLP. This is not permutation equivariant - shuffling the nodes changes the input vector and changes the output. Similarly, computing any function of the matrix AA that uses specific row/column indices (like the (1,2)(1,2) entry) is not permutation invariant.

The only permutation-equivariant functions that aggregate neighborhoods are those that apply a fixed function to the set of neighbor features - which is precisely what GNN aggregation functions do. Functions of sets must be insensitive to the order of elements; sum, mean, and max all satisfy this; concatenation does not.

Formal example. Let hv[1]=σ(W1N(v)uN(v)xu+W0xv)\mathbf{h}_v^{[1]} = \sigma(W \cdot \frac{1}{|\mathcal{N}(v)|}\sum_{u \in \mathcal{N}(v)} \mathbf{x}_u + W_0 \mathbf{x}_v) (GCN-style). This is equivariant because uN(v)\sum_{u \in \mathcal{N}(v)} sums over a set - reordering the neighbors does not change the sum. But [xu1,xu2,][\mathbf{x}_{u_1}^\top, \mathbf{x}_{u_2}^\top, \ldots]^\top for an ordered list of neighbors is not equivariant.

2.4 Representations: From Single Graph to Batched Graphs

Training GNNs on datasets of multiple small graphs (graph classification, molecular property prediction) requires efficient batching. Unlike image batches where tensors have uniform shape, graphs have variable numbers of nodes and edges.

The block-diagonal batch trick. Given graphs G1,G2,,GBG_1, G_2, \ldots, G_B with n1,n2,,nBn_1, n_2, \ldots, n_B nodes, form a single disconnected "super-graph" GbatchG_{\text{batch}} with N=iniN = \sum_i n_i nodes:

Abatch=(A1A2),Xbatch=(X1X2)A_{\text{batch}} = \begin{pmatrix} A_1 & & \\ & A_2 & \\ & & \ddots \end{pmatrix}, \qquad X_{\text{batch}} = \begin{pmatrix} X_1 \\ X_2 \\ \vdots \end{pmatrix}

The adjacency matrix is block-diagonal (no edges between different graphs). Since graphs in the batch are disconnected, message passing does not cross graph boundaries. A batch index vector bZN\mathbf{b} \in \mathbb{Z}^N tracks which graph each node belongs to, enabling graph-level readout by aggregating within each block.

Virtual node trick. Adding a "super node" connected to all real nodes in a graph is a common trick to enable long-range information propagation without increasing depth. The virtual node aggregates the entire graph's information in one hop, then broadcasts back. Used in MPNN (Gilmer et al., 2017) and Graphormer (Ying et al., 2021). It also serves as a form of global pooling: hvirtual=READOUT({hv})\mathbf{h}_{\text{virtual}} = \operatorname{READOUT}(\{\mathbf{h}_v\}).


3. The MPNN Framework

3.1 The Gilmer et al. (2017) Formulation

The Message Passing Neural Network (MPNN) framework, introduced by Gilmer et al. (2017) for quantum chemistry property prediction, provides the canonical mathematical abstraction for GNNs. It unifies GCN, GraphSAGE, GAT, GIN, and most other spatial GNN architectures as special cases.

Algorithm: Message Passing Phase

For l=0,1,,L1l = 0, 1, \ldots, L-1 (layers):

mv[l+1]=uN(v)M[l] ⁣(hv[l],hu[l],euv)\mathbf{m}_v^{[l+1]} = \sum_{u \in \mathcal{N}(v)} M^{[l]}\!\left(\mathbf{h}_v^{[l]}, \mathbf{h}_u^{[l]}, \mathbf{e}_{uv}\right) hv[l+1]=U[l] ⁣(hv[l],mv[l+1])\mathbf{h}_v^{[l+1]} = U^{[l]}\!\left(\mathbf{h}_v^{[l]}, \mathbf{m}_v^{[l+1]}\right)

Readout Phase (for graph-level tasks):

y^G=R ⁣({hv[L]:vV})\hat{y}_G = R\!\left(\left\{\mathbf{h}_v^{[L]} : v \in V\right\}\right)

Here:

  • M[l]:Rd×Rd×RdeRdM^{[l]}: \mathbb{R}^{d} \times \mathbb{R}^{d} \times \mathbb{R}^{d_e} \to \mathbb{R}^{d'} - message function at layer ll; takes the central node's features, a neighbor's features, and the edge features, returns a message vector
  • U[l]:Rd×RdRdU^{[l]}: \mathbb{R}^d \times \mathbb{R}^{d'} \to \mathbb{R}^{d''} - update function at layer ll; combines the node's current representation with the aggregated message
  • R:2RdRkR: 2^{\mathbb{R}^d} \to \mathbb{R}^k - readout function; maps a set of node representations to a graph-level prediction

The aggregation in the message phase is written as a sum, but can be replaced by any permutation-invariant function (mean, max, attention). The sum formulation is canonical because it is the most expressive (see 7).

Initialization: hv[0]=xv\mathbf{h}_v^{[0]} = \mathbf{x}_v - the initial representation is the input node feature.

For AI: The MPNN framework is so general that it also subsumes certain attention mechanisms. The transformer's self-attention layer can be viewed as a fully-connected MPNN (every token attends to every other token) with M(qi,kj,vj)=softmax(qikj/dk)vjM(q_i, k_j, v_j) = \operatorname{softmax}(q_i^\top k_j / \sqrt{d_k}) \cdot v_j as the message function.

3.2 Message Functions

The message function M(hv,hu,euv)M(\mathbf{h}_v, \mathbf{h}_u, \mathbf{e}_{uv}) determines what information flows along each edge. Several common designs:

Edge-independent messages. Many simple GNNs (GCN, GIN) ignore edge features and use:

muv=M(hu)=Whu\mathbf{m}_{uv} = M(\mathbf{h}_u) = W \mathbf{h}_u

or

muv=M(hu,hv)=W1hu+W2hv\mathbf{m}_{uv} = M(\mathbf{h}_u, \mathbf{h}_v) = W_1 \mathbf{h}_u + W_2 \mathbf{h}_v

Edge-conditioned messages. When edge features are available:

muv=fθ(hu,euv)\mathbf{m}_{uv} = f_\theta(\mathbf{h}_u, \mathbf{e}_{uv})

where fθf_\theta is an MLP. Used in MPNN for molecular graphs (Gilmer et al., 2017): edge features encode bond type, bond length, and stereochemistry.

Pair messages. In models like NNConv, the edge feature defines the transformation matrix:

muv=Φ(euv)hu\mathbf{m}_{uv} = \Phi(\mathbf{e}_{uv}) \cdot \mathbf{h}_u

where Φ:RdeRd×d\Phi: \mathbb{R}^{d_e} \to \mathbb{R}^{d' \times d} is an MLP that produces a weight matrix from the edge features. This is extremely expressive but quadratically expensive in dd.

3.3 Aggregation: Permutation-Invariant Pooling

The aggregation function must be permutation invariant over the set {muv:uN(v)}\{\mathbf{m}_{uv} : u \in \mathcal{N}(v)\}. Three canonical choices and their properties:

Sum aggregation:

mv=uN(v)muv\mathbf{m}_v = \sum_{u \in \mathcal{N}(v)} \mathbf{m}_{uv}

Sensitive to neighborhood size - a node with 10 neighbors gets a larger aggregate than a node with 2, even if their neighborhood compositions are identical. This turns out to be a feature, not a bug: sum aggregation is the most expressive (7.4). Used in GIN (Xu et al., 2019).

Mean aggregation:

mv=1N(v)uN(v)muv\mathbf{m}_v = \frac{1}{|\mathcal{N}(v)|} \sum_{u \in \mathcal{N}(v)} \mathbf{m}_{uv}

Normalizes by neighborhood size. Good when neighborhood size is not informative and you want to compare average neighbor characteristics across nodes of different degrees. Used in GCN (Kipf & Welling, 2017) and GraphSAGE (mean variant).

Max aggregation:

(mv)k=maxuN(v)(muv)kfor each dimension k(\mathbf{m}_v)_k = \max_{u \in \mathcal{N}(v)} (\mathbf{m}_{uv})_k \quad \text{for each dimension } k

Detects whether any neighbor has feature kk above a threshold. Good for detecting the presence of particular node types in the neighborhood. Used in GraphSAGE (max-pool variant).

Expressiveness hierarchy. Xu et al. (2019) proved: for injective graph classification, sum \succ mean = max in expressive power. Mean cannot distinguish a graph with two nodes each of degree 1 from a graph with four nodes each of degree 2 if neighbor features are identical. Max cannot count copies of identical features. Sum can distinguish both.

Attention aggregation. Instead of fixed weights, learn attention coefficients:

mv=uN(v)αuvmuv,αuv=softmaxu(euv)\mathbf{m}_v = \sum_{u \in \mathcal{N}(v)} \alpha_{uv} \cdot \mathbf{m}_{uv}, \qquad \alpha_{uv} = \operatorname{softmax}_u(e_{uv})

where euve_{uv} is a learned attention score. This is the GAT model (6).

3.4 Update Functions

The update function U(hv[l],mv[l+1])U(\mathbf{h}_v^{[l]}, \mathbf{m}_v^{[l+1]}) combines the node's current representation with the aggregated neighborhood message.

Linear update (GCN):

hv[l+1]=σ ⁣(W[l]mv[l+1])\mathbf{h}_v^{[l+1]} = \sigma\!\left(W^{[l]} \mathbf{m}_v^{[l+1]}\right)

The current node representation is absorbed into the message (via A~=A+I\tilde{A} = A + I including the self-loop).

Concatenation update (GraphSAGE):

hv[l+1]=σ ⁣(W[l][hv[l]mv[l+1]])\mathbf{h}_v^{[l+1]} = \sigma\!\left(W^{[l]} \left[\mathbf{h}_v^{[l]} \,\|\, \mathbf{m}_v^{[l+1]}\right]\right)

Concatenating the node's current representation with the aggregated neighbor message before transforming. Preserves the node's own information explicitly.

GRU update (Gated GNN, Li et al. 2016):

hv[l+1]=GRU ⁣(hv[l],mv[l+1])\mathbf{h}_v^{[l+1]} = \operatorname{GRU}\!\left(\mathbf{h}_v^{[l]}, \mathbf{m}_v^{[l+1]}\right)

The GRU's gating mechanism adaptively controls how much of the current representation to retain vs. replace with the new message. More expressive than linear update but more parameters.

Residual update:

hv[l+1]=σ ⁣(W[l]mv[l+1])+hv[l]\mathbf{h}_v^{[l+1]} = \sigma\!\left(W^{[l]} \mathbf{m}_v^{[l+1]}\right) + \mathbf{h}_v^{[l]}

Skip connections (as in ResNet) preserve gradient flow in deep GNNs and mitigate over-smoothing. Used in GCNII (Chen et al., 2020) for training 64-layer GCNs successfully.

3.5 Readout for Graph-Level Tasks

For graph classification and regression, the node representations {hv[L]}\{\mathbf{h}_v^{[L]}\} must be aggregated into a single graph embedding hGRd\mathbf{h}_G \in \mathbb{R}^d.

Global pooling:

hG=READOUT ⁣({hv[L]:vV})\mathbf{h}_G = \operatorname{READOUT}\!\left(\left\{\mathbf{h}_v^{[L]} : v \in V\right\}\right)

Simple choices: vhv\sum_v \mathbf{h}_v, 1nvhv\frac{1}{n}\sum_v \mathbf{h}_v, maxvhv\max_v \mathbf{h}_v (element-wise). Sum is most expressive; mean normalizes for graph size.

Hierarchical pooling. Rather than a single pooling step at the end, hierarchical GNNs (DiffPool, SAGPool) alternate GNN layers with coarsening steps that reduce the number of nodes. This mirrors how image CNNs alternate convolution with spatial downsampling.

Jumping Knowledge (Xu et al., 2018). Rather than using only the final layer's representations, JK-networks concatenate representations from all layers:

hvfinal=COMBINE ⁣(hv[0],hv[1],,hv[L])\mathbf{h}_v^{\text{final}} = \operatorname{COMBINE}\!\left(\mathbf{h}_v^{[0]}, \mathbf{h}_v^{[1]}, \ldots, \mathbf{h}_v^{[L]}\right)

This allows the readout to leverage both local (shallow layers) and global (deep layers) structural information, and empirically mitigates over-smoothing.

Set2Set (Vinyals et al., 2016). A more powerful readout using an LSTM-based attention over the node set, producing an order-invariant graph embedding that depends on all node representations through multiple rounds of attention. More expressive than simple global pooling but O(n)O(n) sequential steps.

3.6 MPNN Instances: Unifying GCN, GraphSAGE, GAT, GIN

The following table shows how each major GNN architecture is a special case of the MPNN framework:

MPNN UNIFICATION TABLE
========================================================================

  Model        Message M(h_v, h_u, e_uv)    Aggregation   Update U(h_v, m_v)
  ---------------------------------------------------------------------
  GCN          W*h_u                         Mean (norm.)  \sigma(W*m_v)
  GraphSAGE    W_1*h_u                       Mean/Max/Sum  \sigma(W*[h_v || m_v])
  GAT          \alpha_uv * W*h_u                 Attention     \sigma(W*m_v)
               (\alpha_uv learned from h_v, h_u)
  GIN          h_u                           Sum           MLP((1+\epsilon)*h_v + m_v)
  MPNN-Gilmer  f_(h_u, e_uv)               Sum           GRU(h_v, m_v)
  ---------------------------------------------------------------------
  Key insight: choice of M, aggregation, U determines expressiveness
  and scalability. Sum agg. + injective MLP = most expressive (7)

========================================================================

4. Graph Convolutional Networks (GCN)

4.1 Spectral Foundation Recap

Recall from 11-04 7.5: The GCN layer is derived as a first-order Chebyshev polynomial approximation to spectral graph convolution. Starting from the spectral filter gθ(Λ)=θ0I+θ1(ΛI)g_\theta(\Lambda) = \theta_0 I + \theta_1(\Lambda - I) applied to the normalized Laplacian Lsym=ID1/2AD1/2L_{\text{sym}} = I - D^{-1/2}AD^{-1/2} (with eigenvalues in [0,2][0,2]), constraining θ=θ0=θ1\theta = \theta_0 = -\theta_1 for a single parameter per layer, and applying the renormalization trick A~=A+I\tilde{A} = A + I, D~ii=jA~ij\tilde{D}_{ii} = \sum_j \tilde{A}_{ij}, yields the GCN propagation rule. Full derivation: 11-04 Spectral Graph Theory 7.5.

The key takeaway is that the GCN layer is a principled spectral filter, not an ad hoc heuristic. The choice of A~\tilde{A} (adding self-loops) and D~1/2()D~1/2\tilde{D}^{-1/2}(\cdot)\tilde{D}^{-1/2} (symmetric normalization) both arise from the spectral derivation and can now be given spatial interpretations.

4.2 The GCN Layer

The GCN propagation rule for all nodes simultaneously (matrix form):

H[l+1]=σ ⁣(D~1/2A~D~1/2H[l]W[l])H^{[l+1]} = \sigma\!\left(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2} H^{[l]} W^{[l]}\right)

where:

  • A~=A+In\tilde{A} = A + I_n - adjacency matrix with added self-loops
  • D~ii=jA~ij\tilde{D}_{ii} = \sum_j \tilde{A}_{ij} - degree matrix of A~\tilde{A}
  • H[l]Rn×dlH^{[l]} \in \mathbb{R}^{n \times d_l} - node feature matrix at layer ll; H[0]=XH^{[0]} = X
  • W[l]Rdl×dl+1W^{[l]} \in \mathbb{R}^{d_l \times d_{l+1}} - trainable weight matrix at layer ll
  • σ\sigma - nonlinear activation (ReLU for hidden layers, softmax for output in classification)

Node-level form. For a single node vv:

hv[l+1]=σ ⁣(W[l]uN(v){v}1d~vd~uhu[l])\mathbf{h}_v^{[l+1]} = \sigma\!\left(W^{[l]\top} \sum_{u \in \mathcal{N}(v) \cup \{v\}} \frac{1}{\sqrt{\tilde{d}_v \tilde{d}_u}} \mathbf{h}_u^{[l]}\right)

where d~v=D~vv=dv+1\tilde{d}_v = \tilde{D}_{vv} = d_v + 1 (degree including self-loop). Each neighbor uu's contribution is weighted by 1d~vd~u\frac{1}{\sqrt{\tilde{d}_v \tilde{d}_u}} - the geometric mean of the (augmented) degrees. This normalization prevents high-degree nodes from dominating and ensures the propagation matrix has spectral radius \leq 1.

The propagation matrix. Define A^=D~1/2A~D~1/2\hat{A} = \tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}. This is the normalized adjacency of A~\tilde{A}. Its eigenvalues lie in [1,1][-1, 1] (since A~\tilde{A} is symmetric, non-negative). A key property:

A^ijk=paths of length k:ijepath1d~vsd~vt\hat{A}^k_{ij} = \sum_{\text{paths of length }k: i \to j} \prod_{e \in \text{path}} \frac{1}{\sqrt{\tilde{d}_{v_s} \tilde{d}_{v_t}}}

After LL GCN layers, node vv's representation is a learned function of the (weighted) LL-hop neighborhood - exactly the receptive field intuition from 1.2.

4.3 Matrix View: Propagation Rule

The GCN can be understood as a learned diffusion process. Define S=D~1/2A~D~1/2S = \tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2} as the diffusion operator (also called the normalized propagation matrix). Then:

H[l+1]=σ(SH[l]W[l])H^{[l+1]} = \sigma\left(S H^{[l]} W^{[l]}\right)
  • SH[l]S H^{[l]} - neighborhood aggregation: each node's features become a weighted average of its neighbors' features. This is a graph smoothing step - it reduces variation across connected nodes.
  • ()W[l](\cdot) W^{[l]} - feature transformation: a learnable linear map on the dld_l-dimensional feature space.
  • σ()\sigma(\cdot) - nonlinearity: introduces expressive power.

Repeating this LL times gives H[L]=σ(Sσ(Sσ(SXW[0])W[L2])W[L1])H^{[L]} = \sigma(S \cdot \sigma(S \cdots \sigma(S X W^{[0]}) \cdots W^{[L-2]}) W^{[L-1]}).

Connection to heat diffusion. The continuous heat equation on a graph is ht=Lh\frac{\partial \mathbf{h}}{\partial t} = -L\mathbf{h}, with solution h(t)=etLh(0)\mathbf{h}(t) = e^{-tL}\mathbf{h}(0). The GCN step SHS H approximates one discrete step of this diffusion. Deep GCNs approximate long-time diffusion, which - as we will see in 8 - causes all node representations to converge to the same value (over-smoothing).

4.4 Semi-Supervised Node Classification

The original application of GCN (Kipf & Welling, 2017) is semi-supervised node classification: given a graph with nn nodes, nLnn_L \ll n labeled and nU=nnLn_U = n - n_L unlabeled, train the GCN to predict class labels for all nodes.

Setup:

  • Two-layer GCN: Y^=softmax ⁣(SReLU(SXW[0])W[1])\hat{Y} = \operatorname{softmax}\!\left(S \cdot \operatorname{ReLU}(S X W^{[0]}) \cdot W^{[1]}\right)
  • Loss: cross-entropy over labeled nodes only: L=vVLcYvclogY^vc\mathcal{L} = -\sum_{v \in V_L} \sum_c Y_{vc} \log \hat{Y}_{vc}
  • The GCN propagates label information from labeled to unlabeled nodes through the graph structure (similar to label propagation, but learned)

Benchmark results (Cora citation network, 140 labeled / 2708 total):

MethodAccuracy
DeepWalk + SVM67.2%
Label Propagation68.0%
Planetoid (Yang et al. 2016)75.7%
GCN (Kipf & Welling, 2017)81.5%
GAT (Velickovic et al., 2018)83.0%
GCNII (Chen et al., 2020, 64 layers)85.5%

The GCN's success on this benchmark established GNNs as the go-to method for graph-structured learning and spawned the entire field.

Why it works. The graph encodes a "homophily" assumption: connected nodes tend to have the same class. In citation networks, papers citing each other tend to be in the same research area. The GCN leverages this assumption by smoothing features across edges, effectively propagating class information from labeled nodes to their unlabeled neighbors.

4.5 Strengths and Weaknesses of GCN

Strengths:

  • Simple to implement (two matrix multiplications per layer)
  • Efficient: O(md)O(m \cdot d) per layer for sparse graphs
  • Well-understood theoretically: spectral derivation, convergence analysis
  • Strong baselines for semi-supervised node classification
  • Easily extended with residual connections, BatchNorm, dropout

Weaknesses:

  • Transductive: the normalization A^\hat{A} is computed for the entire graph at training time. Adding a new node requires recomputing A^\hat{A} and retraining - the GCN cannot generalize to unseen nodes
  • Fixed symmetric aggregation: every neighbor is weighted by the same geometric mean of degrees. High-degree hub nodes get down-weighted regardless of their relevance
  • No edge features: the standard GCN layer has no mechanism to incorporate edge attributes
  • Expressiveness limited to 1-WL: cannot distinguish non-isomorphic graphs that the WL test cannot distinguish (7)
  • Over-smoothing at depth: after many layers, all node representations converge (8)

5. GraphSAGE: Inductive Learning on Large Graphs

5.1 The Inductive Learning Problem

GCN is transductive: it operates on a fixed graph seen during training. The normalized adjacency A^=D~1/2A~D~1/2\hat{A} = \tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2} encodes the training graph's structure. When new nodes arrive (a new user on a social platform, a new molecule to screen), the GCN cannot produce representations without recomputing the full graph and rerunning the entire network.

This is impractical for large, dynamic graphs. Pinterest has billions of users; new pins are added every second. Training a fresh GCN daily is infeasible.

Inductive learning solves this: learn an aggregation function that can be applied to the neighborhood of any node, including nodes unseen during training. If the function is parameterized correctly, it generalizes: apply it to the neighborhood of a new node and get a useful embedding without retraining.

GraphSAGE (Hamilton, Ying & Leskovec, 2017) - SAmple and aggreGatE - is the canonical inductive GNN framework.

5.2 GraphSAGE Framework

Key idea: instead of using the entire neighborhood of a node (which can be millions of neighbors for hub nodes), sample a fixed-size subset Sv[l]N(v)\mathcal{S}_v^{[l]} \subseteq \mathcal{N}(v) at each layer. Then aggregate only over the sampled set. The parameters of the aggregation function are shared across all nodes and all graphs.

Algorithm (2-layer GraphSAGE):

For each node vv in the target set (e.g., a mini-batch of nodes):

  1. Sample neighbors: Sv[1]U(N(v),S1)\mathcal{S}_v^{[1]} \sim \mathcal{U}(\mathcal{N}(v), S_1) - sample S1S_1 neighbors uniformly
  2. For each sampled neighbor uSv[1]u \in \mathcal{S}_v^{[1]}, sample their neighbors: Su[2]U(N(u),S2)\mathcal{S}_u^{[2]} \sim \mathcal{U}(\mathcal{N}(u), S_2)
  3. Compute depth-2 embeddings bottom-up:
    • hu[1]=σ ⁣(W[1]AGGREGATE ⁣({xw:wSu[2]}))\mathbf{h}_u^{[1]} = \sigma\!\left(W^{[1]} \cdot \operatorname{AGGREGATE}\!\left(\left\{\mathbf{x}_w : w \in \mathcal{S}_u^{[2]}\right\}\right)\right) for all uSv[1]u \in \mathcal{S}_v^{[1]}
    • hv[2]=σ ⁣(W[2][hv[1]AGGREGATE ⁣({hu[1]:uSv[1]})])\mathbf{h}_v^{[2]} = \sigma\!\left(W^{[2]} \cdot \left[\mathbf{h}_v^{[1]} \,\|\, \operatorname{AGGREGATE}\!\left(\left\{\mathbf{h}_u^{[1]} : u \in \mathcal{S}_v^{[1]}\right\}\right)\right]\right)
  4. Normalize: hv[2]hv[2]/hv[2]2\mathbf{h}_v^{[2]} \leftarrow \mathbf{h}_v^{[2]} / \lVert\mathbf{h}_v^{[2]}\rVert_2

Inductive inference: for a new node vv' with known features and edges (to training-time nodes), apply the same aggregation functions with the learned W[1],W[2]W^{[1]}, W^{[2]} - no retraining needed.

Sampling depth. A LL-layer GraphSAGE with neighborhood sizes (S1,S2,,SL)(S_1, S_2, \ldots, S_L) touches at most l=1LSl\prod_{l=1}^L S_l nodes per target node. With (S1,S2)=(25,10)(S_1, S_2) = (25, 10), at most 250 nodes are accessed per target node, regardless of graph size. This makes training on massive graphs feasible.

5.3 Aggregation Strategies

GraphSAGE defines three aggregation variants with different expressiveness-efficiency tradeoffs:

Mean aggregator:

mv[l]=1Sv[l]uSv[l]hu[l1]\mathbf{m}_v^{[l]} = \frac{1}{|\mathcal{S}_v^{[l]}|} \sum_{u \in \mathcal{S}_v^{[l]}} \mathbf{h}_u^{[l-1]} hv[l]=σ ⁣(W[l][hv[l1]mv[l]])\mathbf{h}_v^{[l]} = \sigma\!\left(W^{[l]} \cdot \left[\mathbf{h}_v^{[l-1]} \,\|\, \mathbf{m}_v^{[l]}\right]\right)

Similar to GCN but with uniform weights (no degree normalization). Concatenation rather than addition preserves the distinction between self and neighborhood.

LSTM aggregator:

mv[l]=LSTM ⁣([hπ(u)[l1]]uSv[l])\mathbf{m}_v^{[l]} = \operatorname{LSTM}\!\left(\left[\mathbf{h}_{\pi(u)}^{[l-1]}\right]_{u \in \mathcal{S}_v^{[l]}}\right)

Apply an LSTM over a random permutation π\pi of the sampled neighbors. The LSTM can model complex interactions between neighbors (e.g., in a sequence-like order), but the random permutation breaks the theoretical permutation invariance. In practice, the LSTM often performs best on benchmark tasks due to its representational power.

Max-pool aggregator:

mv[l]=maxuSv[l]σ ⁣(Wpoolhu[l1]+b)\mathbf{m}_v^{[l]} = \max_{u \in \mathcal{S}_v^{[l]}} \sigma\!\left(W_{\text{pool}} \mathbf{h}_u^{[l-1]} + \mathbf{b}\right)

Apply a learnable transformation to each neighbor's representation, then take the element-wise max. Detects the presence of specific feature patterns in the neighborhood. Often the best choice for tasks where the presence of a rare feature type is informative.

Choosing an aggregator. Empirically: LSTM aggregator performs best on benchmark accuracy but is not permutation invariant. Mean aggregator is permutation invariant and performs well on inductive tasks. Max-pool aggregator is best at detecting rare structural patterns. Practitioners typically tune on validation data.

GraphSAGE can be trained without labels using a link prediction objective:

L(u,v)=logσ ⁣(huhv)QEvnPn(V) ⁣[log ⁣(1σ ⁣(huhvn))]\mathcal{L}(u, v) = -\log \sigma\!\left(\mathbf{h}_u^\top \mathbf{h}_v\right) - Q \cdot \mathbb{E}_{v_n \sim P_n(V)}\!\left[\log\!\left(1 - \sigma\!\left(\mathbf{h}_u^\top \mathbf{h}_{v_n}\right)\right)\right]

where (u,v)(u, v) is a positive pair (connected in the graph), vnv_n is a negative sample (random non-neighbor), QQ is the number of negative samples, Pn(V)P_n(V) is the negative sampling distribution (typically uniform), and σ\sigma is the sigmoid function.

Intuition: nodes that co-occur in short random walks should have similar embeddings; nodes that are far apart should have dissimilar embeddings. This is the DeepWalk objective extended to learned GNN embeddings.

For AI: unsupervised GraphSAGE embeddings are used in PinSage (Pinterest's recommendation system) to generate embeddings for pins (images with descriptions) in a bipartite user-pin-board graph. These embeddings power Pinterest's "more like this" recommendations, serving 3+ billion monthly active users.

5.5 Scaling to Billion-Node Graphs

PinSage (Ying et al., 2018) adapted GraphSAGE to Pinterest's graph with 3 billion nodes and 18 billion edges - the largest graph neural network deployment at the time.

Key engineering innovations:

  1. Importance-based sampling: instead of uniform random sampling, sample neighbors proportional to their random walk visit frequency from the target node. This gives more weight to structurally important neighbors and reduces variance.
  2. Producer-consumer minibatch pipeline: precompute the node neighborhoods offline in Spark, stream them to GPU workers as training data. Decouples graph traversal (CPU-bound) from gradient computation (GPU-bound).
  3. Hard negative mining: the standard negative sampling produces easy negatives (random pins). Hard negatives are pins that are visually similar but semantically different - they force the model to learn finer-grained distinctions.
  4. Curriculum learning: train first with easy negatives, then add hard negatives progressively.

Deployment results: 150% improvement in head-to-head pin recommendation quality over previous collaborative filtering system, with recommendations updated daily across 3+ billion items.


Skill Check

Test this lesson

Answer 4 quick questions to lock in the lesson and feed your adaptive practice queue.

--
Score
0/4
Answered
Not attempted
Status
1

Which module does this lesson belong to?

2

Which section is covered in this lesson content?

3

Which term is most central to this lesson?

4

What is the best way to use this lesson for real learning?

Your answers save locally first, then sync when account storage is available.
Practice queue