Graph Neural Networks
I avoided graph neural networks for longer than I'd like to admit. Every time I saw a paper with "message passing" and adjacency matrices, I'd skim the abstract, nod wisely, and move on to something that fit neatly into a tensor. Images are grids. Text is a sequence. But then I kept running into data that was neither — social connections, molecular structures, recommendation systems — and the discomfort of not understanding how to learn from that kind of structure grew too great. Here is that dive.
Graph Neural Networks are a family of neural architectures designed to operate on graph-structured data — nodes connected by edges, with features living on both. The core ideas trace back to the mid-2000s, but the modern era began around 2017 with Kipf & Welling's GCN paper. Since then, GNNs have become the go-to tool for molecular property prediction, social network analysis, recommendation engines, and knowledge graph reasoning.
Before we start, a heads-up. We're going to be building graphs from scratch, working through matrix multiplications, and touching on attention mechanisms. You don't need to know any of it beforehand. We'll add the concepts we need one at a time, with explanation.
This isn't a short journey, but I hope you'll be glad you came.
What is a graph, really
The adjacency matrix
Node features and edge features
Neighborhood gossip — why structure matters
The message passing framework
GCN — the weighted average that started it all
Rest stop
GAT — paying attention to the right neighbors
GraphSAGE — sampling your way to scale
Three kinds of questions you can ask a graph
Graph pooling — squeezing a graph into a vector
Over-smoothing — the blur problem
What GNNs can't see — the Weisfeiler-Lehman wall
Where GNNs live in the real world
Tools of the trade — PyG and DGL
Wrap-up
Resources and credits
What is a graph, really
Let's start with something concrete. Imagine three friends — Alice, Bob, and Carol. Alice and Bob know each other. Bob and Carol know each other. Alice and Carol don't. That's a graph. Three nodes (the people) and two edges (the friendships).
A graph is nothing more than a collection of things and the connections between them. The things are nodes (sometimes called vertices). The connections are edges (sometimes called links). Unlike an image, which is a rigid grid of pixels, or a sentence, which is a strict left-to-right sequence, a graph has no inherent ordering. There's no "first" node. There's no spatial layout. There are only connections.
This is what makes graphs both powerful and awkward. Powerful because so much of the world is relational — atoms bonded in a molecule, users following each other on a platform, entities linked in a knowledge base. Awkward because most of our deep learning machinery assumes data arrives in neat, ordered arrays.
The adjacency matrix
We need a way to hand our little friendship graph to a computer. The standard approach is an adjacency matrix — a square grid where rows and columns both represent nodes, and a 1 in position (i, j) means node i is connected to node j.
For our three friends, labeling Alice = 0, Bob = 1, Carol = 2:
Alice Bob Carol
Alice [ 0 1 0 ]
Bob [ 1 0 1 ]
Carol [ 0 1 0 ]
This is an undirected graph — if Alice knows Bob, Bob knows Alice — so the matrix is symmetric. A 1 at row 0, column 1 (Alice→Bob) is mirrored at row 1, column 0 (Bob→Alice). If friendships were one-way — like Twitter follows — we'd have a directed graph and the matrix might not be symmetric.
I'll be honest — the first time I saw adjacency matrices, I thought they were wasteful. A 3×3 grid to store two friendships? But this format turns out to be perfect for the linear algebra operations that make GNNs work. Multiplying this matrix by a vector of node features is equivalent to gathering each node's neighbors' information. That observation is the entire foundation of what comes next.
Node features and edge features
A graph with only connectivity would be like a spreadsheet with only row labels and no data. In practice, each node carries a feature vector — a list of numbers describing it. For our friends, maybe it's their movie ratings:
# Each person rates three movies: Action, Comedy, Drama
Alice = [5, 1, 3] # loves action
Bob = [4, 4, 2] # balanced tastes
Carol = [1, 5, 4] # comedy and drama fan
We stack these into a node feature matrix X, where each row is one node's features. For our example, X is a 3×3 matrix — three nodes, each with three features.
Edges can carry features too. In a molecule, a bond between two atoms might have a feature indicating whether it's a single bond, double bond, or aromatic bond. In a social network, an edge might carry the number of messages exchanged. We won't lean heavily on edge features in this post, but knowing they exist is important — many real-world GNN applications use them.
Neighborhood gossip — why structure matters
Here's where things get interesting. Let's say we want to predict what Carol would think of a new action movie. If we only looked at Carol's features — her ratings of [1, 5, 4] — we'd guess she probably won't like it. Low action score.
But Carol is connected to Bob, and Bob likes action movies. If our friends influence each other's tastes (and in the real world, they do), then Carol's connection to Bob is a signal. Maybe she's more open to action than her own ratings suggest.
This is the fundamental insight behind GNNs: a node's identity is shaped not only by its own features, but by who it's connected to. It's like neighborhood gossip — you learn about someone by listening to their neighbors. And their neighbors' neighbors. The deeper you go, the more context you gather.
Think of it this way. If I told you someone's movie preferences, you'd know a little. If I also told you their three closest friends' preferences, you'd know quite a bit more. If I added their friends' friends, you'd have a rich picture of the social context that shapes their taste. That's exactly what a GNN does, layer by layer.
The question is: how do we do this mathematically?
The message passing framework
Every modern GNN — every single one — is a variation of the same three-step loop. The loop runs once per layer, and it goes like this.
First, each node prepares a message. This is typically the node's current feature vector, possibly transformed by a weight matrix. In our movie example, Alice would prepare her ratings [5, 1, 3] as her message.
Second, each node aggregates the messages from its neighbors. "Aggregate" means "combine into a single summary." The aggregation function might be a sum, a mean, or a max — but it must be permutation-invariant. That's a crucial detail. Because graphs have no inherent node ordering, the result can't change if we shuffle the order of the neighbors. Sum doesn't care about order. Mean doesn't care about order. A function that reads neighbors left-to-right would care, and that would break everything.
Third, each node updates its own representation by combining its current features with the aggregated neighbor message, usually through a learnable weight matrix and a nonlinearity like ReLU.
Let's trace through this concretely. Bob is connected to both Alice and Carol. In the message step, Alice sends [5, 1, 3] and Carol sends [1, 5, 4]. Bob aggregates by averaging: [(5+1)/2, (1+5)/2, (3+4)/2] = [3, 3, 3.5]. Then Bob combines this with his own features [4, 4, 2] through a weight matrix and a nonlinearity. The exact combination depends on the architecture, but the shape of the dance is always the same: message, aggregate, update.
That's the whole framework. If you understand this loop, you understand the skeleton of every GNN architecture ever published. The differences between GCN, GAT, and GraphSAGE are differences in how they execute these three steps — how they weight messages, how they aggregate, how they update. The skeleton is identical.
There's a beautiful connection to matrix multiplication here. If we take our adjacency matrix A and multiply it by our node feature matrix X, the result is a new matrix where each row contains the sum of that node's neighbors' features. That single matrix multiplication — A × X — is the message passing step.
# A is the adjacency matrix, X is the node feature matrix
# A @ X gathers neighbor features for every node simultaneously
A = [[0, 1, 0], X = [[5, 1, 3], A @ X = [[4, 4, 2],
[1, 0, 1], [4, 4, 2], [6, 6, 7],
[0, 1, 0]] [1, 5, 4]] [4, 4, 2]]
# Row 0 (Alice): got Bob's features [4,4,2]
# Row 1 (Bob): got Alice's [5,1,3] + Carol's [1,5,4] = [6,6,7]
# Row 2 (Carol): got Bob's features [4,4,2]
One matrix multiplication, and every node has gathered its neighbors' information. Stack a learnable weight matrix W after that, add a nonlinearity, and you've got yourself a GNN layer.
GCN — the weighted average that started it all
In 2017, Thomas Kipf and Max Welling published a paper that turned GNNs from a theoretical curiosity into a practical tool. Their insight was to take a complex spectral theory of graph convolutions — involving eigendecompositions of the graph Laplacian and Chebyshev polynomial approximations — and simplify it down to one elegant layer rule.
The backstory is worth a detour. Mathematicians had figured out how to define "convolution" on graphs by working in the frequency domain — transforming node signals using the graph's eigenvectors, applying a filter, and transforming back. Sounds expensive because it is. It required computing the eigendecomposition of the graph Laplacian, which is impractical for any reasonably sized graph. Defferrard et al. (2016) made it cheaper by approximating the spectral filter with Chebyshev polynomials up to order K. Kipf and Welling looked at that and asked: what if K = 1? What if we use the crudest possible approximation?
What fell out was this layer rule: take each node's features, average them with its neighbors' features (including itself), multiply by a learnable weight matrix, and apply a nonlinearity. That's the entire Graph Convolutional Network.
In formula form: H(l+1) = σ( D̃−½ Ã D̃−½ H(l) W(l) )
Let me unpack every symbol. H(l) is the matrix of node features at layer l — each row is one node's representation. W(l) is the learnable weight matrix for that layer. σ is a nonlinearity, typically ReLU. à is A + I — the adjacency matrix with self-loops added, so each node also considers its own features during aggregation. D̃ is the degree matrix of à — a diagonal matrix where each entry D̃ii counts how many connections node i has (including the self-loop). The D̃−½ terms on either side perform symmetric normalization.
Why does the normalization matter? Without it, a node with 500 neighbors would have aggregated features 500 times larger than a node with 2 neighbors. That blows up the numerical scale and makes training unstable. The symmetric normalization — dividing by the square root of both the source and target degrees — keeps everything on a comparable scale. I'll be honest: I'm still developing my intuition for why the symmetric version (dividing by both sides) works better than simpler alternatives like dividing by degree alone. The mathematical justification traces back to the spectral theory, but in practice, it produces more stable gradients and better performance.
Let's walk through GCN on our three-friend graph. We add self-loops, so à becomes:
à = [[1, 1, 0], # Alice: connected to self + Bob (degree 2)
[1, 1, 1], # Bob: connected to self + Alice + Carol (degree 3)
[0, 1, 1]] # Carol: connected to self + Bob (degree 2)
After symmetric normalization (each entry divided by √(deg_i × deg_j)), the normalized matrix gives Bob's features slightly less weight when aggregated into Alice's representation — because Bob is more "popular" (higher degree), his signal gets diluted. This is by design. Without it, hub nodes would dominate the representations of everyone they touch.
The power of GCN is its simplicity. One matrix multiply to gather neighbors, one matrix multiply to transform features, one nonlinearity. The weakness is equally clear: every neighbor gets the same weight (after normalization). For many problems, that's fine. For problems where some connections matter far more than others, it's not.
Rest stop
Congratulations on making it this far. You can stop if you want.
You now have a working mental model of how GNNs function: data lives on a graph, nodes pass messages to their neighbors, those messages are aggregated (in a way that doesn't care about order), and the result is combined with the node's own features to create updated representations. GCN is the simplest version of this — average your neighbors, transform, apply a nonlinearity. Stack a couple of layers and each node "sees" two hops out into the graph.
That mental model is genuinely useful. It covers perhaps 60% of what you'd need for a conversation about GNNs in a design review or an interview.
But it doesn't tell the complete story. GCN treats every neighbor equally, which is a real limitation. It also requires the entire graph at training time, which breaks down at scale. And there's a fundamental trap waiting if you naively stack too many layers. If the discomfort of not knowing about those things is nagging at you, read on.
GAT — paying attention to the right neighbors
GCN averages all neighbors equally. But think about it in our friend analogy: if you're trying to predict what Carol will like, does it matter whether her friend is a movie critic or someone who watches one movie a year? Of course it does. Some neighbors are more informative than others, and the network should learn which ones to listen to.
This is the motivation behind Graph Attention Networks, introduced by Veličković et al. in 2018. The idea is borrowed directly from the attention mechanism in Transformers — but with a critical difference. In a Transformer, every token attends to every other token (global attention). In a GAT, each node attends only to its immediate neighbors (local attention, dictated by the graph structure). This makes GAT sparse and efficient.
Here's how it works, step by step, for a single node. Say we want to compute Bob's new representation. Bob has two neighbors: Alice and Carol.
First, we project everyone's features through a shared weight matrix W. Bob's features become WhBob, Alice's become WhAlice, Carol's become WhCarol.
Next, we compute an attention score for each neighbor pair. For the (Bob, Alice) pair, we concatenate WhBob and WhAlice into one long vector, then take the dot product with a learnable vector a, and apply LeakyReLU. The result, eBob,Alice, is a raw score indicating how much Alice's features matter to Bob.
We do the same for (Bob, Carol) to get eBob,Carol. Then we normalize both scores with softmax across Bob's neighbors, so they sum to one. These normalized scores — αBob,Alice and αBob,Carol — are the attention weights. Bob's updated representation is the weighted sum: αBob,Alice · WhAlice + αBob,Carol · WhCarol.
If you've seen Transformer attention, this feels familiar but different. There's no separate query/key/value projection. Instead, GAT uses a single shared weight matrix and a separate attention vector. It's a leaner mechanism, partly because graph neighborhoods are typically small (tens of neighbors, not thousands of tokens).
One practical note: GAT uses multi-head attention, the same trick that stabilizes Transformers. Multiple independent attention heads each compute their own weights, and the results are concatenated (or averaged in the final layer). This helps the model capture different kinds of relationships simultaneously.
There's a subtle but important limitation in the original GAT that was discovered later. The attention mechanism computes scores using a concatenation followed by a dot product with a fixed vector — and it turns out this means the ranking of attention scores can be the same regardless of the query node. Brody et al. (2021) fixed this in GATv2 by applying the nonlinearity after the concatenation, making the attention truly dynamic. For any new project, GATv2 is the one to use.
GraphSAGE — sampling your way to scale
Both GCN and GAT share a fundamental assumption: the entire graph is available during training. Every node participates in message passing, even the unlabeled ones. For a citation network with a few thousand nodes, that's fine. For Pinterest's graph with billions of pins and users, it's a non-starter.
Hamilton, Ying, and Leskovec (2017) tackled this with GraphSAGE — SAmple and aggreGatE. The name captures the two core ideas.
The first idea is inductive learning. Instead of learning a fixed embedding for each node (which would require seeing every node during training and couldn't handle new nodes), GraphSAGE learns an aggregation function. When a new user signs up on the platform, we feed their neighborhood through the same learned function and get an embedding immediately. No retraining needed.
The second idea is neighborhood sampling. Instead of aggregating all neighbors — which for a hub node could mean millions of connections — GraphSAGE randomly samples a fixed number at each layer. Say 25 neighbors at the first hop and 10 at the second hop. This caps the computation per node and makes mini-batch training possible. The sampling introduces variance (you don't always see the same neighbors), but in practice this acts like a form of regularization and works well.
Think back to our gossip analogy. GCN and GAT are like asking every single person in the neighborhood what they think. GraphSAGE is like asking a random handful of people at the coffee shop. You don't hear from everyone, but if you ask enough times (enough training iterations), you get the gist.
GraphSAGE offers several aggregation functions: mean (the default, and surprisingly hard to beat), LSTM (they shuffle the neighbor order each time since LSTMs care about order — a clever hack), and max-pool (apply an MLP to each neighbor, then take the element-wise maximum). The mean aggregator is what most practitioners use.
The practical impact is enormous. GraphSAGE-style sampling is what made GNNs deployable at companies like Pinterest, Uber, and Twitter. If you ever need to train a GNN on a graph that doesn't fit in GPU memory, this is where you start.
Three kinds of questions you can ask a graph
We've been focused on how GNNs compute node representations, but what do we actually do with those representations? There are three fundamentally different types of tasks, and they differ in what entity gets a prediction.
Node-level tasks. The most common setup. Given a partially-labeled graph, predict labels for the unlabeled nodes. This is semi-supervised learning — message passing runs across the entire graph (including unlabeled nodes, which still participate in neighbor aggregation), but the loss function only looks at the nodes that have labels. Think: classifying academic papers by topic, detecting fraudulent accounts in a transaction network, predicting protein functions in a biological interaction graph.
Edge-level tasks. The question here is: which edges are missing? Or which edges will form next? Run the GNN to get node embeddings, then score candidate node pairs — the simplest decoder is a dot product between two node embeddings. A high score means "these two nodes should probably be connected." This powers "people you may know" features on social platforms and knowledge graph completion (predicting missing facts like "who directed this movie?"). One gotcha that bites beginners: if you include the target edge in the graph during message passing, the model can trivially detect it. Always remove the edge you're trying to predict.
Graph-level tasks. One prediction for an entire graph. Is this molecule toxic? Is this program's control flow graph malicious? This requires an extra step beyond message passing: pooling.
Graph pooling — squeezing a graph into a vector
Node-level tasks are straightforward — each node already has an embedding after message passing. But for graph-level tasks, we need a single vector that represents the whole graph. Pooling is how we get there.
The simplest approaches are global mean pooling (average all node embeddings) and global sum pooling (add them up). No learnable parameters, fast, and surprisingly effective for many tasks. Sum pooling has a theoretical advantage: it can distinguish graphs with different numbers of nodes that have the same average features. Mean pooling can't.
When you need something more expressive, there are learned pooling methods. DiffPool (Ying et al., 2018) learns a soft cluster assignment matrix — it groups nodes into supernodes at each pooling step, creating a hierarchy. The coarsened graph goes through more GNN layers. It's powerful but memory-hungry. SAGPool uses attention scores to select the top-k most important nodes and drops the rest, creating a sparser graph for the next layer.
For most practical purposes, global mean or sum pooling works well. The learned methods shine on benchmarks with highly structured graphs where hierarchy matters — certain molecular and biological datasets — but they add complexity and computational cost.
Over-smoothing — the blur problem
Here's the trap I promised earlier. Each GNN layer lets every node absorb information from its immediate neighbors. Two layers: friends-of-friends. Three layers: three hops out. The natural instinct is to stack more layers, like we do with deep CNNs or Transformers. Deeper network, richer representations, right?
Wrong. With GNNs, going deeper makes representations worse after a certain point. This is called over-smoothing, and the analogy that helped it click for me is blurring a photograph.
Every GNN layer is like applying a blur filter to an image. One pass gently smooths out noise while keeping the important features visible. Two passes and you start losing some detail. Five passes and everything is a washed-out smudge. Ten passes and every pixel converges to the same average gray.
The mathematics bear this out precisely. Repeatedly multiplying by the normalized adjacency matrix drives all node representations toward the dominant eigenvector — a space where every node looks the same. In a small-world graph (like most social networks), five hops can reach nearly every node, meaning five layers of message passing gives every node access to almost the entire graph. At that point, all nodes have effectively the same information, and the representations collapse.
In practice, 2–3 layers is the sweet spot for most GNNs. My favorite thing about this fact is that, aside from the mathematical explanation I gave, no one is completely certain why such shallow networks capture enough of the graph structure to work well on most tasks. The standard explanation is that many real-world graphs have small diameters, so two hops already reaches most of the relevant context.
Mitigations exist. Residual connections (add the input of each layer to the output, like ResNet) help preserve each node's individual identity. DropEdge randomly removes edges during training, slowing the mixing. JumpingKnowledge networks concatenate the representations from all layers, letting the final classifier pick which depth was most useful for each node. PairNorm explicitly pushes node representations apart during normalization.
None of these fully solve the problem. If you genuinely need long-range dependencies — information that has to travel ten or twenty hops — the emerging approach is Graph Transformers, which allow global attention between all nodes, sidestepping the layer-by-layer propagation bottleneck. But that's a different chapter.
What GNNs can't see — the Weisfeiler-Lehman wall
There's a deeper limitation that took me by surprise. It's not about depth — it's about what message-passing GNNs are fundamentally capable of distinguishing.
In 1968, mathematicians Boris Weisfeiler and Andrei Lehman developed a test to check if two graphs are structurally identical (isomorphic). The test, now called the 1-WL test (or color refinement), works like this: give every node an initial "color" based on its features. Then, iteratively update each node's color based on the multiset of its neighbors' colors. If, after some rounds, the two graphs have different distributions of colors, they're definitely not isomorphic. If the distributions match, they might be isomorphic — the test can't always tell.
Here's the punchline: Xu et al. (2019) proved that standard message-passing GNNs are at most as powerful as the 1-WL test. The iterative "gather neighbors, update" loop of a GNN is doing exactly the same computation as color refinement. Any pair of graphs that fools the 1-WL test will also fool any message-passing GNN.
What kinds of graphs fool it? Highly symmetric ones. Two different six-node graphs where every node has exactly the same number of neighbors can be indistinguishable to 1-WL (and therefore to GNNs), even though a human can see they're structured differently. For most practical applications — molecules, social networks, recommendations — this ceiling rarely matters. But for certain molecular structures where subtle symmetry differences affect chemical properties, it's a real concern.
The same paper introduced GIN (Graph Isomorphism Network), a GNN architecture specifically designed to be maximally powerful — matching the 1-WL test exactly. The key design choice is using sum aggregation (not mean or max) combined with an injective MLP update function. GCN and GraphSAGE with mean aggregation are strictly less powerful because mean and max can map different multisets to the same output.
Getting beyond the 1-WL barrier requires going beyond standard message passing — higher-order GNNs that consider groups of nodes together, or architectures that incorporate subgraph structures. These exist but are computationally expensive. For most practitioners, knowing the ceiling exists and reaching for GIN when expressiveness matters is enough.
Where GNNs live in the real world
The theory is nice, but where do GNNs actually earn their keep?
Molecules. This is where GNNs arguably have the most natural fit. An atom is a node, a bond is an edge, and the task is predicting molecular properties — toxicity, solubility, binding affinity, reaction outcomes. The graph structure isn't a convenient abstraction here; it's literally what the molecule is. Drug discovery pipelines at pharmaceutical companies and startups use GNNs for virtual screening — narrowing down millions of candidate compounds to a handful worth synthesizing and testing in a lab. GNNs have also been used for retrosynthesis — given a target molecule, figuring out what reactions would produce it.
Social networks. Fraud detection is a major application. Fraudulent accounts tend to cluster together (they follow each other, interact with the same targets), and GNNs are excellent at detecting these community patterns. Pinterest's PinSage system (built on GraphSAGE) generates embeddings for billions of pins by treating the user-pin interaction graph as the data structure. Influence prediction — figuring out which users are most likely to spread information — is another natural fit.
Recommendation systems. At scale, recommendations are a graph problem. Users and items are nodes. Interactions (purchases, clicks, ratings) are edges. GNNs propagate preference signals through this bipartite graph. LightGCN, a simplified GCN variant designed specifically for recommendations, showed that stripping away feature transformations and nonlinearities and keeping only the neighborhood aggregation actually improves recommendation quality. Sometimes simpler is better.
Knowledge graphs. A knowledge graph stores facts as (entity, relation, entity) triples — (Berlin, capital_of, Germany). The graph is massive and incomplete. GNNs help with link prediction on these graphs: given existing facts, predict missing ones. R-GCN (Relational GCN) was one of the first architectures to handle multiple relation types. Knowledge graph completion powers question answering systems, search engines, and personal assistants.
Beyond these four, GNNs show up in traffic prediction (road networks), combinatorial optimization (solving graph-structured problems like TSP), program analysis (code as a graph of functions and calls), and physics simulation (particles as nodes, interactions as edges). The common thread: whenever the relationships between entities carry as much signal as the entities themselves, GNNs are worth considering.
Tools of the trade — PyG and DGL
Two frameworks dominate GNN development: PyTorch Geometric (PyG) and the Deep Graph Library (DGL).
PyG feels like home if you're already a PyTorch user. The API mirrors PyTorch conventions — you define layers, write a forward method, run a training loop. Graphs are stored as an edge_index tensor (a 2×num_edges array listing source-target pairs) plus a node feature matrix x. Everything is sparse by default. Swapping GCN for GAT or GraphSAGE is often a one-line change.
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
# Cora: 2708 papers, 10556 citation edges, 7 subject classes
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0]
class TwoLayerGCN(torch.nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.conv1 = GCNConv(in_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, out_dim)
def forward(self, x, edge_index):
# First message passing layer + nonlinearity
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.5, training=self.training)
# Second message passing layer (raw logits for classification)
return self.conv2(x, edge_index)
model = TwoLayerGCN(dataset.num_features, 64, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
for epoch in range(200):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
# Loss only on labeled nodes — semi-supervised setup
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
The training loop looks exactly like any other PyTorch classification task. The only GNN-specific detail is that forward takes both node features and edge connectivity, and the loss is computed only on the nodes that have labels (the train_mask). The unlabeled nodes still participate in message passing — their features flow through the graph and influence the labeled nodes' representations — but they don't contribute to the loss.
For graphs too large to fit in memory, PyG provides NeighborLoader, which implements GraphSAGE-style sampling:
from torch_geometric.loader import NeighborLoader
# Sample 15 first-hop, 10 second-hop neighbors per target node
loader = NeighborLoader(
data, num_neighbors=[15, 10], batch_size=128, shuffle=True
)
for batch in loader:
out = model(batch.x, batch.edge_index)
loss = F.cross_entropy(
out[:batch.batch_size], batch.y[:batch.batch_size]
)
For batching multiple small graphs (common in molecular tasks), PyG stacks them into one large disconnected graph and tracks which node belongs to which graph via a batch vector. This lets you use global pooling — global_mean_pool(x, batch) — to get one vector per graph.
DGL (backed by AWS) takes a different approach. Its message-passing API is more explicit — you define message functions and reduce functions that operate on edges and nodes. This is more verbose but gives finer control, which matters for heterogeneous graphs (multiple node and edge types) and for distributed training across multiple machines. If you're building a production system at scale — think billions of edges, multiple GPUs, real-time updates — DGL is often the better choice. For research, prototyping, and standard benchmarks, PyG is where most practitioners start.
Wrap-up
If you're still with me, thank you. I hope it was worth it.
We started with three friends and their movie ratings — a graph so tiny you could draw it on a napkin. We turned it into an adjacency matrix and a feature matrix, and watched what happens when you multiply them together. From that single observation — matrix multiplication gathers neighbor information — we built up the entire message passing framework. GCN emerged as the simplest instantiation: average your neighbors, transform, apply ReLU. GAT added the ability to pay attention to the right neighbors. GraphSAGE made it all work at scale by sampling. We saw the three task types (node, edge, graph), the over-smoothing trap that limits depth, and the Weisfeiler-Lehman ceiling that limits expressiveness.
My hope is that the next time you encounter a dataset where the relationships between entities carry as much signal as the entities themselves — a social graph, a molecule, a knowledge base — instead of awkwardly flattening it into a feature table and losing the structure, you'll reach for a GNN, having a pretty good mental model of what's happening under the hood when those messages start flowing.
Resources and credits
Kipf & Welling, "Semi-Supervised Classification with Graph Convolutional Networks" (2017) — the O.G. paper that made GNNs practical. Remarkably readable for an academic paper.
Hamilton, Ying & Leskovec, "Inductive Representation Learning on Large Graphs" (2017) — the GraphSAGE paper. The sampling idea that made billion-node GNNs possible.
Xu et al., "How Powerful are Graph Neural Networks?" (2019) — the paper that connected GNNs to the Weisfeiler-Lehman test and introduced GIN. Changed how the field thinks about expressiveness.
Veličković et al., "Graph Attention Networks" (2018) — attention on graphs. The GATv2 follow-up by Brody et al. (2021) is the version you want to use in practice.
Stanford CS224W (Jure Leskovec) — wildly helpful course materials, freely available. The lecture videos are some of the clearest explanations of graph ML I've encountered.
PyTorch Geometric documentation and tutorials — the fastest way to go from "I understand the concept" to "I have a working model." The Colab notebooks are particularly useful for getting started without any setup.