import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using: {device}")
1. Message Passing FrameworkΒΆ
Graph RepresentationΒΆ
\(G = (V, E)\) with node features \(X \in \mathbb{R}^{n \times d}\)
Adjacency matrix \(A \in \{0,1\}^{n \times n}\)
Message PassingΒΆ
Aggregate: Combine neighbor features
Update: Combine with own features
π Reference Materials:
graph_cnn.pdf - Graph Cnn
2. Graph Convolutional Network (GCN)ΒΆ
Spectral MotivationΒΆ
Laplacian: \(L = D - A\)
Normalized: \(\tilde{L} = I - D^{-1/2}AD^{-1/2}\)
GCN LayerΒΆ
where \(\tilde{A} = A + I\) (add self-loops), \(\tilde{D}_{ii} = \sum_j \tilde{A}_{ij}\)
Advanced Message Passing TheoryΒΆ
1. Theoretical FoundationsΒΆ
Permutation Invariance Property:
Graph neural networks must satisfy permutation invariance: swapping node indices shouldnβt change the output structure.
For node embeddings: \(f(PX, PAP^T) = Pf(X, A)\) where \(P\) is a permutation matrix.
For graph-level predictions: \(f(PX, PAP^T) = f(X, A)\) (permutation equivariant nodes β invariant graph)
Proof: Message passing with symmetric aggregation (sum, mean, max) preserves this property:
Aggregation: \(\text{AGG}(\{h_u : u \in \mathcal{N}(v)\})\) is order-independent
Update: Applied identically to all nodes
Therefore: \(h_{Pv}^{(l+1)} = \text{UPDATE}(h_{Pv}^{(l)}, \text{AGG}(\{h_{Pu}^{(l)}\}))\) for any permutation \(P\)
Weisfeiler-Lehman Isomorphism Test:
The expressive power of message passing GNNs is bounded by the 1-WL test:
Initialize: \(h_v^{(0)} = x_v\) (node features)
Iterate: \(h_v^{(l+1)} = \text{HASH}(h_v^{(l)}, \{\{h_u^{(l)} : u \in \mathcal{N}(v)\}\})\)
Two graphs are 1-WL equivalent if they produce the same multiset of node labels
Limitations: GNNs cannot distinguish certain graph structures (e.g., regular graphs with same degree).
Higher-order GNNs: k-WL test uses tuples of nodes β more expressive but computationally expensive.
2. Spectral Graph Theory FoundationsΒΆ
Graph Laplacian:
Combinatorial Laplacian: \(L = D - A\)
Properties:
Symmetric positive semi-definite: \(x^T L x = \frac{1}{2}\sum_{i,j} A_{ij}(x_i - x_j)^2 \geq 0\)
Eigenvalues: \(0 = \lambda_1 \leq \lambda_2 \leq ... \leq \lambda_n \leq 2d_{\max}\)
Smallest eigenvalue \(\lambda_1 = 0\) with eigenvector \(\mathbf{1}\) (constant)
Number of zero eigenvalues = number of connected components
Normalized Laplacian:
Eigenvalues: \(0 \leq \lambda_i \leq 2\) (better numerical properties)
Graph Fourier Transform:
For signal \(x \in \mathbb{R}^n\) on graph nodes:
Forward: \(\hat{x} = U^T x\) where \(U\) contains eigenvectors of \(L\)
Inverse: \(x = U\hat{x}\)
Graph convolution in spectral domain: $\(g *_G x = U(U^T g \odot U^T x)\)$
where \(\odot\) is element-wise product.
Problem: Computing eigenvectors is \(O(n^3)\), impractical for large graphs.
3. Spectral Convolution ApproximationsΒΆ
ChebNet (Defferrard et al., 2016):
Approximate spectral filters using Chebyshev polynomials:
where:
\(\Lambda\): diagonal matrix of eigenvalues
\(T_k\): Chebyshev polynomial of order \(k\)
\(\tilde{\Lambda} = \frac{2}{\lambda_{\max}}\Lambda - I\) (rescaled to \([-1, 1]\))
Recursive computation: \(T_k(x) = 2xT_{k-1}(x) - T_{k-2}(x)\) with \(T_0(x)=1, T_1(x)=x\)
This avoids eigendecomposition! Apply to Laplacian directly:
where \(\tilde{L} = \frac{2}{\lambda_{\max}}L - I\)
Complexity: \(O(K|E|)\) where \(K\) is filter order, \(|E|\) is number of edges.
Localization: \(K\)-order filter uses \(K\)-hop neighborhood (spatially localized).
GCN as First-Order Approximation:
Kipf & Welling (2017) showed GCN is a special case with \(K=1, \lambda_{\max}=2\):
Simplification: Set \(\theta = \theta_0 = -\theta_1\) to reduce parameters:
Renormalization trick (add self-loops):
Why it works:
Symmetric normalization: \(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}\) has eigenvalues in \([0, 2]\)
Prevents numerical instability from repeated multiplication
Self-loops preserve nodeβs own features
4. GraphSAGE: Inductive LearningΒΆ
Problem with GCN: Transductive (requires full graph during training).
GraphSAGE Solution (Hamilton et al., 2017):
Sample and aggregate from fixed-size neighborhoods:
where \(\mathcal{S}(\mathcal{N}(v))\) is a sampled subset of neighbors.
Aggregation Functions:
Mean: \(\text{AGG} = \frac{1}{|\mathcal{S}|}\sum_{u \in \mathcal{S}} h_u^{(l)}\)
Simple, symmetric
GCN is special case with full neighborhood
LSTM: \(\text{AGG} = \text{LSTM}(\{h_u^{(l)}\})\)
More expressive
Requires random permutation (not truly permutation invariant)
Pooling: \(\text{AGG} = \max(\{\sigma(W_{\text{pool}}h_u^{(l)} + b) : u \in \mathcal{S}\})\)
Element-wise max after MLP
Captures diverse set of features
Neighbor Sampling:
For \(L\)-layer model, uniformly sample \(S_l\) neighbors at each layer:
Computational complexity: \(O(S_1 \cdot S_2 \cdots S_L)\) per node
Fixed computational budget (vs. \(O(d^L)\) for full neighborhood)
Training:
Unsupervised: Graph-based loss encourages nearby nodes to have similar embeddings:
where \(v\) is node co-occurring on random walk, \(v_n\) is negative sample.
Supervised: Standard cross-entropy for node classification.
5. Graph Attention Networks (GAT)ΒΆ
Motivation: Not all neighbors are equally important.
Attention Mechanism (VeliΔkoviΔ et al., 2018):
Compute attention coefficients:
Typically: \(a([Wh_i \| Wh_j]) = \text{LeakyReLU}(\mathbf{a}^T [Wh_i \| Wh_j])\)
where \(\mathbf{a} \in \mathbb{R}^{2F'}\) is learnable attention vector.
Normalization:
Aggregation:
Multi-Head Attention:
Concatenate \(K\) attention heads:
For final layer, average instead:
Properties:
Computationally efficient: \(O(|V|FF' + |E|F')\) per layer
Applicable to different graph structures (no normalization by degree needed)
Implicitly specifies different weights to different neighbors
Can be applied inductively to unseen graphs
Comparison to Self-Attention in Transformers:
GAT: Attention over graph neighbors (sparse)
Transformer: Attention over all positions (dense)
GAT complexity: \(O(|E|)\) vs. Transformer: \(O(n^2)\)
6. Graph Pooling StrategiesΒΆ
Problem: Generate graph-level representation from node embeddings.
1. Global Pooling:
Simple aggregation over all nodes:
Mean: \(h_G = \frac{1}{|V|}\sum_{v \in V} h_v\)
Max: \(h_G = \max_{v \in V} h_v\) (element-wise)
Sum: \(h_G = \sum_{v \in V} h_v\)
Permutation invariant but loses structural information.
2. Set2Set (Vinyals et al., 2015):
Use LSTM with attention to aggregate:
Read graph multiple times, final state is graph embedding.
3. DiffPool (Ying et al., 2018):
Learnable hierarchical pooling:
Cluster assignment matrix \(S^{(l)}\): each row assigns node to one of \(n_{l+1}\) clusters.
Coarsened graph:
where \(Z^{(l)} = \text{GNN}_{embed}^{(l)}(A^{(l)}, X^{(l)})\) are node embeddings.
Auxiliary losses:
Link prediction: \(L_{LP} = \|A, SS^T\|_F\) (preserve graph structure)
Entropy: \(L_E = -\frac{1}{n}\sum_{i,j} S_{ij}\log(S_{ij})\) (encourage confident assignments)
4. TopK Pooling (Gao & Ji, 2019):
Score-based selection:
Keep top \(k\) nodes based on learned projection \(p\), gate features by score.
7. Training Techniques and ConsiderationsΒΆ
Over-smoothing Problem:
With many layers, node embeddings become indistinguishable:
Causes:
Repeated averaging of neighbor features
Graph Laplacian has eigenvalue 0
Solutions:
Residual connections: \(h^{(l+1)} = h^{(l)} + \text{GNN}^{(l)}(h^{(l)}, A)\)
Jumping knowledge: Concatenate representations from all layers
PairNorm: Normalize to maintain distance between nodes
Depth-adaptive: Learn to stop propagation dynamically
Graph Normalization:
BatchNorm: Normalize across batch, doesnβt work well for graphs
GraphNorm: Normalize each graph separately: \(h_i = \frac{h_i - \mu_G}{\sqrt{\sigma_G^2 + \epsilon}}\)
Layer normalization: Normalize each nodeβs features
Mini-batch Training:
Neighbor sampling (GraphSAGE):
Sample fixed number of neighbors per layer
Trade-off: accuracy vs. computational cost
Cluster-GCN:
Partition graph into clusters
Sample cluster as mini-batch (preserve local structure)
8. Comparison TableΒΆ
Method |
Aggregation |
Complexity |
Inductive |
Attention |
Year |
|---|---|---|---|---|---|
GCN |
Mean (spectral) |
\(O(|E|F)\) |
β |
β |
2017 |
GraphSAGE |
Mean/Max/LSTM |
\(O(S^L F)\) |
β |
β |
2017 |
GAT |
Weighted sum |
\(O(|E|F)\) |
β |
β |
2018 |
GIN |
Sum |
\(O(|E|F)\) |
β |
β |
2019 |
ChebNet |
Chebyshev poly |
\(O(K|E|F)\) |
β |
β |
2016 |
When to use:
GCN: Small-medium graphs, transductive setting, baseline
GraphSAGE: Large graphs, inductive learning, scalability
GAT: Heterogeneous graphs, need interpretability
GIN: Maximum expressiveness (1-WL equivalent)
Spectral methods: Regular graph structure, frequency analysis
9. Applications and BenchmarksΒΆ
Node Classification:
Cora, CiteSeer, PubMed (citation networks)
Reddit (social network)
PPI (protein-protein interaction)
Graph Classification:
MUTAG, PROTEINS (molecular properties)
COLLAB, IMDB (social networks)
NCI1 (chemical compounds)
Link Prediction:
Knowledge graphs (FB15k, WN18)
Recommendation systems
Graph Generation:
Molecule design (ZINC, QM9)
Social network synthesis
Common challenges:
Label scarcity (semi-supervised learning)
Heterogeneous graphs (multiple node/edge types)
Dynamic graphs (temporal evolution)
Scaling to billions of nodes
class GCNLayer(nn.Module):
"""Graph Convolutional Layer."""
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, X, A):
# Add self-loops
A_hat = A + torch.eye(A.size(0), device=A.device)
# Degree matrix
D_hat = torch.diag(A_hat.sum(1))
D_hat_inv_sqrt = torch.pow(D_hat, -0.5)
D_hat_inv_sqrt[torch.isinf(D_hat_inv_sqrt)] = 0
# Normalized adjacency
A_norm = D_hat_inv_sqrt @ A_hat @ D_hat_inv_sqrt
# Propagate
return self.linear(A_norm @ X)
class GCN(nn.Module):
"""2-layer GCN."""
def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.5):
super().__init__()
self.gc1 = GCNLayer(in_dim, hidden_dim)
self.gc2 = GCNLayer(hidden_dim, out_dim)
self.dropout = dropout
def forward(self, X, A):
h = F.relu(self.gc1(X, A))
h = F.dropout(h, p=self.dropout, training=self.training)
return self.gc2(h, A)
# Test
n_nodes = 10
in_dim = 5
hidden_dim = 16
out_dim = 3
X = torch.randn(n_nodes, in_dim)
A = torch.randint(0, 2, (n_nodes, n_nodes)).float()
A = (A + A.T) / 2 # Make symmetric
model = GCN(in_dim, hidden_dim, out_dim)
out = model(X, A)
print(f"Output shape: {out.shape}")
3. Graph Attention Network (GAT)ΒΆ
Attention MechanismΒΆ
Compute attention coefficients:
UpdateΒΆ
Multi-head: Average or concatenate multiple attention heads.
class GATLayer(nn.Module):
"""Graph Attention Layer."""
def __init__(self, in_features, out_features, n_heads=1, dropout=0.6, alpha=0.2):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.n_heads = n_heads
self.dropout = dropout
self.W = nn.Parameter(torch.zeros(n_heads, in_features, out_features))
self.a = nn.Parameter(torch.zeros(n_heads, 2 * out_features, 1))
self.leakyrelu = nn.LeakyReLU(alpha)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.W)
nn.init.xavier_uniform_(self.a)
def forward(self, X, A):
N = X.size(0)
# Linear transformation for each head
h = torch.matmul(X.unsqueeze(0), self.W) # (n_heads, N, out_features)
# Attention mechanism
h_i = h.unsqueeze(2).repeat(1, 1, N, 1) # (n_heads, N, N, out_features)
h_j = h.unsqueeze(1).repeat(1, N, 1, 1)
# Concatenate
h_cat = torch.cat([h_i, h_j], dim=-1) # (n_heads, N, N, 2*out_features)
# Compute attention
e = self.leakyrelu(torch.matmul(h_cat, self.a).squeeze(-1)) # (n_heads, N, N)
# Mask non-neighbors
zero_vec = -9e15 * torch.ones_like(e)
attention = torch.where(A.unsqueeze(0) > 0, e, zero_vec)
# Normalize
attention = F.softmax(attention, dim=-1)
attention = F.dropout(attention, self.dropout, training=self.training)
# Aggregate
h_prime = torch.matmul(attention, h) # (n_heads, N, out_features)
# Average heads
return h_prime.mean(dim=0)
class GAT(nn.Module):
"""2-layer GAT."""
def __init__(self, in_dim, hidden_dim, out_dim, n_heads=4, dropout=0.6):
super().__init__()
self.gat1 = GATLayer(in_dim, hidden_dim, n_heads, dropout)
self.gat2 = GATLayer(hidden_dim, out_dim, 1, dropout)
def forward(self, X, A):
h = F.elu(self.gat1(X, A))
return self.gat2(h, A)
# Test
gat = GAT(in_dim, hidden_dim, out_dim, n_heads=4)
out = gat(X, A)
print(f"GAT output: {out.shape}")
"""
Advanced GNN Implementations: Spectral Methods, GraphSAGE, and Graph Pooling
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import eigh
from collections import defaultdict
# ============================================================================
# 1. Spectral Graph Convolution (ChebNet)
# ============================================================================
class ChebConv(nn.Module):
"""
Chebyshev Spectral Graph Convolution.
Approximates spectral filters using Chebyshev polynomials.
"""
def __init__(self, in_features, out_features, K=3):
super().__init__()
self.K = K
self.linear = nn.Linear(in_features * K, out_features)
def chebyshev_basis(self, L, K):
"""
Compute Chebyshev polynomial basis T_k(L).
Recursive: T_k(x) = 2x*T_{k-1}(x) - T_{k-2}(x)
with T_0(x) = I, T_1(x) = x
"""
N = L.shape[0]
# Rescale Laplacian to [-1, 1]
lambda_max = 2.0 # Approximation for normalized Laplacian
L_scaled = (2.0 / lambda_max) * L - torch.eye(N, device=L.device)
# Initialize
Tx = [torch.eye(N, device=L.device), L_scaled]
# Compute T_k recursively
for k in range(2, K):
Tx.append(2 * L_scaled @ Tx[-1] - Tx[-2])
return Tx[:K]
def forward(self, X, L):
"""
Args:
X: Node features (N, in_features)
L: Normalized Laplacian (N, N)
"""
# Compute Chebyshev basis
Tx = self.chebyshev_basis(L, self.K)
# Apply each basis to features
out = []
for k in range(self.K):
out.append(Tx[k] @ X)
# Concatenate and linear transform
out = torch.cat(out, dim=-1)
return self.linear(out)
def compute_normalized_laplacian(A):
"""
Compute normalized Laplacian: L = I - D^{-1/2} A D^{-1/2}
"""
# Degree matrix
D = torch.diag(A.sum(1))
D_inv_sqrt = torch.pow(D, -0.5)
D_inv_sqrt[torch.isinf(D_inv_sqrt)] = 0
# Normalized Laplacian
I = torch.eye(A.size(0), device=A.device)
L = I - D_inv_sqrt @ A @ D_inv_sqrt
return L
# Test ChebConv
print("=" * 60)
print("1. Spectral Graph Convolution (ChebNet)")
print("=" * 60)
n_nodes = 20
in_dim = 5
out_dim = 8
K = 3
X = torch.randn(n_nodes, in_dim)
A = torch.randint(0, 2, (n_nodes, n_nodes)).float()
A = (A + A.T) / 2 # Symmetric
A.fill_diagonal_(0) # No self-loops
L = compute_normalized_laplacian(A)
cheb_conv = ChebConv(in_dim, out_dim, K=K)
out = cheb_conv(X, L)
print(f"Input shape: {X.shape}")
print(f"Laplacian eigenvalues (first 5): {torch.linalg.eigvalsh(L)[:5].numpy()}")
print(f"Output shape: {out.shape}")
print(f"Chebyshev order K: {K}")
# ============================================================================
# 2. GraphSAGE with Multiple Aggregators
# ============================================================================
class GraphSAGELayer(nn.Module):
"""
GraphSAGE Layer with different aggregation functions.
"""
def __init__(self, in_features, out_features, aggregator='mean'):
super().__init__()
self.aggregator = aggregator
if aggregator == 'mean':
self.linear = nn.Linear(in_features * 2, out_features)
elif aggregator == 'pool':
self.pool_linear = nn.Linear(in_features, in_features)
self.linear = nn.Linear(in_features * 2, out_features)
elif aggregator == 'lstm':
self.lstm = nn.LSTM(in_features, in_features, batch_first=True)
self.linear = nn.Linear(in_features * 2, out_features)
def aggregate_mean(self, X, neighbors_list):
"""Mean aggregation"""
agg = []
for neighbors in neighbors_list:
if len(neighbors) > 0:
agg.append(X[neighbors].mean(0))
else:
agg.append(torch.zeros(X.size(1), device=X.device))
return torch.stack(agg)
def aggregate_pool(self, X, neighbors_list):
"""Max pooling aggregation after MLP"""
agg = []
for neighbors in neighbors_list:
if len(neighbors) > 0:
neighbor_features = X[neighbors]
pooled = F.relu(self.pool_linear(neighbor_features))
agg.append(pooled.max(0)[0])
else:
agg.append(torch.zeros(X.size(1), device=X.device))
return torch.stack(agg)
def aggregate_lstm(self, X, neighbors_list):
"""LSTM aggregation (requires random permutation)"""
agg = []
for neighbors in neighbors_list:
if len(neighbors) > 0:
neighbor_features = X[neighbors].unsqueeze(0)
_, (h_n, _) = self.lstm(neighbor_features)
agg.append(h_n.squeeze(0))
else:
agg.append(torch.zeros(X.size(1), device=X.device))
return torch.stack(agg)
def forward(self, X, A, sample_size=None):
"""
Args:
X: Node features (N, in_features)
A: Adjacency matrix (N, N)
sample_size: Number of neighbors to sample (None = all)
"""
N = X.size(0)
# Build neighbor lists with optional sampling
neighbors_list = []
for i in range(N):
neighbors = A[i].nonzero(as_tuple=True)[0].tolist()
if sample_size and len(neighbors) > sample_size:
neighbors = np.random.choice(neighbors, sample_size, replace=False).tolist()
neighbors_list.append(neighbors)
# Aggregate
if self.aggregator == 'mean':
h_agg = self.aggregate_mean(X, neighbors_list)
elif self.aggregator == 'pool':
h_agg = self.aggregate_pool(X, neighbors_list)
elif self.aggregator == 'lstm':
h_agg = self.aggregate_lstm(X, neighbors_list)
# Concatenate with self features and transform
h = torch.cat([X, h_agg], dim=1)
return F.relu(self.linear(h))
# Test GraphSAGE aggregators
print("\n" + "=" * 60)
print("2. GraphSAGE Aggregators")
print("=" * 60)
aggregators = ['mean', 'pool', 'lstm']
for agg in aggregators:
sage_layer = GraphSAGELayer(in_dim, out_dim, aggregator=agg)
out = sage_layer(X, A, sample_size=5)
print(f"{agg.upper()} aggregator output shape: {out.shape}")
# ============================================================================
# 3. Graph Pooling: DiffPool
# ============================================================================
class DiffPoolLayer(nn.Module):
"""
Differentiable Graph Pooling (DiffPool).
Learns soft cluster assignments.
"""
def __init__(self, in_features, num_clusters, hidden_dim=16):
super().__init__()
# GNN for node embeddings
self.embed_gnn = GCNLayer(in_features, hidden_dim)
# GNN for cluster assignments
self.pool_gnn = GCNLayer(in_features, num_clusters)
def forward(self, X, A):
"""
Args:
X: Node features (N, in_features)
A: Adjacency matrix (N, N)
Returns:
X_pooled: Cluster features (num_clusters, hidden_dim)
A_pooled: Cluster adjacency (num_clusters, num_clusters)
link_loss: Auxiliary loss for preserving structure
entropy_loss: Auxiliary loss for confident assignments
"""
# Compute embeddings
Z = self.embed_gnn(X, A) # (N, hidden_dim)
# Compute soft assignments
S = F.softmax(self.pool_gnn(X, A), dim=1) # (N, num_clusters)
# Coarsen graph
X_pooled = S.T @ Z # (num_clusters, hidden_dim)
A_pooled = S.T @ A @ S # (num_clusters, num_clusters)
# Auxiliary losses
# Link prediction loss: ||A - SS^T||_F
link_loss = torch.norm(A - S @ S.T, p='fro')
# Entropy loss: encourage confident assignments
entropy = -(S * torch.log(S + 1e-10)).sum(1).mean()
entropy_loss = -entropy # Minimize negative entropy
return X_pooled, A_pooled, link_loss, entropy_loss
class GCNLayer(nn.Module):
"""Basic GCN layer (reused from earlier)"""
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, X, A):
A_hat = A + torch.eye(A.size(0), device=A.device)
D_hat = torch.diag(A_hat.sum(1))
D_hat_inv_sqrt = torch.pow(D_hat, -0.5)
D_hat_inv_sqrt[torch.isinf(D_hat_inv_sqrt)] = 0
A_norm = D_hat_inv_sqrt @ A_hat @ D_hat_inv_sqrt
return self.linear(A_norm @ X)
# Test DiffPool
print("\n" + "=" * 60)
print("3. DiffPool Graph Pooling")
print("=" * 60)
num_clusters = 5
diffpool = DiffPoolLayer(in_dim, num_clusters, hidden_dim=8)
X_pooled, A_pooled, link_loss, entropy_loss = diffpool(X, A)
print(f"Original graph: {X.shape[0]} nodes")
print(f"Pooled graph: {X_pooled.shape[0]} clusters")
print(f"Pooled features shape: {X_pooled.shape}")
print(f"Pooled adjacency shape: {A_pooled.shape}")
print(f"Link prediction loss: {link_loss.item():.4f}")
print(f"Entropy loss: {entropy_loss.item():.4f}")
# ============================================================================
# 4. Visualization: Spectral Properties
# ============================================================================
print("\n" + "=" * 60)
print("4. Spectral Analysis Visualization")
print("=" * 60)
# Compute eigenvalues and eigenvectors
L_np = L.numpy()
eigenvalues, eigenvectors = eigh(L_np)
# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Plot 1: Eigenvalue spectrum
axes[0, 0].plot(eigenvalues, 'o-', markersize=6)
axes[0, 0].axhline(y=0, color='r', linestyle='--', alpha=0.3)
axes[0, 0].set_xlabel('Index')
axes[0, 0].set_ylabel('Eigenvalue')
axes[0, 0].set_title('Normalized Laplacian Eigenvalues')
axes[0, 0].grid(True, alpha=0.3)
# Plot 2: First few eigenvectors
for i in range(min(4, len(eigenvalues))):
axes[0, 1].plot(eigenvectors[:, i], label=f'Ξ»={eigenvalues[i]:.3f}', alpha=0.7)
axes[0, 1].set_xlabel('Node index')
axes[0, 1].set_ylabel('Eigenvector value')
axes[0, 1].set_title('First 4 Eigenvectors')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
# Plot 3: Adjacency matrix heatmap
im1 = axes[1, 0].imshow(A.numpy(), cmap='Blues', aspect='auto')
axes[1, 0].set_title('Adjacency Matrix')
axes[1, 0].set_xlabel('Node')
axes[1, 0].set_ylabel('Node')
plt.colorbar(im1, ax=axes[1, 0])
# Plot 4: Normalized Laplacian heatmap
im2 = axes[1, 1].imshow(L.numpy(), cmap='RdBu_r', aspect='auto', vmin=-0.5, vmax=1.5)
axes[1, 1].set_title('Normalized Laplacian')
axes[1, 1].set_xlabel('Node')
axes[1, 1].set_ylabel('Node')
plt.colorbar(im2, ax=axes[1, 1])
plt.tight_layout()
plt.savefig('gnn_spectral_analysis.png', dpi=150, bbox_inches='tight')
print("Saved: gnn_spectral_analysis.png")
plt.show()
print(f"\nSpectral gap (Ξ»2 - Ξ»1): {eigenvalues[1] - eigenvalues[0]:.4f}")
print(f"Number of zero eigenvalues: {(eigenvalues < 1e-10).sum()}")
print("(Zero eigenvalues = number of connected components)")
# ============================================================================
# 5. Comparison: GCN vs GraphSAGE vs ChebNet
# ============================================================================
print("\n" + "=" * 60)
print("5. GNN Methods Comparison")
print("=" * 60)
# Create models
gcn = GCNLayer(in_dim, out_dim)
sage = GraphSAGELayer(in_dim, out_dim, aggregator='mean')
cheb = ChebConv(in_dim, out_dim, K=3)
# Forward pass
out_gcn = gcn(X, A)
out_sage = sage(X, A, sample_size=None) # Full neighborhood
out_cheb = cheb(X, L)
# Compare outputs
print(f"\nGCN output: mean={out_gcn.mean().item():.4f}, std={out_gcn.std().item():.4f}")
print(f"GraphSAGE output: mean={out_sage.mean().item():.4f}, std={out_sage.std().item():.4f}")
print(f"ChebNet output: mean={out_cheb.mean().item():.4f}, std={out_cheb.std().item():.4f}")
# Parameter count
gcn_params = sum(p.numel() for p in gcn.parameters())
sage_params = sum(p.numel() for p in sage.parameters())
cheb_params = sum(p.numel() for p in cheb.parameters())
print(f"\nParameter count:")
print(f" GCN: {gcn_params}")
print(f" GraphSAGE (mean): {sage_params}")
print(f" ChebNet (K={K}): {cheb_params}")
# Visualize output distributions
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for ax, (name, output) in zip(axes, [('GCN', out_gcn), ('GraphSAGE', out_sage), ('ChebNet', out_cheb)]):
ax.hist(output.detach().numpy().flatten(), bins=30, alpha=0.7, edgecolor='black')
ax.set_xlabel('Activation value')
ax.set_ylabel('Frequency')
ax.set_title(f'{name} Output Distribution')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('gnn_comparison.png', dpi=150, bbox_inches='tight')
print("\nSaved: gnn_comparison.png")
plt.show()
Training on Karate ClubΒΆ
Zacharyβs Karate Club is a classic graph dataset with 34 nodes (club members) and edges representing social interactions. After a dispute, the club split into two factions β making this a natural node classification benchmark. Training a GCN on this small graph with only a few labeled nodes demonstrates semi-supervised learning on graphs: the message-passing mechanism propagates label information through the graph structure, allowing accurate classification even with very sparse supervision. This is directly analogous to how GNNs are used in real applications like social network analysis, fraud detection, and molecular property prediction.
# Load Karate Club graph
G = nx.karate_club_graph()
n = len(G.nodes())
# Adjacency
A = torch.tensor(nx.to_numpy_array(G), dtype=torch.float32)
# Features: one-hot node IDs
X = torch.eye(n)
# Labels: two communities
labels = torch.tensor([G.nodes[i]['club'] == 'Mr. Hi' for i in G.nodes()], dtype=torch.long)
# Train GCN
model = GCN(n, 16, 2, dropout=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(X, A)
loss = F.cross_entropy(out, labels)
loss.backward()
optimizer.step()
if (epoch + 1) % 50 == 0:
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
# Evaluate
model.eval()
with torch.no_grad():
pred = model(X, A).argmax(dim=1)
acc = (pred == labels).float().mean()
print(f"\nAccuracy: {acc:.3f}")
# Visualize
pos = nx.spring_layout(G, seed=42)
colors = ['red' if labels[i] == 0 else 'blue' for i in range(n)]
pred_colors = ['red' if pred[i] == 0 else 'blue' for i in range(n)]
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
# True labels
nx.draw(G, pos, node_color=colors, with_labels=True, ax=axes[0], node_size=500)
axes[0].set_title('True Communities', fontsize=12)
# Predictions
nx.draw(G, pos, node_color=pred_colors, with_labels=True, ax=axes[1], node_size=500)
axes[1].set_title('GCN Predictions', fontsize=12)
plt.tight_layout()
plt.show()
SummaryΒΆ
Message Passing:ΒΆ
GCN:ΒΆ
Spectral convolution on graphs.
GAT:ΒΆ
Attention weights neighbors adaptively.
Applications:ΒΆ
Node classification
Link prediction
Graph classification
Molecular property prediction
Social network analysis
Variants:ΒΆ
GraphSAGE (sampling)
GIN (graph isomorphism)
PNA (principal aggregation)
Next Steps:ΒΆ
14_efficient_transformers.ipynb - Attention mechanisms
Explore graph pooling
Apply to molecular graphs
Advanced Graph Neural Networks TheoryΒΆ
1. Graph Representation Learning FoundationsΒΆ
1.1 Graph DefinitionΒΆ
Graph: G = (V, E) where:
V = {vβ, vβ, β¦, vβ} (nodes/vertices)
E β V Γ V (edges)
Types:
Undirected: (u, v) β E βΊ (v, u) β E
Directed: Edges have direction
Weighted: Edge weights w: E β β
Attributed: Node features X β β^(NΓD), edge features E_feat β β^(|E|ΓD_e)
Adjacency matrix: A β {0,1}^(NΓN)
A[i,j] = 1 if (v_i, v_j) β E, else 0
Degree matrix: D β β^(NΓN) (diagonal)
D[i,i] = Ξ£β±Ό A[i,j] (degree of node i)
1.2 The Message Passing FrameworkΒΆ
Core idea: Node representations via neighborhood aggregation.
General message passing:
h_v^(k+1) = UPDATE^(k)(h_v^(k), AGGREGATE^(k)({h_u^(k) : u β N(v)}))
where:
h_v^(k): Node v representation at layer k
N(v): Neighbors of v
AGGREGATE: Permutation-invariant function (sum, mean, max)
UPDATE: Combines old state with aggregated messages
Key property: Permutation invariance
AGGREGATE(Ο(S)) = AGGREGATE(S) for any permutation Ο
1.3 Why Graphs are Different from GridsΒΆ
CNNs assume:
Fixed neighborhood size (e.g., 3Γ3)
Spatial locality and grid structure
Translational equivariance
Graphs:
Variable neighborhood sizes (d(v) β constant)
No spatial coordinates
Permutation invariance required
Consequence: Canβt directly apply convolutions. Need graph-specific operations.
2. Graph Convolutional Networks (GCN)ΒΆ
2.1 Spectral MotivationΒΆ
Graph Laplacian:
L = D - A (unnormalized)
L_norm = I - D^(-1/2) A D^(-1/2) (normalized)
Eigendecomposition:
L = UΞU^T
where U are eigenvectors (graph Fourier basis), Ξ eigenvalues.
Graph Fourier transform:
xΜ = U^T x (transform to frequency domain)
Spectral convolution:
x β
y = U((U^T x) β (U^T y))
Polynomial filter (Chebyshev): Avoid expensive eigendecomposition:
g_ΞΈ β
x β Ξ£β ΞΈβ Tβ(LΜ) x
where Tβ is Chebyshev polynomial, LΜ = 2L/Ξ»_max - I.
2.2 GCN Layer (Kipf & Welling, 2017)ΒΆ
Simplification: First-order approximation (K=1):
H^(l+1) = Ο(DΜ^(-1/2) Γ DΜ^(-1/2) H^(l) W^(l))
where:
Γ = A + I (add self-loops)
DΜ[i,i] = Ξ£β±Ό Γ[i,j]
W^(l): Learnable weight matrix
Ο: Activation function
Intuition:
Self-loops: Include own features
DΜ^(-1/2) Γ DΜ^(-1/2): Symmetric normalization (average neighbors)
W: Learnable transformation
Ο: Non-linearity
Per-node update:
h_v^(l+1) = Ο(Ξ£_{uβN(v)βͺ{v}} (1/β(d_u d_v)) h_u^(l) W^(l))
2.3 GCN PropertiesΒΆ
Computational complexity:
Matrix multiplication: O(|E| Β· D Β· F) where D, F are feature dimensions
Sparse adjacency β efficient for sparse graphs
Receptive field:
K-layer GCN sees K-hop neighborhood
Node v at layer K influenced by all nodes β€ K hops away
Over-smoothing:
Problem: As K β β, all node features β same value
Proof sketch: Repeated averaging makes all nodes similar
Solution: Shallow networks (2-3 layers), skip connections, normalization
2.4 Mathematical AnalysisΒΆ
Propagation as low-pass filter: GCN smooths features across edges:
h^(l+1) = Γ_norm h^(l) W
where Γ_norm has eigenvalues in [0, 1].
Repeated application:
h^(K) = (Γ_norm)^K h^(0) W_combined
As K increases, (Γ_norm)^K emphasizes low-frequency components β over-smoothing.
3. Graph Attention Networks (GAT)ΒΆ
3.1 Attention Mechanism for GraphsΒΆ
Motivation: Not all neighbors equally important. Learn attention weights.
Attention coefficient (unnormalized):
e_ij = LeakyReLU(a^T [W h_i || W h_j])
where:
W: Learnable transformation
a: Learnable attention vector
|| : Concatenation
Normalized attention (softmax over neighbors):
Ξ±_ij = softmax_j(e_ij) = exp(e_ij) / Ξ£_{kβN(i)} exp(e_ik)
Aggregation:
h_i^(l+1) = Ο(Ξ£_{jβN(i)} Ξ±_ij W^(l) h_j^(l))
3.2 Multi-Head AttentionΒΆ
K independent attention heads:
h_i^(l+1) = ||_{k=1}^K Ο(Ξ£_{jβN(i)} Ξ±_ij^k W^k h_j^(l))
Last layer (average instead of concat):
h_i^(L) = Ο((1/K) Ξ£_{k=1}^K Ξ£_{jβN(i)} Ξ±_ij^k W^k h_j^(l))
Advantages:
Stabilizes learning
Multiple attention patterns (different heads attend to different aspects)
3.3 Computational ComplexityΒΆ
Per-edge attention computation:
Attention score: O(F) where F is feature dimension
Total: O(|E| Β· F) for all edges
Comparison with GCN:
GCN: Fixed normalization 1/β(d_i d_j)
GAT: Learned attention Ξ±_ij
GAT more flexible but higher computational cost
3.4 Theoretical PropertiesΒΆ
Inductive learning: GAT can generalize to unseen nodes/graphs (unlike GCN which needs full graph).
Attention weights interpretability:
Visualize Ξ±_ij to understand which neighbors are important
Useful for explainability
Expressiveness: GAT is strictly more expressive than GCN (proven via WL-test).
4. GraphSAGE (SAmple and aggreGatE)ΒΆ
4.1 MotivationΒΆ
Problem: GCN requires full graph in memory (batch training on large graphs infeasible).
Solution: Sample fixed-size neighborhood for each node.
4.2 AlgorithmΒΆ
Sampling: For each node v, sample S neighbors from N(v) uniformly.
Aggregation functions:
Mean aggregator:
h_v^(k) = Ο(W Β· MEAN({h_u^(k-1), βu β N(v)}) + B h_v^(k-1))
LSTM aggregator:
h_v^(k) = Ο(W Β· LSTM({h_u^(k-1), βu β Ο(N(v))}) + B h_v^(k-1))
where Ο is random permutation (LSTM not permutation-invariant, so randomize).
Max-pooling aggregator:
h_v^(k) = Ο(W Β· MAX({ReLU(W_pool h_u^(k-1) + b), βu β N(v)}) + B h_v^(k-1))
Normalization: L2-normalize embeddings at each layer:
h_v^(k) β h_v^(k) / ||h_v^(k)||_2
4.3 Mini-Batch TrainingΒΆ
Idea: For batch of nodes, only compute embeddings for their sampled neighborhoods.
Multi-layer sampling:
Layer K: Sample S_K neighbors per node
Layer K-1: Sample S_{K-1} neighbors per (already sampled) node
β¦
Recursive neighborhood sampling
Computational complexity:
K layers, sample S neighbors per layer
Each node in batch has S^K neighbors in computation graph
Total: O(B Β· S^K Β· F) where B is batch size
4.4 Loss FunctionsΒΆ
Supervised (node classification):
L = -Ξ£_v y_v log(softmax(h_v^(K)))
Unsupervised (graph structure):
L = -log(Ο(h_u^T h_v)) - Q Β· E_{v_n~P_n(v)} log(Ο(-h_u^T h_{v_n}))
where:
(u, v) are co-occurring nodes (e.g., random walk)
v_n are negative samples
Q: Number of negative samples
5. Message Passing Neural Networks (MPNN)ΒΆ
5.1 Unified FrameworkΒΆ
General message passing (Gilmer et al., 2017):
Message phase:
m_v^(k+1) = Ξ£_{uβN(v)} M_k(h_v^(k), h_u^(k), e_{uv})
Update phase:
h_v^(k+1) = U_k(h_v^(k), m_v^(k+1))
where:
M_k: Message function (can depend on edge features e_{uv})
U_k: Update function (e.g., GRU)
Readout (graph-level):
y = R({h_v^(K) : v β G})
5.2 Specific InstancesΒΆ
GCN as MPNN:
M(h_v, h_u, e_{uv}) = (1/β(d_u d_v)) h_u W
U(h_v, m_v) = Ο(m_v)
GAT as MPNN:
M(h_v, h_u, e_{uv}) = Ξ±_{vu} W h_u
U(h_v, m_v) = Ο(m_v)
MPNN with edge features:
M(h_v, h_u, e_{uv}) = NN(concat(h_v, h_u, e_{uv}))
5.3 Expressiveness via WL TestΒΆ
Weisfeiler-Lehman (WL) test: Graph isomorphism test.
1-WL algorithm:
Initialize: Label each node with its features
Iterate: Update label = hash(old_label, multiset of neighbor labels)
Repeat until convergence
Theorem (Xu et al., 2019 - GIN): Standard MPNNs are at most as powerful as 1-WL test.
Limitations:
Cannot distinguish certain non-isomorphic graphs
Example: Complete bipartite graphs K_{m,n} with mβ n but same node degrees
Going beyond WL:
Higher-order GNNs (k-WL)
Subgraph GNNs
Graph transformers with positional encodings
6. Graph Isomorphism Network (GIN)ΒΆ
6.1 Maximally Expressive GNNΒΆ
Theorem: GIN can distinguish any graphs distinguishable by 1-WL test.
GIN update rule:
h_v^(k+1) = MLP^(k)((1 + Ξ΅^(k)) h_v^(k) + Ξ£_{uβN(v)} h_u^(k))
where:
Ξ΅: Learnable scalar (or fixed)
MLP: Multi-layer perceptron with at least one hidden layer
Why this specific form?
Sum aggregation (provably best for WL expressiveness)
(1+Ξ΅) self-weight distinguishes nodes with same neighbor multiset but different features
MLP provides expressiveness (can model injective functions)
6.2 Theoretical JustificationΒΆ
Injective aggregation: To distinguish different multisets, aggregation must be injective.
Theorem: Sum is injective over multisets (in countable infinite domain).
Other aggregators:
Mean: Not injective (e.g., {1,2} and {1.5,1.5} have same mean)
Max: Not injective (e.g., {1,3} and {3,3} have same max)
Sum: Injective! β
6.3 Practical ConsiderationsΒΆ
MLP architecture:
At least 2 layers (to ensure universality)
Batch normalization between layers
Dropout for regularization
Ξ΅ initialization:
Learnable: Ξ΅ = 0 initially, learned during training
Fixed: Ξ΅ = 0 or small positive value
7. Advanced GNN ArchitecturesΒΆ
7.1 PNA (Principal Neighbourhood Aggregation)ΒΆ
Motivation: Different aggregators capture different information.
Multi-aggregator:
h_v = MLP(||_{aggβ{mean,max,min,std}} agg({h_u : uβN(v)}))
Scalers (degree-aware):
h_v = MLP(||_{agg,scaler} scaler(agg(...)))
where scalers β {amplification, attenuation, identity}:
Amplification: d_v Β· agg (emphasize high-degree nodes)
Attenuation: (1/d_v) Β· agg (normalize by degree)
7.2 DeeperGCNΒΆ
Problem: Deep GNNs suffer from over-smoothing, gradient vanishing.
Solutions:
Skip connections (residual):
h_v^(l+1) = h_v^(l) + GCN_layer(h_v^(l))
Normalization (layer norm, batch norm):
h_v^(l+1) = LayerNorm(GCN_layer(h_v^(l)))
Pre-activation:
h_v^(l+1) = h_v^(l) + Ο(GCN(LayerNorm(h_v^(l))))
Result: Can train 28+ layer GNNs with good performance.
7.3 Graph TransformerΒΆ
Self-attention on graphs:
Challenges:
No positional information (unlike sequences)
Quadratic complexity in number of nodes
Solutions:
Laplacian positional encoding:
PE_v = [u_1(v), u_2(v), ..., u_k(v)]
where u_i are eigenvectors of graph Laplacian.
Sparse attention (only on edges):
Attention only between (u,v) if (u,v) β E
Relative positional encoding: Encode shortest path distance between nodes.
Graphormer (ICLR 2021):
Centrality encoding (degree)
Spatial encoding (shortest path)
Edge encoding in attention bias
8. Pooling and Hierarchical GNNsΒΆ
8.1 Graph-Level RepresentationsΒΆ
Need: Map variable-size graphs to fixed-size vectors.
Simple pooling:
h_G = Ξ£_v h_v (sum pooling)
h_G = mean_v h_v (mean pooling)
h_G = max_v h_v (max pooling)
8.2 Differentiable Pooling (DiffPool)ΒΆ
Idea: Learn soft cluster assignment.
Cluster assignment matrix: S β β^(NΓK)
S[i,k] = probability node i assigned to cluster k
Coarsened features:
X^(new) = S^T X (K Γ D)
A^(new) = S^T A S (K Γ K)
Loss:
L_pool = ||A - SS^T||_F^2 (link prediction)
+ entropy(S) (entropy regularization)
8.3 Top-K PoolingΒΆ
Select top-k nodes based on learned scores:
Score function:
y = X p / ||p|| (project features to scalar)
Select top-k:
idx = top_k(y)
X^(new) = X[idx]
A^(new) = A[idx, idx]
Gating:
X^(new) = X[idx] β Ο(y[idx])
8.4 Self-Attention Pooling (SAGPool)ΒΆ
Attention-based node selection:
Z = GNN(X, A)
y = Ο(Z p / ||p||)
idx = top_k(y)
X^(new) = X[idx] β y[idx]
A^(new) = A[idx, idx]
9. Heterogeneous and Dynamic GraphsΒΆ
9.1 Heterogeneous GraphsΒΆ
Definition: Multiple node/edge types.
V = Vβ βͺ Vβ βͺ β¦ βͺ Vβ (node types)
E = Eβ βͺ Eβ βͺ β¦ βͺ Eβ (edge types)
Example (citation network):
Nodes: Papers, Authors, Venues
Edges: Writes, Publishes, Cites
Relational GCN (R-GCN):
h_v^(l+1) = Ο(Ξ£_{rβR} Ξ£_{uβN_r(v)} (1/|N_r(v)|) W_r^(l) h_u^(l) + W_0^(l) h_v^(l))
where:
r: Relation type
N_r(v): Neighbors of v via relation r
W_r: Relation-specific weight matrix
Heterogeneous GAT (HAN):
Node-level attention (within meta-path)
Semantic-level attention (across meta-paths)
9.2 Dynamic GraphsΒΆ
Temporal graphs: Edges/nodes change over time.
Discrete-time: Sequence of graph snapshots Gβ, Gβ, β¦, Gβ
Continuous-time: Edge stream: (u_i, v_i, t_i)
Temporal GNN approaches:
Snapshot-based: Apply GNN to each snapshot independently
RNN-based: GNN + LSTM/GRU across time
Temporal random walk: Sample time-respecting walks
Temporal attention: Attend to past events
TGAT (Temporal Graph Attention):
h_v(t) = Attention({h_u(t_u) : (u,v,t_u) β E, t_u < t})
10. Training TechniquesΒΆ
10.1 NormalizationΒΆ
Batch normalization (problematic for graphs):
Variable graph sizes β inconsistent statistics
Small graphs β noisy estimates
Better alternatives:
Layer normalization:
h_v β (h_v - ΞΌ_v) / Ο_v
Normalize each node independently.
Graph normalization:
h_G β (h_G - ΞΌ_G) / Ο_G
Normalize across graph.
Pair normalization: Normalize across node pairs (for message passing).
10.2 RegularizationΒΆ
Dropout:
Node dropout: Drop nodes
Edge dropout: Drop edges (augmentation + regularization)
Message dropout: Drop messages
Graph augmentation:
Add/remove random edges
Mask node features
Subgraph sampling
10.3 Loss FunctionsΒΆ
Node classification:
L = -Ξ£_vβV_train y_v log Ε·_v
Link prediction:
L = -Ξ£_{(u,v)βE} log Ο(h_u^T h_v) - Ξ£_{(u,v)βE} log(1 - Ο(h_u^T h_v))
Graph classification:
L = -Ξ£_G y_G log Ε·_G
Contrastive (self-supervised):
L = -log(exp(sim(h_i, h_j)/Ο) / Ξ£_k exp(sim(h_i, h_k)/Ο))
11. Complexity and ScalabilityΒΆ
11.1 Computational ComplexityΒΆ
GCN forward pass:
Dense: O(NΒ² D F) (matrix multiplication)
Sparse: O(E D F) (leverage sparsity)
Memory:
Store adjacency: O(E) (sparse) or O(NΒ²) (dense)
Node features: O(N D)
Activations: O(N F) per layer
11.2 Scalability ChallengesΒΆ
Full-batch training:
Need entire graph in memory
Infeasible for graphs with millions of nodes
Solutions:
Mini-batching (GraphSAGE):
Sample neighborhoods
Complexity: O(B S^L D F)
Cluster-GCN:
Partition graph into clusters
Train on subgraphs
Reduces memory and enables parallelism
Graph sampling:
Layer-wise sampling
Importance sampling
Variance reduction techniques
11.3 Large-Scale GNN SystemsΒΆ
Distributed training:
Partition nodes across machines
Communication overhead for neighbor features
Libraries:
DGL (Deep Graph Library): Flexible message passing
PyG (PyTorch Geometric): Rich model zoo
GraphStorm (AWS): Distributed training
Graph-Learn (Alibaba): Industrial scale
12. ApplicationsΒΆ
12.1 Node ClassificationΒΆ
Task: Predict node labels given graph structure and partial labels.
Examples:
Social networks: User attributes
Citation networks: Paper topics
Molecules: Atom properties
Evaluation:
Semi-supervised: Train on few labeled, test on rest
Metrics: Accuracy, F1-score, Micro/Macro-F1
12.2 Link PredictionΒΆ
Task: Predict missing or future edges.
Methods:
Score function: s(u,v) = h_u^T h_v (dot product)
Decoder: MLP(concat(h_u, h_v))
Distance: ||h_u - h_v||
Applications:
Recommendation: User-item edges
Knowledge graphs: Entity-relation-entity triples
Drug discovery: Protein-protein interactions
12.3 Graph ClassificationΒΆ
Task: Classify entire graphs.
Examples:
Molecule property prediction (toxicity, solubility)
Social network analysis (community detection)
Program verification (code graphs)
Approach:
Graph β GNN β Pooling β Readout β Classifier
12.4 Temporal PredictionΒΆ
Traffic forecasting:
Nodes: Road segments
Edges: Connectivity
Features: Traffic speed/volume
Predict: Future traffic
Financial networks:
Nodes: Assets
Edges: Correlations
Predict: Price movements
13. Benchmarks and DatasetsΒΆ
13.1 Standard BenchmarksΒΆ
Node classification:
Cora, Citeseer, Pubmed (citation networks)
Reddit, Yelp (social networks)
Ogbn-arxiv, ogbn-products (OGB)
Graph classification:
MUTAG, PTC, PROTEINS (bioinformatics)
COLLAB, IMDB (social networks)
Ogbg-molhiv, ogbg-molpcba (molecules)
Link prediction:
WN18RR, FB15k-237 (knowledge graphs)
Ogbl-ppa, ogbl-collab (OGB)
13.2 Open Graph Benchmark (OGB)ΒΆ
Large-scale, diverse, realistic:
Millions of nodes/edges
Realistic splits (temporal, scaffold)
Standardized evaluation
Example tasks:
ogbn-arxiv: 169K papers, 40-way classification
ogbg-molhiv: 41K molecules, binary classification
ogbl-citation2: 2.9M papers, link prediction
14. Recent Advances (2020-2024)ΒΆ
14.1 Equivariant GNNsΒΆ
SE(3)-equivariant networks:
Preserve 3D rotations and translations
Critical for molecular modeling
E(n)-GNN, SchNet, DimeNet++
14.2 Graph TransformersΒΆ
Graphormer (ICLR 2021):
Won OGB-LSC molecular prediction
Centrality + spatial encoding
SAN (Spectral Attention Network):
Laplacian positional encoding
Full attention with efficiency tricks
14.3 Self-Supervised LearningΒΆ
Contrastive methods:
GraphCL: Augmentation-based contrastive
BGRL: Bootstrap graph latents
SimGRACE: Perturb encoder
Generative:
GPT-GNN: Generative pre-training
GraphMAE: Masked autoencoder
14.4 Graph Foundation ModelsΒΆ
Motivation: Pre-train on many graphs, fine-tune on target.
Approaches:
Graph-level pre-training (molecule datasets)
Prompt-based fine-tuning
Transfer across domains
15. Limitations and Future DirectionsΒΆ
15.1 Known LimitationsΒΆ
Expressiveness:
Limited by 1-WL test
Canβt count triangles, cycles
Over-smoothing:
Deep GNNs make all nodes similar
Limits depth to 2-3 layers
Scalability:
Full-batch infeasible for billion-node graphs
Distributed training communication overhead
Heterophily:
GNNs assume homophily (similar nodes connect)
Fail when dissimilar nodes connect
15.2 Open ProblemsΒΆ
Theoretical:
Better understanding of GNN expressiveness
Generalization bounds
When do GNNs work vs. fail?
Architectural:
GNNs beyond 1-WL
Long-range dependencies (beyond k-hop)
Handling heterophily
Scalability:
Billion-node, trillion-edge graphs
Streaming graphs
Efficient inference
15.3 Future DirectionsΒΆ
Foundation models:
Universal graph encoders
Zero-shot on new graphs
Neurosymbolic:
Combine GNNs with logic
Interpretable reasoning
Geometric deep learning:
Unified theory (symmetry, invariance)
Beyond graphs (manifolds, meshes)
16. Key TakeawaysΒΆ
Message passing is core: All GNNs aggregate neighbor information iteratively.
Expressiveness vs. efficiency tradeoff:
GIN: Maximally expressive (within WL limit) but expensive
GCN: Less expressive but fast and simple
Depth is tricky:
Over-smoothing limits to 2-3 layers
Solutions: Skip connections, normalization, attention
Scalability crucial:
Sampling (GraphSAGE)
Clustering (Cluster-GCN)
Distributed systems
Inductive vs. transductive:
GCN: Transductive (needs full graph)
GraphSAGE, GAT: Inductive (generalize to new nodes)
Application-specific design:
Molecules: E(3)-equivariance
Social: Heterogeneous graphs
Temporal: Dynamic GNNs
Pre-training helps: Like NLP/vision, pre-training improves few-shot performance.
17. ReferencesΒΆ
Foundational:
Kipf & Welling (2017): βSemi-Supervised Classification with GCNsβ (ICLR)
VeliΔkoviΔ et al. (2018): βGraph Attention Networksβ (ICLR)
Hamilton et al. (2017): βInductive Representation Learning on Large Graphsβ (NeurIPS)
Gilmer et al. (2017): βNeural Message Passing for Quantum Chemistryβ (ICML)
Xu et al. (2019): βHow Powerful are GNNs?β (ICLR)
Advanced architectures:
Corso et al. (2020): βPrincipal Neighbourhood Aggregationβ (NeurIPS)
Li et al. (2020): βDeeperGCNβ (ICLR)
Ying et al. (2021): βGraphormerβ (NeurIPS)
Theory:
Morris et al. (2019): βWeisfeiler and Leman Go Neuralβ (AAAI)
Bronstein et al. (2021): βGeometric Deep Learningβ (Book/Tutorial)
Surveys:
Wu et al. (2021): βA Comprehensive Survey on GNNsβ
Zhou et al. (2020): βGraph Neural Networks: A Review of Methods and Applicationsβ
"""
Complete Graph Neural Network Implementations
==============================================
Includes: GCN, GAT, GraphSAGE, GIN, message passing framework, pooling methods.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import math
# ============================================================================
# 1. Graph Convolutional Network (GCN) Layer
# ============================================================================
class GCNConv(MessagePassing):
"""
Graph Convolutional Network layer (Kipf & Welling, 2017).
Implements: H^(l+1) = Ο(DΜ^(-1/2) Γ DΜ^(-1/2) H^(l) W^(l))
where Γ = A + I (add self-loops), DΜ is degree matrix of Γ.
Args:
in_channels: Input feature dimension
out_channels: Output feature dimension
bias: Whether to add bias
normalize: Whether to apply symmetric normalization
"""
def __init__(self, in_channels, out_channels, bias=True, normalize=True):
super(GCNConv, self).__init__(aggr='add') # Sum aggregation
self.in_channels = in_channels
self.out_channels = out_channels
self.normalize = normalize
# Learnable weight matrix
self.weight = nn.Parameter(torch.Tensor(in_channels, out_channels))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
"""Glorot initialization."""
nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, x, edge_index):
"""
Args:
x: Node features [N, in_channels]
edge_index: Edge indices [2, E]
Returns:
Updated node features [N, out_channels]
"""
# Add self-loops
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Transform features
x = torch.matmul(x, self.weight)
# Compute normalization
if self.normalize:
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
else:
norm = None
# Propagate (message passing)
out = self.propagate(edge_index, x=x, norm=norm)
if self.bias is not None:
out += self.bias
return out
def message(self, x_j, norm):
"""Construct messages: 1/β(d_i d_j) * h_j."""
return norm.view(-1, 1) * x_j if norm is not None else x_j
class GCN(nn.Module):
"""
Multi-layer GCN for node classification.
Args:
num_features: Input feature dimension
hidden_dim: Hidden layer dimension
num_classes: Number of output classes
num_layers: Number of GCN layers
dropout: Dropout rate
"""
def __init__(self, num_features, hidden_dim, num_classes, num_layers=2, dropout=0.5):
super(GCN, self).__init__()
self.num_layers = num_layers
self.dropout = dropout
self.convs = nn.ModuleList()
self.convs.append(GCNConv(num_features, hidden_dim))
for _ in range(num_layers - 2):
self.convs.append(GCNConv(hidden_dim, hidden_dim))
self.convs.append(GCNConv(hidden_dim, num_classes))
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index)
return F.log_softmax(x, dim=1)
# ============================================================================
# 2. Graph Attention Network (GAT) Layer
# ============================================================================
class GATConv(MessagePassing):
"""
Graph Attention Network layer (VeliΔkoviΔ et al., 2018).
Implements: h_i = Ο(Ξ£_{jβN(i)} Ξ±_ij W h_j)
where Ξ±_ij = softmax_j(LeakyReLU(a^T [W h_i || W h_j]))
Args:
in_channels: Input feature dimension
out_channels: Output feature dimension
heads: Number of attention heads
concat: Whether to concatenate or average multi-head outputs
dropout: Attention dropout rate
bias: Whether to add bias
"""
def __init__(self, in_channels, out_channels, heads=1, concat=True,
dropout=0.0, bias=True):
super(GATConv, self).__init__(aggr='add')
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.dropout = dropout
# Learnable parameters
self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))
self.att = nn.Parameter(torch.Tensor(1, heads, 2 * out_channels))
if bias and concat:
self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
nn.init.xavier_uniform_(self.att)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, x, edge_index):
"""
Args:
x: Node features [N, in_channels]
edge_index: Edge indices [2, E]
Returns:
Updated node features [N, heads * out_channels] or [N, out_channels]
"""
# Transform features
x = torch.matmul(x, self.weight).view(-1, self.heads, self.out_channels)
# Propagate
out = self.propagate(edge_index, x=x)
if self.concat:
out = out.view(-1, self.heads * self.out_channels)
else:
out = out.mean(dim=1)
if self.bias is not None:
out += self.bias
return out
def message(self, x_i, x_j, edge_index_i):
"""
Compute attention coefficients and messages.
Args:
x_i: Target node features [E, heads, out_channels]
x_j: Source node features [E, heads, out_channels]
edge_index_i: Target node indices [E]
"""
# Concatenate features
alpha = torch.cat([x_i, x_j], dim=-1) # [E, heads, 2*out_channels]
# Compute attention scores
alpha = (alpha * self.att).sum(dim=-1) # [E, heads]
alpha = F.leaky_relu(alpha, negative_slope=0.2)
# Softmax over neighbors
alpha = softmax(alpha, edge_index_i)
# Attention dropout
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
# Weight messages by attention
return x_j * alpha.unsqueeze(-1)
def softmax(src, index):
"""Numerically stable softmax over nodes."""
src_max = src.max()
out = (src - src_max).exp()
# Sum per node
out_sum = torch.zeros(index.max() + 1, out.size(1), device=out.device)
out_sum.scatter_add_(0, index.unsqueeze(-1).expand_as(out), out)
return out / (out_sum[index] + 1e-16)
class GAT(nn.Module):
"""
Multi-layer GAT for node classification.
Args:
num_features: Input feature dimension
hidden_dim: Hidden layer dimension
num_classes: Number of output classes
num_layers: Number of GAT layers
heads: Number of attention heads
dropout: Dropout rate
"""
def __init__(self, num_features, hidden_dim, num_classes, num_layers=2,
heads=8, dropout=0.6):
super(GAT, self).__init__()
self.num_layers = num_layers
self.dropout = dropout
self.convs = nn.ModuleList()
# First layer
self.convs.append(GATConv(num_features, hidden_dim, heads=heads, concat=True))
# Hidden layers
for _ in range(num_layers - 2):
self.convs.append(GATConv(hidden_dim * heads, hidden_dim, heads=heads, concat=True))
# Output layer (average instead of concat)
self.convs.append(GATConv(hidden_dim * heads, num_classes, heads=1, concat=False))
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index)
return F.log_softmax(x, dim=1)
# ============================================================================
# 3. GraphSAGE Layer
# ============================================================================
class SAGEConv(MessagePassing):
"""
GraphSAGE layer (Hamilton et al., 2017).
Implements: h_v = Ο(W Β· [h_v || AGG({h_u : u β N(v)})])
Args:
in_channels: Input feature dimension
out_channels: Output feature dimension
aggr: Aggregation method ('mean', 'max', 'add')
normalize: Whether to L2-normalize output
bias: Whether to add bias
"""
def __init__(self, in_channels, out_channels, aggr='mean', normalize=True, bias=True):
super(SAGEConv, self).__init__(aggr=aggr)
self.in_channels = in_channels
self.out_channels = out_channels
self.normalize = normalize
# Learnable weights
self.weight = nn.Parameter(torch.Tensor(in_channels, out_channels))
self.root_weight = nn.Parameter(torch.Tensor(in_channels, out_channels))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
nn.init.xavier_uniform_(self.root_weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, x, edge_index):
"""
Args:
x: Node features [N, in_channels]
edge_index: Edge indices [2, E]
Returns:
Updated node features [N, out_channels]
"""
# Aggregate neighbors
out = self.propagate(edge_index, x=x)
out = torch.matmul(out, self.weight)
# Add root node features
out += torch.matmul(x, self.root_weight)
if self.bias is not None:
out += self.bias
if self.normalize:
out = F.normalize(out, p=2, dim=-1)
return out
def message(self, x_j):
"""Messages are just neighbor features."""
return x_j
class GraphSAGE(nn.Module):
"""
Multi-layer GraphSAGE for node classification.
Args:
num_features: Input feature dimension
hidden_dim: Hidden layer dimension
num_classes: Number of output classes
num_layers: Number of SAGE layers
dropout: Dropout rate
aggr: Aggregation method
"""
def __init__(self, num_features, hidden_dim, num_classes, num_layers=2,
dropout=0.5, aggr='mean'):
super(GraphSAGE, self).__init__()
self.num_layers = num_layers
self.dropout = dropout
self.convs = nn.ModuleList()
self.convs.append(SAGEConv(num_features, hidden_dim, aggr=aggr))
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_dim, hidden_dim, aggr=aggr))
self.convs.append(SAGEConv(hidden_dim, num_classes, aggr=aggr, normalize=False))
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index)
return F.log_softmax(x, dim=1)
# ============================================================================
# 4. Graph Isomorphism Network (GIN) Layer
# ============================================================================
class GINConv(MessagePassing):
"""
Graph Isomorphism Network layer (Xu et al., 2019).
Implements: h_v = MLP((1 + Ξ΅) h_v + Ξ£_{uβN(v)} h_u)
Args:
nn: MLP (torch.nn.Sequential)
eps: Initial epsilon value
train_eps: Whether epsilon is learnable
"""
def __init__(self, nn, eps=0.0, train_eps=True):
super(GINConv, self).__init__(aggr='add') # Sum aggregation
self.nn = nn
self.initial_eps = eps
if train_eps:
self.eps = nn.Parameter(torch.Tensor([eps]))
else:
self.register_buffer('eps', torch.Tensor([eps]))
def forward(self, x, edge_index):
"""
Args:
x: Node features [N, in_channels]
edge_index: Edge indices [2, E]
Returns:
Updated node features [N, out_channels]
"""
# Aggregate neighbors
out = self.propagate(edge_index, x=x)
# Add self-features with (1 + Ξ΅)
out += (1 + self.eps) * x
# Apply MLP
return self.nn(out)
def message(self, x_j):
return x_j
class GIN(nn.Module):
"""
Multi-layer GIN for graph classification.
Args:
num_features: Input feature dimension
hidden_dim: Hidden layer dimension
num_classes: Number of output classes
num_layers: Number of GIN layers
dropout: Dropout rate
"""
def __init__(self, num_features, hidden_dim, num_classes, num_layers=5, dropout=0.5):
super(GIN, self).__init__()
self.num_layers = num_layers
self.dropout = dropout
self.convs = nn.ModuleList()
self.batch_norms = nn.ModuleList()
# First layer
mlp = nn.Sequential(
nn.Linear(num_features, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.convs.append(GINConv(mlp, train_eps=True))
self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
# Hidden layers
for _ in range(num_layers - 1):
mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.convs.append(GINConv(mlp, train_eps=True))
self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
# Graph-level readout
self.fc1 = nn.Linear(hidden_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, num_classes)
def forward(self, x, edge_index, batch):
"""
Args:
x: Node features [N, num_features]
edge_index: Edge indices [2, E]
batch: Batch vector [N] (which graph each node belongs to)
Returns:
Graph-level predictions [batch_size, num_classes]
"""
# Node-level forward
for i, (conv, bn) in enumerate(zip(self.convs, self.batch_norms)):
x = conv(x, edge_index)
x = bn(x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# Graph-level readout (sum pooling)
x = global_add_pool(x, batch)
# MLP for classification
x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def global_add_pool(x, batch):
"""Sum pooling over nodes in each graph."""
size = int(batch.max().item()) + 1
out = torch.zeros(size, x.size(1), device=x.device)
return out.scatter_add_(0, batch.unsqueeze(-1).expand_as(x), x)
# ============================================================================
# 5. Pooling Layers
# ============================================================================
class TopKPooling(nn.Module):
"""
Top-K pooling layer.
Selects top-k nodes based on learned projection scores.
Args:
in_channels: Input feature dimension
ratio: Pooling ratio (fraction of nodes to keep)
min_score: Minimum score threshold
"""
def __init__(self, in_channels, ratio=0.5, min_score=None):
super(TopKPooling, self).__init__()
self.in_channels = in_channels
self.ratio = ratio
self.min_score = min_score
self.score_layer = nn.Linear(in_channels, 1)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.score_layer.weight)
nn.init.zeros_(self.score_layer.bias)
def forward(self, x, edge_index, batch=None):
"""
Args:
x: Node features [N, in_channels]
edge_index: Edge indices [2, E]
batch: Batch vector [N]
Returns:
x: Pooled features
edge_index: Pooled edges
batch: Pooled batch
perm: Selected node indices
"""
# Compute scores
score = self.score_layer(x).squeeze()
# Select top-k nodes
if batch is None:
perm = topk(score, self.ratio, self.min_score)
else:
perm = topk_per_graph(score, batch, self.ratio, self.min_score)
# Filter nodes
x = x[perm] * torch.sigmoid(score[perm]).view(-1, 1)
# Filter edges
edge_index, _ = filter_adj(edge_index, perm, num_nodes=score.size(0))
# Update batch
if batch is not None:
batch = batch[perm]
return x, edge_index, batch, perm
def topk(x, ratio, min_score=None):
"""Select top-k elements."""
num_nodes = x.size(0)
k = max(1, int(ratio * num_nodes))
_, perm = torch.topk(x, k)
return perm
def topk_per_graph(x, batch, ratio, min_score=None):
"""Select top-k elements per graph."""
perm = []
for i in range(int(batch.max()) + 1):
mask = batch == i
indices = mask.nonzero(as_tuple=False).view(-1)
scores = x[mask]
k = max(1, int(ratio * scores.size(0)))
_, local_perm = torch.topk(scores, k)
perm.append(indices[local_perm])
return torch.cat(perm, dim=0)
def filter_adj(edge_index, perm, num_nodes):
"""Filter adjacency to only include selected nodes."""
mask = edge_index.new_full((num_nodes,), -1)
mask[perm] = torch.arange(perm.size(0), device=perm.device)
row, col = edge_index
row, col = mask[row], mask[col]
mask_edges = (row >= 0) & (col >= 0)
return torch.stack([row[mask_edges], col[mask_edges]], dim=0), mask_edges
class SAGPooling(nn.Module):
"""
Self-Attention Graph Pooling (SAGPool).
Combines GNN with top-k pooling for attention-based selection.
Args:
in_channels: Input feature dimension
ratio: Pooling ratio
gnn: GNN layer for computing attention
"""
def __init__(self, in_channels, ratio=0.5, gnn=None):
super(SAGPooling, self).__init__()
self.in_channels = in_channels
self.ratio = ratio
if gnn is None:
self.gnn = GCNConv(in_channels, 1)
else:
self.gnn = gnn
def forward(self, x, edge_index, batch=None):
"""
Args:
x: Node features [N, in_channels]
edge_index: Edge indices [2, E]
batch: Batch vector [N]
Returns:
Pooled x, edge_index, batch, perm
"""
# Compute attention scores via GNN
score = self.gnn(x, edge_index).squeeze()
# Select top-k
if batch is None:
perm = topk(score, self.ratio)
else:
perm = topk_per_graph(score, batch, self.ratio)
# Gate features by attention
x = x[perm] * torch.tanh(score[perm]).view(-1, 1)
# Filter edges
edge_index, _ = filter_adj(edge_index, perm, num_nodes=score.size(0))
if batch is not None:
batch = batch[perm]
return x, edge_index, batch, perm
# ============================================================================
# 6. Demonstration: Compare GNN Architectures
# ============================================================================
def demo_gnn_comparison():
"""Compare GCN, GAT, GraphSAGE, GIN on synthetic graph."""
print("="*70)
print("GNN Architecture Comparison")
print("="*70)
# Create synthetic graph
num_nodes = 100
num_features = 32
num_classes = 7
# Random features
x = torch.randn(num_nodes, num_features)
# Random edges (ErdΕs-RΓ©nyi graph)
edge_prob = 0.1
edge_index = []
for i in range(num_nodes):
for j in range(i+1, num_nodes):
if torch.rand(1).item() < edge_prob:
edge_index.append([i, j])
edge_index.append([j, i])
edge_index = torch.tensor(edge_index).t().contiguous()
num_edges = edge_index.size(1)
print(f"Graph: {num_nodes} nodes, {num_edges} edges")
print(f"Features: {num_features}D, Classes: {num_classes}")
print()
# Initialize models
models = {
'GCN': GCN(num_features, 64, num_classes, num_layers=2),
'GAT': GAT(num_features, 8, num_classes, num_layers=2, heads=8),
'GraphSAGE': GraphSAGE(num_features, 64, num_classes, num_layers=2),
# GIN requires batch vector for graph classification
}
# Compare
print(f"{'Model':<15} {'Parameters':<12} {'Output Shape':<15}")
print("-"*70)
for name, model in models.items():
model.eval()
with torch.no_grad():
out = model(x, edge_index)
num_params = sum(p.numel() for p in model.parameters())
print(f"{name:<15} {num_params:<12,} {str(tuple(out.shape)):<15}")
print()
# GIN for graph classification
print("GIN (Graph Classification):")
batch = torch.zeros(num_nodes, dtype=torch.long) # Single graph
gin = GIN(num_features, 64, num_classes, num_layers=3)
gin.eval()
with torch.no_grad():
out_gin = gin(x, edge_index, batch)
gin_params = sum(p.numel() for p in gin.parameters())
print(f" Parameters: {gin_params:,}")
print(f" Output shape: {tuple(out_gin.shape)} (graph-level)")
print()
def demo_pooling():
"""Demonstrate graph pooling methods."""
print("="*70)
print("Graph Pooling Demonstration")
print("="*70)
# Create graph
num_nodes = 20
in_channels = 16
x = torch.randn(num_nodes, in_channels)
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 6, 7, 7, 8, 8, 9],
[1, 0, 2, 1, 3, 2, 4, 3, 6, 5, 7, 6, 8, 7, 9, 8]
])
print(f"Original graph: {num_nodes} nodes, {edge_index.size(1)} edges")
print()
# Top-K pooling
topk_pool = TopKPooling(in_channels, ratio=0.5)
x_topk, edge_topk, _, perm_topk = topk_pool(x, edge_index)
print(f"Top-K Pooling (ratio=0.5):")
print(f" Nodes: {num_nodes} β {x_topk.size(0)}")
print(f" Edges: {edge_index.size(1)} β {edge_topk.size(1)}")
print(f" Selected nodes: {perm_topk.tolist()[:10]}...")
print()
# SAG pooling
sag_pool = SAGPooling(in_channels, ratio=0.5)
x_sag, edge_sag, _, perm_sag = sag_pool(x, edge_index)
print(f"SAG Pooling (ratio=0.5):")
print(f" Nodes: {num_nodes} β {x_sag.size(0)}")
print(f" Edges: {edge_index.size(1)} β {edge_sag.size(1)}")
print(f" Selected nodes: {perm_sag.tolist()[:10]}...")
print()
def demo_complexity_analysis():
"""Analyze computational complexity of different GNN layers."""
print("="*70)
print("Computational Complexity Analysis")
print("="*70)
configs = [
(100, 1000, 32, 64), # Small graph
(1000, 10000, 64, 128), # Medium graph
(10000, 100000, 128, 256) # Large graph
]
print(f"{'Graph Size':<20} {'GCN':<15} {'GAT (8 heads)':<20} {'GraphSAGE':<15}")
print("-"*70)
for num_nodes, num_edges, in_dim, out_dim in configs:
# GCN: O(E * in_dim * out_dim)
gcn_flops = num_edges * in_dim * out_dim
# GAT: O(E * heads * out_dim * (in_dim + 2*out_dim))
heads = 8
gat_flops = num_edges * heads * (in_dim * out_dim + 2 * out_dim * out_dim)
# GraphSAGE: Similar to GCN but with separate root transform
sage_flops = num_edges * in_dim * out_dim + num_nodes * in_dim * out_dim
graph_desc = f"N={num_nodes}, E={num_edges}"
print(f"{graph_desc:<20} {gcn_flops/1e6:>10.1f}M {gat_flops/1e6:>15.1f}M {sage_flops/1e6:>12.1f}M")
print()
print("Note: FLOPs are approximate (matrix multiplications only)")
print()
# ============================================================================
# 7. When to Use Which GNN?
# ============================================================================
def print_method_comparison():
"""Print comprehensive comparison of GNN methods."""
print("="*70)
print("GNN Method Comparison and Selection Guide")
print("="*70)
print()
comparison = """
ββββββββββββββββ¬ββββββββββββββ¬βββββββββββββββ¬βββββββββββββββ¬ββββββββββββββ
β Method β Expressiveness β Scalability β Inductive β Best For β
ββββββββββββββββΌββββββββββββββΌβββββββββββββββΌβββββββββββββββΌββββββββββββββ€
β GCN β Medium β High β No β Fast β
β β β β β baseline β
ββββββββββββββββΌββββββββββββββΌβββββββββββββββΌβββββββββββββββΌββββββββββββββ€
β GAT β High β Medium β Yes β Heterophily,β
β β β β β interpret. β
ββββββββββββββββΌββββββββββββββΌβββββββββββββββΌβββββββββββββββΌββββββββββββββ€
β GraphSAGE β Medium β Very High β Yes β Large-scale,β
β β β β β dynamic β
ββββββββββββββββΌββββββββββββββΌβββββββββββββββΌβββββββββββββββΌββββββββββββββ€
β GIN β Highest β Medium β Yes β Graph class,β
β β β β β chemistry β
ββββββββββββββββ΄ββββββββββββββ΄βββββββββββββββ΄βββββββββββββββ΄ββββββββββββββ
**Decision Guide:**
1. **Use GCN if:**
- Need fast baseline
- Graph fits in memory
- Node classification on homophilic graphs
- Simplicity preferred
2. **Use GAT if:**
- Need interpretability (attention weights)
- Graph has heterophily
- Different neighbor importance
- Smaller graphs (<100K nodes)
3. **Use GraphSAGE if:**
- Large graphs (>1M nodes)
- Inductive learning (new nodes)
- Production systems
- Limited memory
4. **Use GIN if:**
- Graph classification
- Maximum expressiveness needed
- Chemical/molecular graphs
- Small to medium graphs
**Performance Expectations:**
Node Classification (Cora):
- GCN: ~81% accuracy, 0.1s training
- GAT: ~83% accuracy, 0.5s training
- GraphSAGE: ~80% accuracy, 0.2s training
Graph Classification (MUTAG):
- GIN: ~89% accuracy (SOTA among MPNNs)
- GCN: ~85% accuracy
- GraphSAGE: ~84% accuracy
Link Prediction (Cora):
- GAT: Best (attention captures relationships)
- GCN: Good baseline
- GraphSAGE: Scalable option
"""
print(comparison)
print()
# ============================================================================
# Run Demonstrations
# ============================================================================
if __name__ == "__main__":
# Set random seed
torch.manual_seed(42)
# Run demos
demo_gnn_comparison()
demo_pooling()
demo_complexity_analysis()
print_method_comparison()
print("="*70)
print("GNN Implementations Complete")
print("="*70)
print()
print("Summary:")
print(" β’ GCN: Fast, simple, spectral motivation")
print(" β’ GAT: Attention-based, interpretable")
print(" β’ GraphSAGE: Scalable, inductive, sampling-based")
print(" β’ GIN: Maximally expressive (within WL limit)")
print(" β’ Pooling: Top-K and SAG for graph-level tasks")
print()
print("All implementations follow message-passing framework:")
print(" h_v^(k+1) = UPDATE(h_v^(k), AGG({h_u^(k) : u β N(v)}))")
print()