import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using: {device}")

1. Message Passing FrameworkΒΆ

Graph RepresentationΒΆ

\(G = (V, E)\) with node features \(X \in \mathbb{R}^{n \times d}\)

Adjacency matrix \(A \in \{0,1\}^{n \times n}\)

Message PassingΒΆ

\[h_v^{(l+1)} = \text{UPDATE}^{(l)}\left(h_v^{(l)}, \text{AGGREGATE}^{(l)}(\{h_u^{(l)} : u \in \mathcal{N}(v)\})\right)\]

Aggregate: Combine neighbor features

Update: Combine with own features

πŸ“š Reference Materials:

2. Graph Convolutional Network (GCN)ΒΆ

Spectral MotivationΒΆ

Laplacian: \(L = D - A\)

Normalized: \(\tilde{L} = I - D^{-1/2}AD^{-1/2}\)

GCN LayerΒΆ

\[H^{(l+1)} = \sigma\left(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}H^{(l)}W^{(l)}\right)\]

where \(\tilde{A} = A + I\) (add self-loops), \(\tilde{D}_{ii} = \sum_j \tilde{A}_{ij}\)

Advanced Message Passing TheoryΒΆ

1. Theoretical FoundationsΒΆ

Permutation Invariance Property:

Graph neural networks must satisfy permutation invariance: swapping node indices shouldn’t change the output structure.

For node embeddings: \(f(PX, PAP^T) = Pf(X, A)\) where \(P\) is a permutation matrix.

For graph-level predictions: \(f(PX, PAP^T) = f(X, A)\) (permutation equivariant nodes β†’ invariant graph)

Proof: Message passing with symmetric aggregation (sum, mean, max) preserves this property:

  • Aggregation: \(\text{AGG}(\{h_u : u \in \mathcal{N}(v)\})\) is order-independent

  • Update: Applied identically to all nodes

  • Therefore: \(h_{Pv}^{(l+1)} = \text{UPDATE}(h_{Pv}^{(l)}, \text{AGG}(\{h_{Pu}^{(l)}\}))\) for any permutation \(P\)

Weisfeiler-Lehman Isomorphism Test:

The expressive power of message passing GNNs is bounded by the 1-WL test:

  1. Initialize: \(h_v^{(0)} = x_v\) (node features)

  2. Iterate: \(h_v^{(l+1)} = \text{HASH}(h_v^{(l)}, \{\{h_u^{(l)} : u \in \mathcal{N}(v)\}\})\)

  3. Two graphs are 1-WL equivalent if they produce the same multiset of node labels

Limitations: GNNs cannot distinguish certain graph structures (e.g., regular graphs with same degree).

Higher-order GNNs: k-WL test uses tuples of nodes β†’ more expressive but computationally expensive.

2. Spectral Graph Theory FoundationsΒΆ

Graph Laplacian:

Combinatorial Laplacian: \(L = D - A\)

Properties:

  • Symmetric positive semi-definite: \(x^T L x = \frac{1}{2}\sum_{i,j} A_{ij}(x_i - x_j)^2 \geq 0\)

  • Eigenvalues: \(0 = \lambda_1 \leq \lambda_2 \leq ... \leq \lambda_n \leq 2d_{\max}\)

  • Smallest eigenvalue \(\lambda_1 = 0\) with eigenvector \(\mathbf{1}\) (constant)

  • Number of zero eigenvalues = number of connected components

Normalized Laplacian:

\[\mathcal{L} = I - D^{-1/2}AD^{-1/2} = D^{-1/2}LD^{-1/2}\]

Eigenvalues: \(0 \leq \lambda_i \leq 2\) (better numerical properties)

Graph Fourier Transform:

For signal \(x \in \mathbb{R}^n\) on graph nodes:

  • Forward: \(\hat{x} = U^T x\) where \(U\) contains eigenvectors of \(L\)

  • Inverse: \(x = U\hat{x}\)

Graph convolution in spectral domain: $\(g *_G x = U(U^T g \odot U^T x)\)$

where \(\odot\) is element-wise product.

Problem: Computing eigenvectors is \(O(n^3)\), impractical for large graphs.

3. Spectral Convolution ApproximationsΒΆ

ChebNet (Defferrard et al., 2016):

Approximate spectral filters using Chebyshev polynomials:

\[g_\theta(\Lambda) \approx \sum_{k=0}^{K-1} \theta_k T_k(\tilde{\Lambda})\]

where:

  • \(\Lambda\): diagonal matrix of eigenvalues

  • \(T_k\): Chebyshev polynomial of order \(k\)

  • \(\tilde{\Lambda} = \frac{2}{\lambda_{\max}}\Lambda - I\) (rescaled to \([-1, 1]\))

Recursive computation: \(T_k(x) = 2xT_{k-1}(x) - T_{k-2}(x)\) with \(T_0(x)=1, T_1(x)=x\)

This avoids eigendecomposition! Apply to Laplacian directly:

\[g_\theta * x = \sum_{k=0}^{K-1} \theta_k T_k(\tilde{L})x\]

where \(\tilde{L} = \frac{2}{\lambda_{\max}}L - I\)

Complexity: \(O(K|E|)\) where \(K\) is filter order, \(|E|\) is number of edges.

Localization: \(K\)-order filter uses \(K\)-hop neighborhood (spatially localized).

GCN as First-Order Approximation:

Kipf & Welling (2017) showed GCN is a special case with \(K=1, \lambda_{\max}=2\):

\[g_\theta * x \approx \theta_0 x + \theta_1(L - I)x = \theta_0 x - \theta_1 D^{-1/2}AD^{-1/2}x\]

Simplification: Set \(\theta = \theta_0 = -\theta_1\) to reduce parameters:

\[g_\theta * x = \theta(I + D^{-1/2}AD^{-1/2})x\]

Renormalization trick (add self-loops):

\[\tilde{A} = A + I, \quad \tilde{D}_{ii} = \sum_j \tilde{A}_{ij}\]
\[H^{(l+1)} = \sigma\left(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}H^{(l)}W^{(l)}\right)\]

Why it works:

  • Symmetric normalization: \(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}\) has eigenvalues in \([0, 2]\)

  • Prevents numerical instability from repeated multiplication

  • Self-loops preserve node’s own features

4. GraphSAGE: Inductive LearningΒΆ

Problem with GCN: Transductive (requires full graph during training).

GraphSAGE Solution (Hamilton et al., 2017):

Sample and aggregate from fixed-size neighborhoods:

\[h_v^{(l+1)} = \sigma\left(W^{(l)} \cdot \text{CONCAT}(h_v^{(l)}, \text{AGG}(\{h_u^{(l)} : u \in \mathcal{S}(\mathcal{N}(v))\})\right)\]

where \(\mathcal{S}(\mathcal{N}(v))\) is a sampled subset of neighbors.

Aggregation Functions:

  1. Mean: \(\text{AGG} = \frac{1}{|\mathcal{S}|}\sum_{u \in \mathcal{S}} h_u^{(l)}\)

    • Simple, symmetric

    • GCN is special case with full neighborhood

  2. LSTM: \(\text{AGG} = \text{LSTM}(\{h_u^{(l)}\})\)

    • More expressive

    • Requires random permutation (not truly permutation invariant)

  3. Pooling: \(\text{AGG} = \max(\{\sigma(W_{\text{pool}}h_u^{(l)} + b) : u \in \mathcal{S}\})\)

    • Element-wise max after MLP

    • Captures diverse set of features

Neighbor Sampling:

For \(L\)-layer model, uniformly sample \(S_l\) neighbors at each layer:

  • Computational complexity: \(O(S_1 \cdot S_2 \cdots S_L)\) per node

  • Fixed computational budget (vs. \(O(d^L)\) for full neighborhood)

Training:

Unsupervised: Graph-based loss encourages nearby nodes to have similar embeddings:

\[J(z_u) = -\log(\sigma(z_u^T z_v)) - Q \cdot \mathbb{E}_{v_n \sim P_n(v)} \log(\sigma(-z_u^T z_{v_n}))\]

where \(v\) is node co-occurring on random walk, \(v_n\) is negative sample.

Supervised: Standard cross-entropy for node classification.

5. Graph Attention Networks (GAT)ΒΆ

Motivation: Not all neighbors are equally important.

Attention Mechanism (VeličkoviΔ‡ et al., 2018):

Compute attention coefficients:

\[e_{ij} = a(W\mathbf{h}_i, W\mathbf{h}_j)\]

Typically: \(a([Wh_i \| Wh_j]) = \text{LeakyReLU}(\mathbf{a}^T [Wh_i \| Wh_j])\)

where \(\mathbf{a} \in \mathbb{R}^{2F'}\) is learnable attention vector.

Normalization:

\[\alpha_{ij} = \text{softmax}_j(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i) \cup \{i\}} \exp(e_{ik})}\]

Aggregation:

\[h_i' = \sigma\left(\sum_{j \in \mathcal{N}(i) \cup \{i\}} \alpha_{ij} W h_j\right)\]

Multi-Head Attention:

Concatenate \(K\) attention heads:

\[h_i' = \|_{k=1}^K \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij}^k W^k h_j\right)\]

For final layer, average instead:

\[h_i' = \sigma\left(\frac{1}{K}\sum_{k=1}^K \sum_{j \in \mathcal{N}(i)} \alpha_{ij}^k W^k h_j\right)\]

Properties:

  • Computationally efficient: \(O(|V|FF' + |E|F')\) per layer

  • Applicable to different graph structures (no normalization by degree needed)

  • Implicitly specifies different weights to different neighbors

  • Can be applied inductively to unseen graphs

Comparison to Self-Attention in Transformers:

  • GAT: Attention over graph neighbors (sparse)

  • Transformer: Attention over all positions (dense)

  • GAT complexity: \(O(|E|)\) vs. Transformer: \(O(n^2)\)

6. Graph Pooling StrategiesΒΆ

Problem: Generate graph-level representation from node embeddings.

1. Global Pooling:

Simple aggregation over all nodes:

  • Mean: \(h_G = \frac{1}{|V|}\sum_{v \in V} h_v\)

  • Max: \(h_G = \max_{v \in V} h_v\) (element-wise)

  • Sum: \(h_G = \sum_{v \in V} h_v\)

Permutation invariant but loses structural information.

2. Set2Set (Vinyals et al., 2015):

Use LSTM with attention to aggregate:

\[q_t^* = \text{LSTM}(q_{t-1}^*, m_{t-1})\]
\[e_{i,t} = f(h_i, q_t^*), \quad \alpha_{i,t} = \text{softmax}(e_{i,t})\]
\[m_t = \sum_i \alpha_{i,t} h_i\]

Read graph multiple times, final state is graph embedding.

3. DiffPool (Ying et al., 2018):

Learnable hierarchical pooling:

\[S^{(l)} = \text{softmax}(\text{GNN}_{pool}^{(l)}(A^{(l)}, X^{(l)})) \in \mathbb{R}^{n_l \times n_{l+1}}\]

Cluster assignment matrix \(S^{(l)}\): each row assigns node to one of \(n_{l+1}\) clusters.

Coarsened graph:

\[X^{(l+1)} = S^{(l)T} Z^{(l)}\]
\[A^{(l+1)} = S^{(l)T} A^{(l)} S^{(l)}\]

where \(Z^{(l)} = \text{GNN}_{embed}^{(l)}(A^{(l)}, X^{(l)})\) are node embeddings.

Auxiliary losses:

  • Link prediction: \(L_{LP} = \|A, SS^T\|_F\) (preserve graph structure)

  • Entropy: \(L_E = -\frac{1}{n}\sum_{i,j} S_{ij}\log(S_{ij})\) (encourage confident assignments)

4. TopK Pooling (Gao & Ji, 2019):

Score-based selection:

\[y = \frac{Xp}{\|p\|}\]
\[\text{idx} = \text{top-k}(y), \quad \tilde{X} = X_{\text{idx}} \odot \sigma(y_{\text{idx}})\]
\[\tilde{A} = A_{\text{idx}, \text{idx}}\]

Keep top \(k\) nodes based on learned projection \(p\), gate features by score.

7. Training Techniques and ConsiderationsΒΆ

Over-smoothing Problem:

With many layers, node embeddings become indistinguishable:

\[\lim_{l \to \infty} h_i^{(l)} = c \quad \forall i\]

Causes:

  • Repeated averaging of neighbor features

  • Graph Laplacian has eigenvalue 0

Solutions:

  1. Residual connections: \(h^{(l+1)} = h^{(l)} + \text{GNN}^{(l)}(h^{(l)}, A)\)

  2. Jumping knowledge: Concatenate representations from all layers

  3. PairNorm: Normalize to maintain distance between nodes

  4. Depth-adaptive: Learn to stop propagation dynamically

Graph Normalization:

  • BatchNorm: Normalize across batch, doesn’t work well for graphs

  • GraphNorm: Normalize each graph separately: \(h_i = \frac{h_i - \mu_G}{\sqrt{\sigma_G^2 + \epsilon}}\)

  • Layer normalization: Normalize each node’s features

Mini-batch Training:

Neighbor sampling (GraphSAGE):

  • Sample fixed number of neighbors per layer

  • Trade-off: accuracy vs. computational cost

Cluster-GCN:

  • Partition graph into clusters

  • Sample cluster as mini-batch (preserve local structure)

8. Comparison TableΒΆ

Method

Aggregation

Complexity

Inductive

Attention

Year

GCN

Mean (spectral)

\(O(|E|F)\)

❌

❌

2017

GraphSAGE

Mean/Max/LSTM

\(O(S^L F)\)

βœ…

❌

2017

GAT

Weighted sum

\(O(|E|F)\)

βœ…

βœ…

2018

GIN

Sum

\(O(|E|F)\)

βœ…

❌

2019

ChebNet

Chebyshev poly

\(O(K|E|F)\)

❌

❌

2016

When to use:

  • GCN: Small-medium graphs, transductive setting, baseline

  • GraphSAGE: Large graphs, inductive learning, scalability

  • GAT: Heterogeneous graphs, need interpretability

  • GIN: Maximum expressiveness (1-WL equivalent)

  • Spectral methods: Regular graph structure, frequency analysis

9. Applications and BenchmarksΒΆ

Node Classification:

  • Cora, CiteSeer, PubMed (citation networks)

  • Reddit (social network)

  • PPI (protein-protein interaction)

Graph Classification:

  • MUTAG, PROTEINS (molecular properties)

  • COLLAB, IMDB (social networks)

  • NCI1 (chemical compounds)

Link Prediction:

  • Knowledge graphs (FB15k, WN18)

  • Recommendation systems

Graph Generation:

  • Molecule design (ZINC, QM9)

  • Social network synthesis

Common challenges:

  • Label scarcity (semi-supervised learning)

  • Heterogeneous graphs (multiple node/edge types)

  • Dynamic graphs (temporal evolution)

  • Scaling to billions of nodes

class GCNLayer(nn.Module):
    """Graph Convolutional Layer."""
    
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
    
    def forward(self, X, A):
        # Add self-loops
        A_hat = A + torch.eye(A.size(0), device=A.device)
        
        # Degree matrix
        D_hat = torch.diag(A_hat.sum(1))
        D_hat_inv_sqrt = torch.pow(D_hat, -0.5)
        D_hat_inv_sqrt[torch.isinf(D_hat_inv_sqrt)] = 0
        
        # Normalized adjacency
        A_norm = D_hat_inv_sqrt @ A_hat @ D_hat_inv_sqrt
        
        # Propagate
        return self.linear(A_norm @ X)

class GCN(nn.Module):
    """2-layer GCN."""
    
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.5):
        super().__init__()
        self.gc1 = GCNLayer(in_dim, hidden_dim)
        self.gc2 = GCNLayer(hidden_dim, out_dim)
        self.dropout = dropout
    
    def forward(self, X, A):
        h = F.relu(self.gc1(X, A))
        h = F.dropout(h, p=self.dropout, training=self.training)
        return self.gc2(h, A)

# Test
n_nodes = 10
in_dim = 5
hidden_dim = 16
out_dim = 3

X = torch.randn(n_nodes, in_dim)
A = torch.randint(0, 2, (n_nodes, n_nodes)).float()
A = (A + A.T) / 2  # Make symmetric

model = GCN(in_dim, hidden_dim, out_dim)
out = model(X, A)
print(f"Output shape: {out.shape}")

3. Graph Attention Network (GAT)ΒΆ

Attention MechanismΒΆ

Compute attention coefficients:

\[e_{ij} = \text{LeakyReLU}(a^T[Wh_i \| Wh_j])\]
\[\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})}\]

UpdateΒΆ

\[h_i' = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij}Wh_j\right)\]

Multi-head: Average or concatenate multiple attention heads.

class GATLayer(nn.Module):
    """Graph Attention Layer."""
    
    def __init__(self, in_features, out_features, n_heads=1, dropout=0.6, alpha=0.2):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.n_heads = n_heads
        self.dropout = dropout
        
        self.W = nn.Parameter(torch.zeros(n_heads, in_features, out_features))
        self.a = nn.Parameter(torch.zeros(n_heads, 2 * out_features, 1))
        
        self.leakyrelu = nn.LeakyReLU(alpha)
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W)
        nn.init.xavier_uniform_(self.a)
    
    def forward(self, X, A):
        N = X.size(0)
        
        # Linear transformation for each head
        h = torch.matmul(X.unsqueeze(0), self.W)  # (n_heads, N, out_features)
        
        # Attention mechanism
        h_i = h.unsqueeze(2).repeat(1, 1, N, 1)  # (n_heads, N, N, out_features)
        h_j = h.unsqueeze(1).repeat(1, N, 1, 1)
        
        # Concatenate
        h_cat = torch.cat([h_i, h_j], dim=-1)  # (n_heads, N, N, 2*out_features)
        
        # Compute attention
        e = self.leakyrelu(torch.matmul(h_cat, self.a).squeeze(-1))  # (n_heads, N, N)
        
        # Mask non-neighbors
        zero_vec = -9e15 * torch.ones_like(e)
        attention = torch.where(A.unsqueeze(0) > 0, e, zero_vec)
        
        # Normalize
        attention = F.softmax(attention, dim=-1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        
        # Aggregate
        h_prime = torch.matmul(attention, h)  # (n_heads, N, out_features)
        
        # Average heads
        return h_prime.mean(dim=0)

class GAT(nn.Module):
    """2-layer GAT."""
    
    def __init__(self, in_dim, hidden_dim, out_dim, n_heads=4, dropout=0.6):
        super().__init__()
        self.gat1 = GATLayer(in_dim, hidden_dim, n_heads, dropout)
        self.gat2 = GATLayer(hidden_dim, out_dim, 1, dropout)
    
    def forward(self, X, A):
        h = F.elu(self.gat1(X, A))
        return self.gat2(h, A)

# Test
gat = GAT(in_dim, hidden_dim, out_dim, n_heads=4)
out = gat(X, A)
print(f"GAT output: {out.shape}")
"""
Advanced GNN Implementations: Spectral Methods, GraphSAGE, and Graph Pooling
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import eigh
from collections import defaultdict

# ============================================================================
# 1. Spectral Graph Convolution (ChebNet)
# ============================================================================

class ChebConv(nn.Module):
    """
    Chebyshev Spectral Graph Convolution.
    Approximates spectral filters using Chebyshev polynomials.
    """
    def __init__(self, in_features, out_features, K=3):
        super().__init__()
        self.K = K
        self.linear = nn.Linear(in_features * K, out_features)
        
    def chebyshev_basis(self, L, K):
        """
        Compute Chebyshev polynomial basis T_k(L).
        
        Recursive: T_k(x) = 2x*T_{k-1}(x) - T_{k-2}(x)
        with T_0(x) = I, T_1(x) = x
        """
        N = L.shape[0]
        # Rescale Laplacian to [-1, 1]
        lambda_max = 2.0  # Approximation for normalized Laplacian
        L_scaled = (2.0 / lambda_max) * L - torch.eye(N, device=L.device)
        
        # Initialize
        Tx = [torch.eye(N, device=L.device), L_scaled]
        
        # Compute T_k recursively
        for k in range(2, K):
            Tx.append(2 * L_scaled @ Tx[-1] - Tx[-2])
        
        return Tx[:K]
    
    def forward(self, X, L):
        """
        Args:
            X: Node features (N, in_features)
            L: Normalized Laplacian (N, N)
        """
        # Compute Chebyshev basis
        Tx = self.chebyshev_basis(L, self.K)
        
        # Apply each basis to features
        out = []
        for k in range(self.K):
            out.append(Tx[k] @ X)
        
        # Concatenate and linear transform
        out = torch.cat(out, dim=-1)
        return self.linear(out)


def compute_normalized_laplacian(A):
    """
    Compute normalized Laplacian: L = I - D^{-1/2} A D^{-1/2}
    """
    # Degree matrix
    D = torch.diag(A.sum(1))
    D_inv_sqrt = torch.pow(D, -0.5)
    D_inv_sqrt[torch.isinf(D_inv_sqrt)] = 0
    
    # Normalized Laplacian
    I = torch.eye(A.size(0), device=A.device)
    L = I - D_inv_sqrt @ A @ D_inv_sqrt
    
    return L


# Test ChebConv
print("=" * 60)
print("1. Spectral Graph Convolution (ChebNet)")
print("=" * 60)

n_nodes = 20
in_dim = 5
out_dim = 8
K = 3

X = torch.randn(n_nodes, in_dim)
A = torch.randint(0, 2, (n_nodes, n_nodes)).float()
A = (A + A.T) / 2  # Symmetric
A.fill_diagonal_(0)  # No self-loops

L = compute_normalized_laplacian(A)

cheb_conv = ChebConv(in_dim, out_dim, K=K)
out = cheb_conv(X, L)

print(f"Input shape: {X.shape}")
print(f"Laplacian eigenvalues (first 5): {torch.linalg.eigvalsh(L)[:5].numpy()}")
print(f"Output shape: {out.shape}")
print(f"Chebyshev order K: {K}")

# ============================================================================
# 2. GraphSAGE with Multiple Aggregators
# ============================================================================

class GraphSAGELayer(nn.Module):
    """
    GraphSAGE Layer with different aggregation functions.
    """
    def __init__(self, in_features, out_features, aggregator='mean'):
        super().__init__()
        self.aggregator = aggregator
        
        if aggregator == 'mean':
            self.linear = nn.Linear(in_features * 2, out_features)
        elif aggregator == 'pool':
            self.pool_linear = nn.Linear(in_features, in_features)
            self.linear = nn.Linear(in_features * 2, out_features)
        elif aggregator == 'lstm':
            self.lstm = nn.LSTM(in_features, in_features, batch_first=True)
            self.linear = nn.Linear(in_features * 2, out_features)
    
    def aggregate_mean(self, X, neighbors_list):
        """Mean aggregation"""
        agg = []
        for neighbors in neighbors_list:
            if len(neighbors) > 0:
                agg.append(X[neighbors].mean(0))
            else:
                agg.append(torch.zeros(X.size(1), device=X.device))
        return torch.stack(agg)
    
    def aggregate_pool(self, X, neighbors_list):
        """Max pooling aggregation after MLP"""
        agg = []
        for neighbors in neighbors_list:
            if len(neighbors) > 0:
                neighbor_features = X[neighbors]
                pooled = F.relu(self.pool_linear(neighbor_features))
                agg.append(pooled.max(0)[0])
            else:
                agg.append(torch.zeros(X.size(1), device=X.device))
        return torch.stack(agg)
    
    def aggregate_lstm(self, X, neighbors_list):
        """LSTM aggregation (requires random permutation)"""
        agg = []
        for neighbors in neighbors_list:
            if len(neighbors) > 0:
                neighbor_features = X[neighbors].unsqueeze(0)
                _, (h_n, _) = self.lstm(neighbor_features)
                agg.append(h_n.squeeze(0))
            else:
                agg.append(torch.zeros(X.size(1), device=X.device))
        return torch.stack(agg)
    
    def forward(self, X, A, sample_size=None):
        """
        Args:
            X: Node features (N, in_features)
            A: Adjacency matrix (N, N)
            sample_size: Number of neighbors to sample (None = all)
        """
        N = X.size(0)
        
        # Build neighbor lists with optional sampling
        neighbors_list = []
        for i in range(N):
            neighbors = A[i].nonzero(as_tuple=True)[0].tolist()
            if sample_size and len(neighbors) > sample_size:
                neighbors = np.random.choice(neighbors, sample_size, replace=False).tolist()
            neighbors_list.append(neighbors)
        
        # Aggregate
        if self.aggregator == 'mean':
            h_agg = self.aggregate_mean(X, neighbors_list)
        elif self.aggregator == 'pool':
            h_agg = self.aggregate_pool(X, neighbors_list)
        elif self.aggregator == 'lstm':
            h_agg = self.aggregate_lstm(X, neighbors_list)
        
        # Concatenate with self features and transform
        h = torch.cat([X, h_agg], dim=1)
        return F.relu(self.linear(h))


# Test GraphSAGE aggregators
print("\n" + "=" * 60)
print("2. GraphSAGE Aggregators")
print("=" * 60)

aggregators = ['mean', 'pool', 'lstm']
for agg in aggregators:
    sage_layer = GraphSAGELayer(in_dim, out_dim, aggregator=agg)
    out = sage_layer(X, A, sample_size=5)
    print(f"{agg.upper()} aggregator output shape: {out.shape}")

# ============================================================================
# 3. Graph Pooling: DiffPool
# ============================================================================

class DiffPoolLayer(nn.Module):
    """
    Differentiable Graph Pooling (DiffPool).
    Learns soft cluster assignments.
    """
    def __init__(self, in_features, num_clusters, hidden_dim=16):
        super().__init__()
        # GNN for node embeddings
        self.embed_gnn = GCNLayer(in_features, hidden_dim)
        # GNN for cluster assignments
        self.pool_gnn = GCNLayer(in_features, num_clusters)
    
    def forward(self, X, A):
        """
        Args:
            X: Node features (N, in_features)
            A: Adjacency matrix (N, N)
        
        Returns:
            X_pooled: Cluster features (num_clusters, hidden_dim)
            A_pooled: Cluster adjacency (num_clusters, num_clusters)
            link_loss: Auxiliary loss for preserving structure
            entropy_loss: Auxiliary loss for confident assignments
        """
        # Compute embeddings
        Z = self.embed_gnn(X, A)  # (N, hidden_dim)
        
        # Compute soft assignments
        S = F.softmax(self.pool_gnn(X, A), dim=1)  # (N, num_clusters)
        
        # Coarsen graph
        X_pooled = S.T @ Z  # (num_clusters, hidden_dim)
        A_pooled = S.T @ A @ S  # (num_clusters, num_clusters)
        
        # Auxiliary losses
        # Link prediction loss: ||A - SS^T||_F
        link_loss = torch.norm(A - S @ S.T, p='fro')
        
        # Entropy loss: encourage confident assignments
        entropy = -(S * torch.log(S + 1e-10)).sum(1).mean()
        entropy_loss = -entropy  # Minimize negative entropy
        
        return X_pooled, A_pooled, link_loss, entropy_loss


class GCNLayer(nn.Module):
    """Basic GCN layer (reused from earlier)"""
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
    
    def forward(self, X, A):
        A_hat = A + torch.eye(A.size(0), device=A.device)
        D_hat = torch.diag(A_hat.sum(1))
        D_hat_inv_sqrt = torch.pow(D_hat, -0.5)
        D_hat_inv_sqrt[torch.isinf(D_hat_inv_sqrt)] = 0
        A_norm = D_hat_inv_sqrt @ A_hat @ D_hat_inv_sqrt
        return self.linear(A_norm @ X)


# Test DiffPool
print("\n" + "=" * 60)
print("3. DiffPool Graph Pooling")
print("=" * 60)

num_clusters = 5
diffpool = DiffPoolLayer(in_dim, num_clusters, hidden_dim=8)

X_pooled, A_pooled, link_loss, entropy_loss = diffpool(X, A)

print(f"Original graph: {X.shape[0]} nodes")
print(f"Pooled graph: {X_pooled.shape[0]} clusters")
print(f"Pooled features shape: {X_pooled.shape}")
print(f"Pooled adjacency shape: {A_pooled.shape}")
print(f"Link prediction loss: {link_loss.item():.4f}")
print(f"Entropy loss: {entropy_loss.item():.4f}")

# ============================================================================
# 4. Visualization: Spectral Properties
# ============================================================================

print("\n" + "=" * 60)
print("4. Spectral Analysis Visualization")
print("=" * 60)

# Compute eigenvalues and eigenvectors
L_np = L.numpy()
eigenvalues, eigenvectors = eigh(L_np)

# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Eigenvalue spectrum
axes[0, 0].plot(eigenvalues, 'o-', markersize=6)
axes[0, 0].axhline(y=0, color='r', linestyle='--', alpha=0.3)
axes[0, 0].set_xlabel('Index')
axes[0, 0].set_ylabel('Eigenvalue')
axes[0, 0].set_title('Normalized Laplacian Eigenvalues')
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: First few eigenvectors
for i in range(min(4, len(eigenvalues))):
    axes[0, 1].plot(eigenvectors[:, i], label=f'Ξ»={eigenvalues[i]:.3f}', alpha=0.7)
axes[0, 1].set_xlabel('Node index')
axes[0, 1].set_ylabel('Eigenvector value')
axes[0, 1].set_title('First 4 Eigenvectors')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Adjacency matrix heatmap
im1 = axes[1, 0].imshow(A.numpy(), cmap='Blues', aspect='auto')
axes[1, 0].set_title('Adjacency Matrix')
axes[1, 0].set_xlabel('Node')
axes[1, 0].set_ylabel('Node')
plt.colorbar(im1, ax=axes[1, 0])

# Plot 4: Normalized Laplacian heatmap
im2 = axes[1, 1].imshow(L.numpy(), cmap='RdBu_r', aspect='auto', vmin=-0.5, vmax=1.5)
axes[1, 1].set_title('Normalized Laplacian')
axes[1, 1].set_xlabel('Node')
axes[1, 1].set_ylabel('Node')
plt.colorbar(im2, ax=axes[1, 1])

plt.tight_layout()
plt.savefig('gnn_spectral_analysis.png', dpi=150, bbox_inches='tight')
print("Saved: gnn_spectral_analysis.png")
plt.show()

print(f"\nSpectral gap (Ξ»2 - Ξ»1): {eigenvalues[1] - eigenvalues[0]:.4f}")
print(f"Number of zero eigenvalues: {(eigenvalues < 1e-10).sum()}")
print("(Zero eigenvalues = number of connected components)")

# ============================================================================
# 5. Comparison: GCN vs GraphSAGE vs ChebNet
# ============================================================================

print("\n" + "=" * 60)
print("5. GNN Methods Comparison")
print("=" * 60)

# Create models
gcn = GCNLayer(in_dim, out_dim)
sage = GraphSAGELayer(in_dim, out_dim, aggregator='mean')
cheb = ChebConv(in_dim, out_dim, K=3)

# Forward pass
out_gcn = gcn(X, A)
out_sage = sage(X, A, sample_size=None)  # Full neighborhood
out_cheb = cheb(X, L)

# Compare outputs
print(f"\nGCN output: mean={out_gcn.mean().item():.4f}, std={out_gcn.std().item():.4f}")
print(f"GraphSAGE output: mean={out_sage.mean().item():.4f}, std={out_sage.std().item():.4f}")
print(f"ChebNet output: mean={out_cheb.mean().item():.4f}, std={out_cheb.std().item():.4f}")

# Parameter count
gcn_params = sum(p.numel() for p in gcn.parameters())
sage_params = sum(p.numel() for p in sage.parameters())
cheb_params = sum(p.numel() for p in cheb.parameters())

print(f"\nParameter count:")
print(f"  GCN: {gcn_params}")
print(f"  GraphSAGE (mean): {sage_params}")
print(f"  ChebNet (K={K}): {cheb_params}")

# Visualize output distributions
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, (name, output) in zip(axes, [('GCN', out_gcn), ('GraphSAGE', out_sage), ('ChebNet', out_cheb)]):
    ax.hist(output.detach().numpy().flatten(), bins=30, alpha=0.7, edgecolor='black')
    ax.set_xlabel('Activation value')
    ax.set_ylabel('Frequency')
    ax.set_title(f'{name} Output Distribution')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('gnn_comparison.png', dpi=150, bbox_inches='tight')
print("\nSaved: gnn_comparison.png")
plt.show()

Training on Karate ClubΒΆ

Zachary’s Karate Club is a classic graph dataset with 34 nodes (club members) and edges representing social interactions. After a dispute, the club split into two factions – making this a natural node classification benchmark. Training a GCN on this small graph with only a few labeled nodes demonstrates semi-supervised learning on graphs: the message-passing mechanism propagates label information through the graph structure, allowing accurate classification even with very sparse supervision. This is directly analogous to how GNNs are used in real applications like social network analysis, fraud detection, and molecular property prediction.

# Load Karate Club graph
G = nx.karate_club_graph()
n = len(G.nodes())

# Adjacency
A = torch.tensor(nx.to_numpy_array(G), dtype=torch.float32)

# Features: one-hot node IDs
X = torch.eye(n)

# Labels: two communities
labels = torch.tensor([G.nodes[i]['club'] == 'Mr. Hi' for i in G.nodes()], dtype=torch.long)

# Train GCN
model = GCN(n, 16, 2, dropout=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(X, A)
    loss = F.cross_entropy(out, labels)
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

# Evaluate
model.eval()
with torch.no_grad():
    pred = model(X, A).argmax(dim=1)
    acc = (pred == labels).float().mean()
    print(f"\nAccuracy: {acc:.3f}")
# Visualize
pos = nx.spring_layout(G, seed=42)
colors = ['red' if labels[i] == 0 else 'blue' for i in range(n)]
pred_colors = ['red' if pred[i] == 0 else 'blue' for i in range(n)]

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# True labels
nx.draw(G, pos, node_color=colors, with_labels=True, ax=axes[0], node_size=500)
axes[0].set_title('True Communities', fontsize=12)

# Predictions
nx.draw(G, pos, node_color=pred_colors, with_labels=True, ax=axes[1], node_size=500)
axes[1].set_title('GCN Predictions', fontsize=12)

plt.tight_layout()
plt.show()

SummaryΒΆ

Message Passing:ΒΆ

\[h_v^{(l+1)} = \text{UPDATE}(h_v^{(l)}, \text{AGG}(\{h_u^{(l)}\}))\]

GCN:ΒΆ

\[H^{(l+1)} = \sigma(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}H^{(l)}W)\]

Spectral convolution on graphs.

GAT:ΒΆ

\[h_i' = \sigma\left(\sum_j \alpha_{ij}Wh_j\right)\]

Attention weights neighbors adaptively.

Applications:ΒΆ

  • Node classification

  • Link prediction

  • Graph classification

  • Molecular property prediction

  • Social network analysis

Variants:ΒΆ

  • GraphSAGE (sampling)

  • GIN (graph isomorphism)

  • PNA (principal aggregation)

Next Steps:ΒΆ

  • 14_efficient_transformers.ipynb - Attention mechanisms

  • Explore graph pooling

  • Apply to molecular graphs

Advanced Graph Neural Networks TheoryΒΆ

1. Graph Representation Learning FoundationsΒΆ

1.1 Graph DefinitionΒΆ

Graph: G = (V, E) where:

  • V = {v₁, vβ‚‚, …, vβ‚™} (nodes/vertices)

  • E βŠ† V Γ— V (edges)

Types:

  • Undirected: (u, v) ∈ E ⟺ (v, u) ∈ E

  • Directed: Edges have direction

  • Weighted: Edge weights w: E β†’ ℝ

  • Attributed: Node features X ∈ ℝ^(NΓ—D), edge features E_feat ∈ ℝ^(|E|Γ—D_e)

Adjacency matrix: A ∈ {0,1}^(NΓ—N)

A[i,j] = 1 if (v_i, v_j) ∈ E, else 0

Degree matrix: D ∈ ℝ^(NΓ—N) (diagonal)

D[i,i] = Ξ£β±Ό A[i,j]  (degree of node i)

1.2 The Message Passing FrameworkΒΆ

Core idea: Node representations via neighborhood aggregation.

General message passing:

h_v^(k+1) = UPDATE^(k)(h_v^(k), AGGREGATE^(k)({h_u^(k) : u ∈ N(v)}))

where:

  • h_v^(k): Node v representation at layer k

  • N(v): Neighbors of v

  • AGGREGATE: Permutation-invariant function (sum, mean, max)

  • UPDATE: Combines old state with aggregated messages

Key property: Permutation invariance

AGGREGATE(Ο€(S)) = AGGREGATE(S) for any permutation Ο€

1.3 Why Graphs are Different from GridsΒΆ

CNNs assume:

  • Fixed neighborhood size (e.g., 3Γ—3)

  • Spatial locality and grid structure

  • Translational equivariance

Graphs:

  • Variable neighborhood sizes (d(v) β‰  constant)

  • No spatial coordinates

  • Permutation invariance required

Consequence: Can’t directly apply convolutions. Need graph-specific operations.

2. Graph Convolutional Networks (GCN)ΒΆ

2.1 Spectral MotivationΒΆ

Graph Laplacian:

L = D - A  (unnormalized)
L_norm = I - D^(-1/2) A D^(-1/2)  (normalized)

Eigendecomposition:

L = UΞ›U^T

where U are eigenvectors (graph Fourier basis), Ξ› eigenvalues.

Graph Fourier transform:

xΜ‚ = U^T x  (transform to frequency domain)

Spectral convolution:

x β˜… y = U((U^T x) βŠ™ (U^T y))

Polynomial filter (Chebyshev): Avoid expensive eigendecomposition:

g_ΞΈ β˜… x β‰ˆ Ξ£β‚– ΞΈβ‚– Tβ‚–(LΜƒ) x

where Tₖ is Chebyshev polynomial, L̃ = 2L/λ_max - I.

2.2 GCN Layer (Kipf & Welling, 2017)ΒΆ

Simplification: First-order approximation (K=1):

H^(l+1) = σ(D̃^(-1/2) Ã D̃^(-1/2) H^(l) W^(l))

where:

  • Γƒ = A + I (add self-loops)

  • DΜƒ[i,i] = Ξ£β±Ό Γƒ[i,j]

  • W^(l): Learnable weight matrix

  • Οƒ: Activation function

Intuition:

  1. Self-loops: Include own features

  2. D̃^(-1/2) Ã D̃^(-1/2): Symmetric normalization (average neighbors)

  3. W: Learnable transformation

  4. Οƒ: Non-linearity

Per-node update:

h_v^(l+1) = Οƒ(Ξ£_{u∈N(v)βˆͺ{v}} (1/√(d_u d_v)) h_u^(l) W^(l))

2.3 GCN PropertiesΒΆ

Computational complexity:

  • Matrix multiplication: O(|E| Β· D Β· F) where D, F are feature dimensions

  • Sparse adjacency β†’ efficient for sparse graphs

Receptive field:

  • K-layer GCN sees K-hop neighborhood

  • Node v at layer K influenced by all nodes ≀ K hops away

Over-smoothing:

  • Problem: As K β†’ ∞, all node features β†’ same value

  • Proof sketch: Repeated averaging makes all nodes similar

  • Solution: Shallow networks (2-3 layers), skip connections, normalization

2.4 Mathematical AnalysisΒΆ

Propagation as low-pass filter: GCN smooths features across edges:

h^(l+1) = Γƒ_norm h^(l) W

where Γƒ_norm has eigenvalues in [0, 1].

Repeated application:

h^(K) = (Γƒ_norm)^K h^(0) W_combined

As K increases, (Γƒ_norm)^K emphasizes low-frequency components β†’ over-smoothing.

3. Graph Attention Networks (GAT)ΒΆ

3.1 Attention Mechanism for GraphsΒΆ

Motivation: Not all neighbors equally important. Learn attention weights.

Attention coefficient (unnormalized):

e_ij = LeakyReLU(a^T [W h_i || W h_j])

where:

  • W: Learnable transformation

  • a: Learnable attention vector

  • || : Concatenation

Normalized attention (softmax over neighbors):

α_ij = softmax_j(e_ij) = exp(e_ij) / Σ_{k∈N(i)} exp(e_ik)

Aggregation:

h_i^(l+1) = Οƒ(Ξ£_{j∈N(i)} Ξ±_ij W^(l) h_j^(l))

3.2 Multi-Head AttentionΒΆ

K independent attention heads:

h_i^(l+1) = ||_{k=1}^K Οƒ(Ξ£_{j∈N(i)} Ξ±_ij^k W^k h_j^(l))

Last layer (average instead of concat):

h_i^(L) = Οƒ((1/K) Ξ£_{k=1}^K Ξ£_{j∈N(i)} Ξ±_ij^k W^k h_j^(l))

Advantages:

  • Stabilizes learning

  • Multiple attention patterns (different heads attend to different aspects)

3.3 Computational ComplexityΒΆ

Per-edge attention computation:

  • Attention score: O(F) where F is feature dimension

  • Total: O(|E| Β· F) for all edges

Comparison with GCN:

  • GCN: Fixed normalization 1/√(d_i d_j)

  • GAT: Learned attention Ξ±_ij

  • GAT more flexible but higher computational cost

3.4 Theoretical PropertiesΒΆ

Inductive learning: GAT can generalize to unseen nodes/graphs (unlike GCN which needs full graph).

Attention weights interpretability:

  • Visualize Ξ±_ij to understand which neighbors are important

  • Useful for explainability

Expressiveness: GAT is strictly more expressive than GCN (proven via WL-test).

4. GraphSAGE (SAmple and aggreGatE)ΒΆ

4.1 MotivationΒΆ

Problem: GCN requires full graph in memory (batch training on large graphs infeasible).

Solution: Sample fixed-size neighborhood for each node.

4.2 AlgorithmΒΆ

Sampling: For each node v, sample S neighbors from N(v) uniformly.

Aggregation functions:

  1. Mean aggregator:

h_v^(k) = Οƒ(W Β· MEAN({h_u^(k-1), βˆ€u ∈ N(v)}) + B h_v^(k-1))
  1. LSTM aggregator:

h_v^(k) = Οƒ(W Β· LSTM({h_u^(k-1), βˆ€u ∈ Ο€(N(v))}) + B h_v^(k-1))

where Ο€ is random permutation (LSTM not permutation-invariant, so randomize).

  1. Max-pooling aggregator:

h_v^(k) = Οƒ(W Β· MAX({ReLU(W_pool h_u^(k-1) + b), βˆ€u ∈ N(v)}) + B h_v^(k-1))

Normalization: L2-normalize embeddings at each layer:

h_v^(k) ← h_v^(k) / ||h_v^(k)||_2

4.3 Mini-Batch TrainingΒΆ

Idea: For batch of nodes, only compute embeddings for their sampled neighborhoods.

Multi-layer sampling:

  • Layer K: Sample S_K neighbors per node

  • Layer K-1: Sample S_{K-1} neighbors per (already sampled) node

  • …

  • Recursive neighborhood sampling

Computational complexity:

  • K layers, sample S neighbors per layer

  • Each node in batch has S^K neighbors in computation graph

  • Total: O(B Β· S^K Β· F) where B is batch size

4.4 Loss FunctionsΒΆ

Supervised (node classification):

L = -Ξ£_v y_v log(softmax(h_v^(K)))

Unsupervised (graph structure):

L = -log(Οƒ(h_u^T h_v)) - Q Β· E_{v_n~P_n(v)} log(Οƒ(-h_u^T h_{v_n}))

where:

  • (u, v) are co-occurring nodes (e.g., random walk)

  • v_n are negative samples

  • Q: Number of negative samples

5. Message Passing Neural Networks (MPNN)ΒΆ

5.1 Unified FrameworkΒΆ

General message passing (Gilmer et al., 2017):

Message phase:

m_v^(k+1) = Σ_{u∈N(v)} M_k(h_v^(k), h_u^(k), e_{uv})

Update phase:

h_v^(k+1) = U_k(h_v^(k), m_v^(k+1))

where:

  • M_k: Message function (can depend on edge features e_{uv})

  • U_k: Update function (e.g., GRU)

Readout (graph-level):

y = R({h_v^(K) : v ∈ G})

5.2 Specific InstancesΒΆ

GCN as MPNN:

M(h_v, h_u, e_{uv}) = (1/√(d_u d_v)) h_u W
U(h_v, m_v) = Οƒ(m_v)

GAT as MPNN:

M(h_v, h_u, e_{uv}) = Ξ±_{vu} W h_u
U(h_v, m_v) = Οƒ(m_v)

MPNN with edge features:

M(h_v, h_u, e_{uv}) = NN(concat(h_v, h_u, e_{uv}))

5.3 Expressiveness via WL TestΒΆ

Weisfeiler-Lehman (WL) test: Graph isomorphism test.

1-WL algorithm:

  1. Initialize: Label each node with its features

  2. Iterate: Update label = hash(old_label, multiset of neighbor labels)

  3. Repeat until convergence

Theorem (Xu et al., 2019 - GIN): Standard MPNNs are at most as powerful as 1-WL test.

Limitations:

  • Cannot distinguish certain non-isomorphic graphs

  • Example: Complete bipartite graphs K_{m,n} with mβ‰ n but same node degrees

Going beyond WL:

  • Higher-order GNNs (k-WL)

  • Subgraph GNNs

  • Graph transformers with positional encodings

6. Graph Isomorphism Network (GIN)ΒΆ

6.1 Maximally Expressive GNNΒΆ

Theorem: GIN can distinguish any graphs distinguishable by 1-WL test.

GIN update rule:

h_v^(k+1) = MLP^(k)((1 + Ρ^(k)) h_v^(k) + Σ_{u∈N(v)} h_u^(k))

where:

  • Ξ΅: Learnable scalar (or fixed)

  • MLP: Multi-layer perceptron with at least one hidden layer

Why this specific form?

  • Sum aggregation (provably best for WL expressiveness)

  • (1+Ξ΅) self-weight distinguishes nodes with same neighbor multiset but different features

  • MLP provides expressiveness (can model injective functions)

6.2 Theoretical JustificationΒΆ

Injective aggregation: To distinguish different multisets, aggregation must be injective.

Theorem: Sum is injective over multisets (in countable infinite domain).

Other aggregators:

  • Mean: Not injective (e.g., {1,2} and {1.5,1.5} have same mean)

  • Max: Not injective (e.g., {1,3} and {3,3} have same max)

  • Sum: Injective! βœ“

6.3 Practical ConsiderationsΒΆ

MLP architecture:

  • At least 2 layers (to ensure universality)

  • Batch normalization between layers

  • Dropout for regularization

Ξ΅ initialization:

  • Learnable: Ξ΅ = 0 initially, learned during training

  • Fixed: Ξ΅ = 0 or small positive value

7. Advanced GNN ArchitecturesΒΆ

7.1 PNA (Principal Neighbourhood Aggregation)ΒΆ

Motivation: Different aggregators capture different information.

Multi-aggregator:

h_v = MLP(||_{agg∈{mean,max,min,std}} agg({h_u : u∈N(v)}))

Scalers (degree-aware):

h_v = MLP(||_{agg,scaler} scaler(agg(...)))

where scalers ∈ {amplification, attenuation, identity}:

  • Amplification: d_v Β· agg (emphasize high-degree nodes)

  • Attenuation: (1/d_v) Β· agg (normalize by degree)

7.2 DeeperGCNΒΆ

Problem: Deep GNNs suffer from over-smoothing, gradient vanishing.

Solutions:

  1. Skip connections (residual):

h_v^(l+1) = h_v^(l) + GCN_layer(h_v^(l))
  1. Normalization (layer norm, batch norm):

h_v^(l+1) = LayerNorm(GCN_layer(h_v^(l)))
  1. Pre-activation:

h_v^(l+1) = h_v^(l) + Οƒ(GCN(LayerNorm(h_v^(l))))

Result: Can train 28+ layer GNNs with good performance.

7.3 Graph TransformerΒΆ

Self-attention on graphs:

Challenges:

  • No positional information (unlike sequences)

  • Quadratic complexity in number of nodes

Solutions:

  1. Laplacian positional encoding:

PE_v = [u_1(v), u_2(v), ..., u_k(v)]

where u_i are eigenvectors of graph Laplacian.

  1. Sparse attention (only on edges):

Attention only between (u,v) if (u,v) ∈ E
  1. Relative positional encoding: Encode shortest path distance between nodes.

Graphormer (ICLR 2021):

  • Centrality encoding (degree)

  • Spatial encoding (shortest path)

  • Edge encoding in attention bias

8. Pooling and Hierarchical GNNsΒΆ

8.1 Graph-Level RepresentationsΒΆ

Need: Map variable-size graphs to fixed-size vectors.

Simple pooling:

h_G = Ξ£_v h_v  (sum pooling)
h_G = mean_v h_v  (mean pooling)
h_G = max_v h_v  (max pooling)

8.2 Differentiable Pooling (DiffPool)ΒΆ

Idea: Learn soft cluster assignment.

Cluster assignment matrix: S ∈ ℝ^(NΓ—K)

S[i,k] = probability node i assigned to cluster k

Coarsened features:

X^(new) = S^T X  (K Γ— D)
A^(new) = S^T A S  (K Γ— K)

Loss:

L_pool = ||A - SS^T||_F^2  (link prediction)
       + entropy(S)  (entropy regularization)

8.3 Top-K PoolingΒΆ

Select top-k nodes based on learned scores:

Score function:

y = X p / ||p||  (project features to scalar)

Select top-k:

idx = top_k(y)
X^(new) = X[idx]
A^(new) = A[idx, idx]

Gating:

X^(new) = X[idx] βŠ™ Οƒ(y[idx])

8.4 Self-Attention Pooling (SAGPool)ΒΆ

Attention-based node selection:

Z = GNN(X, A)
y = Οƒ(Z p / ||p||)
idx = top_k(y)
X^(new) = X[idx] βŠ™ y[idx]
A^(new) = A[idx, idx]

9. Heterogeneous and Dynamic GraphsΒΆ

9.1 Heterogeneous GraphsΒΆ

Definition: Multiple node/edge types.

  • V = V₁ βˆͺ Vβ‚‚ βˆͺ … βˆͺ Vβ‚˜ (node types)

  • E = E₁ βˆͺ Eβ‚‚ βˆͺ … βˆͺ Eβ‚™ (edge types)

Example (citation network):

  • Nodes: Papers, Authors, Venues

  • Edges: Writes, Publishes, Cites

Relational GCN (R-GCN):

h_v^(l+1) = Οƒ(Ξ£_{r∈R} Ξ£_{u∈N_r(v)} (1/|N_r(v)|) W_r^(l) h_u^(l) + W_0^(l) h_v^(l))

where:

  • r: Relation type

  • N_r(v): Neighbors of v via relation r

  • W_r: Relation-specific weight matrix

Heterogeneous GAT (HAN):

  • Node-level attention (within meta-path)

  • Semantic-level attention (across meta-paths)

9.2 Dynamic GraphsΒΆ

Temporal graphs: Edges/nodes change over time.

Discrete-time: Sequence of graph snapshots G₁, Gβ‚‚, …, Gβ‚œ

Continuous-time: Edge stream: (u_i, v_i, t_i)

Temporal GNN approaches:

  1. Snapshot-based: Apply GNN to each snapshot independently

  2. RNN-based: GNN + LSTM/GRU across time

  3. Temporal random walk: Sample time-respecting walks

  4. Temporal attention: Attend to past events

TGAT (Temporal Graph Attention):

h_v(t) = Attention({h_u(t_u) : (u,v,t_u) ∈ E, t_u < t})

10. Training TechniquesΒΆ

10.1 NormalizationΒΆ

Batch normalization (problematic for graphs):

  • Variable graph sizes β†’ inconsistent statistics

  • Small graphs β†’ noisy estimates

Better alternatives:

  1. Layer normalization:

h_v ← (h_v - ΞΌ_v) / Οƒ_v

Normalize each node independently.

  1. Graph normalization:

h_G ← (h_G - ΞΌ_G) / Οƒ_G

Normalize across graph.

  1. Pair normalization: Normalize across node pairs (for message passing).

10.2 RegularizationΒΆ

Dropout:

  • Node dropout: Drop nodes

  • Edge dropout: Drop edges (augmentation + regularization)

  • Message dropout: Drop messages

Graph augmentation:

  • Add/remove random edges

  • Mask node features

  • Subgraph sampling

10.3 Loss FunctionsΒΆ

Node classification:

L = -Σ_v∈V_train y_v log ŷ_v

Link prediction:

L = -Ξ£_{(u,v)∈E} log Οƒ(h_u^T h_v) - Ξ£_{(u,v)βˆ‰E} log(1 - Οƒ(h_u^T h_v))

Graph classification:

L = -Ξ£_G y_G log Ε·_G

Contrastive (self-supervised):

L = -log(exp(sim(h_i, h_j)/Ο„) / Ξ£_k exp(sim(h_i, h_k)/Ο„))

11. Complexity and ScalabilityΒΆ

11.1 Computational ComplexityΒΆ

GCN forward pass:

  • Dense: O(NΒ² D F) (matrix multiplication)

  • Sparse: O(E D F) (leverage sparsity)

Memory:

  • Store adjacency: O(E) (sparse) or O(NΒ²) (dense)

  • Node features: O(N D)

  • Activations: O(N F) per layer

11.2 Scalability ChallengesΒΆ

Full-batch training:

  • Need entire graph in memory

  • Infeasible for graphs with millions of nodes

Solutions:

  1. Mini-batching (GraphSAGE):

    • Sample neighborhoods

    • Complexity: O(B S^L D F)

  2. Cluster-GCN:

    • Partition graph into clusters

    • Train on subgraphs

    • Reduces memory and enables parallelism

  3. Graph sampling:

    • Layer-wise sampling

    • Importance sampling

    • Variance reduction techniques

11.3 Large-Scale GNN SystemsΒΆ

Distributed training:

  • Partition nodes across machines

  • Communication overhead for neighbor features

Libraries:

  • DGL (Deep Graph Library): Flexible message passing

  • PyG (PyTorch Geometric): Rich model zoo

  • GraphStorm (AWS): Distributed training

  • Graph-Learn (Alibaba): Industrial scale

12. ApplicationsΒΆ

12.1 Node ClassificationΒΆ

Task: Predict node labels given graph structure and partial labels.

Examples:

  • Social networks: User attributes

  • Citation networks: Paper topics

  • Molecules: Atom properties

Evaluation:

  • Semi-supervised: Train on few labeled, test on rest

  • Metrics: Accuracy, F1-score, Micro/Macro-F1

12.3 Graph ClassificationΒΆ

Task: Classify entire graphs.

Examples:

  • Molecule property prediction (toxicity, solubility)

  • Social network analysis (community detection)

  • Program verification (code graphs)

Approach:

Graph β†’ GNN β†’ Pooling β†’ Readout β†’ Classifier

12.4 Temporal PredictionΒΆ

Traffic forecasting:

  • Nodes: Road segments

  • Edges: Connectivity

  • Features: Traffic speed/volume

  • Predict: Future traffic

Financial networks:

  • Nodes: Assets

  • Edges: Correlations

  • Predict: Price movements

13. Benchmarks and DatasetsΒΆ

13.1 Standard BenchmarksΒΆ

Node classification:

  • Cora, Citeseer, Pubmed (citation networks)

  • Reddit, Yelp (social networks)

  • Ogbn-arxiv, ogbn-products (OGB)

Graph classification:

  • MUTAG, PTC, PROTEINS (bioinformatics)

  • COLLAB, IMDB (social networks)

  • Ogbg-molhiv, ogbg-molpcba (molecules)

Link prediction:

  • WN18RR, FB15k-237 (knowledge graphs)

  • Ogbl-ppa, ogbl-collab (OGB)

13.2 Open Graph Benchmark (OGB)ΒΆ

Large-scale, diverse, realistic:

  • Millions of nodes/edges

  • Realistic splits (temporal, scaffold)

  • Standardized evaluation

Example tasks:

  • ogbn-arxiv: 169K papers, 40-way classification

  • ogbg-molhiv: 41K molecules, binary classification

  • ogbl-citation2: 2.9M papers, link prediction

14. Recent Advances (2020-2024)ΒΆ

14.1 Equivariant GNNsΒΆ

SE(3)-equivariant networks:

  • Preserve 3D rotations and translations

  • Critical for molecular modeling

  • E(n)-GNN, SchNet, DimeNet++

14.2 Graph TransformersΒΆ

Graphormer (ICLR 2021):

  • Won OGB-LSC molecular prediction

  • Centrality + spatial encoding

SAN (Spectral Attention Network):

  • Laplacian positional encoding

  • Full attention with efficiency tricks

14.3 Self-Supervised LearningΒΆ

Contrastive methods:

  • GraphCL: Augmentation-based contrastive

  • BGRL: Bootstrap graph latents

  • SimGRACE: Perturb encoder

Generative:

  • GPT-GNN: Generative pre-training

  • GraphMAE: Masked autoencoder

14.4 Graph Foundation ModelsΒΆ

Motivation: Pre-train on many graphs, fine-tune on target.

Approaches:

  • Graph-level pre-training (molecule datasets)

  • Prompt-based fine-tuning

  • Transfer across domains

15. Limitations and Future DirectionsΒΆ

15.1 Known LimitationsΒΆ

Expressiveness:

  • Limited by 1-WL test

  • Can’t count triangles, cycles

Over-smoothing:

  • Deep GNNs make all nodes similar

  • Limits depth to 2-3 layers

Scalability:

  • Full-batch infeasible for billion-node graphs

  • Distributed training communication overhead

Heterophily:

  • GNNs assume homophily (similar nodes connect)

  • Fail when dissimilar nodes connect

15.2 Open ProblemsΒΆ

Theoretical:

  • Better understanding of GNN expressiveness

  • Generalization bounds

  • When do GNNs work vs. fail?

Architectural:

  • GNNs beyond 1-WL

  • Long-range dependencies (beyond k-hop)

  • Handling heterophily

Scalability:

  • Billion-node, trillion-edge graphs

  • Streaming graphs

  • Efficient inference

15.3 Future DirectionsΒΆ

Foundation models:

  • Universal graph encoders

  • Zero-shot on new graphs

Neurosymbolic:

  • Combine GNNs with logic

  • Interpretable reasoning

Geometric deep learning:

  • Unified theory (symmetry, invariance)

  • Beyond graphs (manifolds, meshes)

16. Key TakeawaysΒΆ

  1. Message passing is core: All GNNs aggregate neighbor information iteratively.

  2. Expressiveness vs. efficiency tradeoff:

    • GIN: Maximally expressive (within WL limit) but expensive

    • GCN: Less expressive but fast and simple

  3. Depth is tricky:

    • Over-smoothing limits to 2-3 layers

    • Solutions: Skip connections, normalization, attention

  4. Scalability crucial:

    • Sampling (GraphSAGE)

    • Clustering (Cluster-GCN)

    • Distributed systems

  5. Inductive vs. transductive:

    • GCN: Transductive (needs full graph)

    • GraphSAGE, GAT: Inductive (generalize to new nodes)

  6. Application-specific design:

    • Molecules: E(3)-equivariance

    • Social: Heterogeneous graphs

    • Temporal: Dynamic GNNs

  7. Pre-training helps: Like NLP/vision, pre-training improves few-shot performance.

17. ReferencesΒΆ

Foundational:

  • Kipf & Welling (2017): β€œSemi-Supervised Classification with GCNs” (ICLR)

  • VeličkoviΔ‡ et al. (2018): β€œGraph Attention Networks” (ICLR)

  • Hamilton et al. (2017): β€œInductive Representation Learning on Large Graphs” (NeurIPS)

  • Gilmer et al. (2017): β€œNeural Message Passing for Quantum Chemistry” (ICML)

  • Xu et al. (2019): β€œHow Powerful are GNNs?” (ICLR)

Advanced architectures:

  • Corso et al. (2020): β€œPrincipal Neighbourhood Aggregation” (NeurIPS)

  • Li et al. (2020): β€œDeeperGCN” (ICLR)

  • Ying et al. (2021): β€œGraphormer” (NeurIPS)

Theory:

  • Morris et al. (2019): β€œWeisfeiler and Leman Go Neural” (AAAI)

  • Bronstein et al. (2021): β€œGeometric Deep Learning” (Book/Tutorial)

Surveys:

  • Wu et al. (2021): β€œA Comprehensive Survey on GNNs”

  • Zhou et al. (2020): β€œGraph Neural Networks: A Review of Methods and Applications”

"""
Complete Graph Neural Network Implementations
==============================================
Includes: GCN, GAT, GraphSAGE, GIN, message passing framework, pooling methods.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import math

# ============================================================================
# 1. Graph Convolutional Network (GCN) Layer
# ============================================================================

class GCNConv(MessagePassing):
    """
    Graph Convolutional Network layer (Kipf & Welling, 2017).
    
    Implements: H^(l+1) = σ(D̃^(-1/2) Ã D̃^(-1/2) H^(l) W^(l))
    where à = A + I (add self-loops), D̃ is degree matrix of Ã.
    
    Args:
        in_channels: Input feature dimension
        out_channels: Output feature dimension
        bias: Whether to add bias
        normalize: Whether to apply symmetric normalization
    """
    def __init__(self, in_channels, out_channels, bias=True, normalize=True):
        super(GCNConv, self).__init__(aggr='add')  # Sum aggregation
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        
        # Learnable weight matrix
        self.weight = nn.Parameter(torch.Tensor(in_channels, out_channels))
        
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()
    
    def reset_parameters(self):
        """Glorot initialization."""
        nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    
    def forward(self, x, edge_index):
        """
        Args:
            x: Node features [N, in_channels]
            edge_index: Edge indices [2, E]
        Returns:
            Updated node features [N, out_channels]
        """
        # Add self-loops
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        # Transform features
        x = torch.matmul(x, self.weight)
        
        # Compute normalization
        if self.normalize:
            row, col = edge_index
            deg = degree(col, x.size(0), dtype=x.dtype)
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
            norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        else:
            norm = None
        
        # Propagate (message passing)
        out = self.propagate(edge_index, x=x, norm=norm)
        
        if self.bias is not None:
            out += self.bias
        
        return out
    
    def message(self, x_j, norm):
        """Construct messages: 1/√(d_i d_j) * h_j."""
        return norm.view(-1, 1) * x_j if norm is not None else x_j


class GCN(nn.Module):
    """
    Multi-layer GCN for node classification.
    
    Args:
        num_features: Input feature dimension
        hidden_dim: Hidden layer dimension
        num_classes: Number of output classes
        num_layers: Number of GCN layers
        dropout: Dropout rate
    """
    def __init__(self, num_features, hidden_dim, num_classes, num_layers=2, dropout=0.5):
        super(GCN, self).__init__()
        self.num_layers = num_layers
        self.dropout = dropout
        
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(num_features, hidden_dim))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
        self.convs.append(GCNConv(hidden_dim, num_classes))
    
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return F.log_softmax(x, dim=1)


# ============================================================================
# 2. Graph Attention Network (GAT) Layer
# ============================================================================

class GATConv(MessagePassing):
    """
    Graph Attention Network layer (VeličkoviΔ‡ et al., 2018).
    
    Implements: h_i = Οƒ(Ξ£_{j∈N(i)} Ξ±_ij W h_j)
    where Ξ±_ij = softmax_j(LeakyReLU(a^T [W h_i || W h_j]))
    
    Args:
        in_channels: Input feature dimension
        out_channels: Output feature dimension
        heads: Number of attention heads
        concat: Whether to concatenate or average multi-head outputs
        dropout: Attention dropout rate
        bias: Whether to add bias
    """
    def __init__(self, in_channels, out_channels, heads=1, concat=True, 
                 dropout=0.0, bias=True):
        super(GATConv, self).__init__(aggr='add')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.dropout = dropout
        
        # Learnable parameters
        self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))
        self.att = nn.Parameter(torch.Tensor(1, heads, 2 * out_channels))
        
        if bias and concat:
            self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        nn.init.xavier_uniform_(self.att)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    
    def forward(self, x, edge_index):
        """
        Args:
            x: Node features [N, in_channels]
            edge_index: Edge indices [2, E]
        Returns:
            Updated node features [N, heads * out_channels] or [N, out_channels]
        """
        # Transform features
        x = torch.matmul(x, self.weight).view(-1, self.heads, self.out_channels)
        
        # Propagate
        out = self.propagate(edge_index, x=x)
        
        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)
        
        if self.bias is not None:
            out += self.bias
        
        return out
    
    def message(self, x_i, x_j, edge_index_i):
        """
        Compute attention coefficients and messages.
        
        Args:
            x_i: Target node features [E, heads, out_channels]
            x_j: Source node features [E, heads, out_channels]
            edge_index_i: Target node indices [E]
        """
        # Concatenate features
        alpha = torch.cat([x_i, x_j], dim=-1)  # [E, heads, 2*out_channels]
        
        # Compute attention scores
        alpha = (alpha * self.att).sum(dim=-1)  # [E, heads]
        alpha = F.leaky_relu(alpha, negative_slope=0.2)
        
        # Softmax over neighbors
        alpha = softmax(alpha, edge_index_i)
        
        # Attention dropout
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        
        # Weight messages by attention
        return x_j * alpha.unsqueeze(-1)


def softmax(src, index):
    """Numerically stable softmax over nodes."""
    src_max = src.max()
    out = (src - src_max).exp()
    
    # Sum per node
    out_sum = torch.zeros(index.max() + 1, out.size(1), device=out.device)
    out_sum.scatter_add_(0, index.unsqueeze(-1).expand_as(out), out)
    
    return out / (out_sum[index] + 1e-16)


class GAT(nn.Module):
    """
    Multi-layer GAT for node classification.
    
    Args:
        num_features: Input feature dimension
        hidden_dim: Hidden layer dimension
        num_classes: Number of output classes
        num_layers: Number of GAT layers
        heads: Number of attention heads
        dropout: Dropout rate
    """
    def __init__(self, num_features, hidden_dim, num_classes, num_layers=2, 
                 heads=8, dropout=0.6):
        super(GAT, self).__init__()
        self.num_layers = num_layers
        self.dropout = dropout
        
        self.convs = nn.ModuleList()
        # First layer
        self.convs.append(GATConv(num_features, hidden_dim, heads=heads, concat=True))
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(hidden_dim * heads, hidden_dim, heads=heads, concat=True))
        # Output layer (average instead of concat)
        self.convs.append(GATConv(hidden_dim * heads, num_classes, heads=1, concat=False))
    
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.elu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return F.log_softmax(x, dim=1)


# ============================================================================
# 3. GraphSAGE Layer
# ============================================================================

class SAGEConv(MessagePassing):
    """
    GraphSAGE layer (Hamilton et al., 2017).
    
    Implements: h_v = Οƒ(W Β· [h_v || AGG({h_u : u ∈ N(v)})])
    
    Args:
        in_channels: Input feature dimension
        out_channels: Output feature dimension
        aggr: Aggregation method ('mean', 'max', 'add')
        normalize: Whether to L2-normalize output
        bias: Whether to add bias
    """
    def __init__(self, in_channels, out_channels, aggr='mean', normalize=True, bias=True):
        super(SAGEConv, self).__init__(aggr=aggr)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        
        # Learnable weights
        self.weight = nn.Parameter(torch.Tensor(in_channels, out_channels))
        self.root_weight = nn.Parameter(torch.Tensor(in_channels, out_channels))
        
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        nn.init.xavier_uniform_(self.root_weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    
    def forward(self, x, edge_index):
        """
        Args:
            x: Node features [N, in_channels]
            edge_index: Edge indices [2, E]
        Returns:
            Updated node features [N, out_channels]
        """
        # Aggregate neighbors
        out = self.propagate(edge_index, x=x)
        out = torch.matmul(out, self.weight)
        
        # Add root node features
        out += torch.matmul(x, self.root_weight)
        
        if self.bias is not None:
            out += self.bias
        
        if self.normalize:
            out = F.normalize(out, p=2, dim=-1)
        
        return out
    
    def message(self, x_j):
        """Messages are just neighbor features."""
        return x_j


class GraphSAGE(nn.Module):
    """
    Multi-layer GraphSAGE for node classification.
    
    Args:
        num_features: Input feature dimension
        hidden_dim: Hidden layer dimension
        num_classes: Number of output classes
        num_layers: Number of SAGE layers
        dropout: Dropout rate
        aggr: Aggregation method
    """
    def __init__(self, num_features, hidden_dim, num_classes, num_layers=2, 
                 dropout=0.5, aggr='mean'):
        super(GraphSAGE, self).__init__()
        self.num_layers = num_layers
        self.dropout = dropout
        
        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(num_features, hidden_dim, aggr=aggr))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_dim, hidden_dim, aggr=aggr))
        self.convs.append(SAGEConv(hidden_dim, num_classes, aggr=aggr, normalize=False))
    
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return F.log_softmax(x, dim=1)


# ============================================================================
# 4. Graph Isomorphism Network (GIN) Layer
# ============================================================================

class GINConv(MessagePassing):
    """
    Graph Isomorphism Network layer (Xu et al., 2019).
    
    Implements: h_v = MLP((1 + Ρ) h_v + Σ_{u∈N(v)} h_u)
    
    Args:
        nn: MLP (torch.nn.Sequential)
        eps: Initial epsilon value
        train_eps: Whether epsilon is learnable
    """
    def __init__(self, nn, eps=0.0, train_eps=True):
        super(GINConv, self).__init__(aggr='add')  # Sum aggregation
        self.nn = nn
        self.initial_eps = eps
        
        if train_eps:
            self.eps = nn.Parameter(torch.Tensor([eps]))
        else:
            self.register_buffer('eps', torch.Tensor([eps]))
    
    def forward(self, x, edge_index):
        """
        Args:
            x: Node features [N, in_channels]
            edge_index: Edge indices [2, E]
        Returns:
            Updated node features [N, out_channels]
        """
        # Aggregate neighbors
        out = self.propagate(edge_index, x=x)
        
        # Add self-features with (1 + Ξ΅)
        out += (1 + self.eps) * x
        
        # Apply MLP
        return self.nn(out)
    
    def message(self, x_j):
        return x_j


class GIN(nn.Module):
    """
    Multi-layer GIN for graph classification.
    
    Args:
        num_features: Input feature dimension
        hidden_dim: Hidden layer dimension
        num_classes: Number of output classes
        num_layers: Number of GIN layers
        dropout: Dropout rate
    """
    def __init__(self, num_features, hidden_dim, num_classes, num_layers=5, dropout=0.5):
        super(GIN, self).__init__()
        self.num_layers = num_layers
        self.dropout = dropout
        
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        # First layer
        mlp = nn.Sequential(
            nn.Linear(num_features, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.convs.append(GINConv(mlp, train_eps=True))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        # Hidden layers
        for _ in range(num_layers - 1):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            self.convs.append(GINConv(mlp, train_eps=True))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        # Graph-level readout
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x, edge_index, batch):
        """
        Args:
            x: Node features [N, num_features]
            edge_index: Edge indices [2, E]
            batch: Batch vector [N] (which graph each node belongs to)
        Returns:
            Graph-level predictions [batch_size, num_classes]
        """
        # Node-level forward
        for i, (conv, bn) in enumerate(zip(self.convs, self.batch_norms)):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        # Graph-level readout (sum pooling)
        x = global_add_pool(x, batch)
        
        # MLP for classification
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.fc2(x)
        
        return F.log_softmax(x, dim=1)


def global_add_pool(x, batch):
    """Sum pooling over nodes in each graph."""
    size = int(batch.max().item()) + 1
    out = torch.zeros(size, x.size(1), device=x.device)
    return out.scatter_add_(0, batch.unsqueeze(-1).expand_as(x), x)


# ============================================================================
# 5. Pooling Layers
# ============================================================================

class TopKPooling(nn.Module):
    """
    Top-K pooling layer.
    
    Selects top-k nodes based on learned projection scores.
    
    Args:
        in_channels: Input feature dimension
        ratio: Pooling ratio (fraction of nodes to keep)
        min_score: Minimum score threshold
    """
    def __init__(self, in_channels, ratio=0.5, min_score=None):
        super(TopKPooling, self).__init__()
        self.in_channels = in_channels
        self.ratio = ratio
        self.min_score = min_score
        
        self.score_layer = nn.Linear(in_channels, 1)
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.score_layer.weight)
        nn.init.zeros_(self.score_layer.bias)
    
    def forward(self, x, edge_index, batch=None):
        """
        Args:
            x: Node features [N, in_channels]
            edge_index: Edge indices [2, E]
            batch: Batch vector [N]
        Returns:
            x: Pooled features
            edge_index: Pooled edges
            batch: Pooled batch
            perm: Selected node indices
        """
        # Compute scores
        score = self.score_layer(x).squeeze()
        
        # Select top-k nodes
        if batch is None:
            perm = topk(score, self.ratio, self.min_score)
        else:
            perm = topk_per_graph(score, batch, self.ratio, self.min_score)
        
        # Filter nodes
        x = x[perm] * torch.sigmoid(score[perm]).view(-1, 1)
        
        # Filter edges
        edge_index, _ = filter_adj(edge_index, perm, num_nodes=score.size(0))
        
        # Update batch
        if batch is not None:
            batch = batch[perm]
        
        return x, edge_index, batch, perm


def topk(x, ratio, min_score=None):
    """Select top-k elements."""
    num_nodes = x.size(0)
    k = max(1, int(ratio * num_nodes))
    _, perm = torch.topk(x, k)
    return perm


def topk_per_graph(x, batch, ratio, min_score=None):
    """Select top-k elements per graph."""
    perm = []
    for i in range(int(batch.max()) + 1):
        mask = batch == i
        indices = mask.nonzero(as_tuple=False).view(-1)
        scores = x[mask]
        k = max(1, int(ratio * scores.size(0)))
        _, local_perm = torch.topk(scores, k)
        perm.append(indices[local_perm])
    return torch.cat(perm, dim=0)


def filter_adj(edge_index, perm, num_nodes):
    """Filter adjacency to only include selected nodes."""
    mask = edge_index.new_full((num_nodes,), -1)
    mask[perm] = torch.arange(perm.size(0), device=perm.device)
    
    row, col = edge_index
    row, col = mask[row], mask[col]
    mask_edges = (row >= 0) & (col >= 0)
    
    return torch.stack([row[mask_edges], col[mask_edges]], dim=0), mask_edges


class SAGPooling(nn.Module):
    """
    Self-Attention Graph Pooling (SAGPool).
    
    Combines GNN with top-k pooling for attention-based selection.
    
    Args:
        in_channels: Input feature dimension
        ratio: Pooling ratio
        gnn: GNN layer for computing attention
    """
    def __init__(self, in_channels, ratio=0.5, gnn=None):
        super(SAGPooling, self).__init__()
        self.in_channels = in_channels
        self.ratio = ratio
        
        if gnn is None:
            self.gnn = GCNConv(in_channels, 1)
        else:
            self.gnn = gnn
    
    def forward(self, x, edge_index, batch=None):
        """
        Args:
            x: Node features [N, in_channels]
            edge_index: Edge indices [2, E]
            batch: Batch vector [N]
        Returns:
            Pooled x, edge_index, batch, perm
        """
        # Compute attention scores via GNN
        score = self.gnn(x, edge_index).squeeze()
        
        # Select top-k
        if batch is None:
            perm = topk(score, self.ratio)
        else:
            perm = topk_per_graph(score, batch, self.ratio)
        
        # Gate features by attention
        x = x[perm] * torch.tanh(score[perm]).view(-1, 1)
        
        # Filter edges
        edge_index, _ = filter_adj(edge_index, perm, num_nodes=score.size(0))
        
        if batch is not None:
            batch = batch[perm]
        
        return x, edge_index, batch, perm


# ============================================================================
# 6. Demonstration: Compare GNN Architectures
# ============================================================================

def demo_gnn_comparison():
    """Compare GCN, GAT, GraphSAGE, GIN on synthetic graph."""
    print("="*70)
    print("GNN Architecture Comparison")
    print("="*70)
    
    # Create synthetic graph
    num_nodes = 100
    num_features = 32
    num_classes = 7
    
    # Random features
    x = torch.randn(num_nodes, num_features)
    
    # Random edges (ErdΕ‘s-RΓ©nyi graph)
    edge_prob = 0.1
    edge_index = []
    for i in range(num_nodes):
        for j in range(i+1, num_nodes):
            if torch.rand(1).item() < edge_prob:
                edge_index.append([i, j])
                edge_index.append([j, i])
    edge_index = torch.tensor(edge_index).t().contiguous()
    
    num_edges = edge_index.size(1)
    
    print(f"Graph: {num_nodes} nodes, {num_edges} edges")
    print(f"Features: {num_features}D, Classes: {num_classes}")
    print()
    
    # Initialize models
    models = {
        'GCN': GCN(num_features, 64, num_classes, num_layers=2),
        'GAT': GAT(num_features, 8, num_classes, num_layers=2, heads=8),
        'GraphSAGE': GraphSAGE(num_features, 64, num_classes, num_layers=2),
        # GIN requires batch vector for graph classification
    }
    
    # Compare
    print(f"{'Model':<15} {'Parameters':<12} {'Output Shape':<15}")
    print("-"*70)
    
    for name, model in models.items():
        model.eval()
        with torch.no_grad():
            out = model(x, edge_index)
        
        num_params = sum(p.numel() for p in model.parameters())
        
        print(f"{name:<15} {num_params:<12,} {str(tuple(out.shape)):<15}")
    
    print()
    
    # GIN for graph classification
    print("GIN (Graph Classification):")
    batch = torch.zeros(num_nodes, dtype=torch.long)  # Single graph
    gin = GIN(num_features, 64, num_classes, num_layers=3)
    gin.eval()
    with torch.no_grad():
        out_gin = gin(x, edge_index, batch)
    
    gin_params = sum(p.numel() for p in gin.parameters())
    print(f"  Parameters: {gin_params:,}")
    print(f"  Output shape: {tuple(out_gin.shape)} (graph-level)")
    print()


def demo_pooling():
    """Demonstrate graph pooling methods."""
    print("="*70)
    print("Graph Pooling Demonstration")
    print("="*70)
    
    # Create graph
    num_nodes = 20
    in_channels = 16
    x = torch.randn(num_nodes, in_channels)
    edge_index = torch.tensor([
        [0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 6, 7, 7, 8, 8, 9],
        [1, 0, 2, 1, 3, 2, 4, 3, 6, 5, 7, 6, 8, 7, 9, 8]
    ])
    
    print(f"Original graph: {num_nodes} nodes, {edge_index.size(1)} edges")
    print()
    
    # Top-K pooling
    topk_pool = TopKPooling(in_channels, ratio=0.5)
    x_topk, edge_topk, _, perm_topk = topk_pool(x, edge_index)
    
    print(f"Top-K Pooling (ratio=0.5):")
    print(f"  Nodes: {num_nodes} β†’ {x_topk.size(0)}")
    print(f"  Edges: {edge_index.size(1)} β†’ {edge_topk.size(1)}")
    print(f"  Selected nodes: {perm_topk.tolist()[:10]}...")
    print()
    
    # SAG pooling
    sag_pool = SAGPooling(in_channels, ratio=0.5)
    x_sag, edge_sag, _, perm_sag = sag_pool(x, edge_index)
    
    print(f"SAG Pooling (ratio=0.5):")
    print(f"  Nodes: {num_nodes} β†’ {x_sag.size(0)}")
    print(f"  Edges: {edge_index.size(1)} β†’ {edge_sag.size(1)}")
    print(f"  Selected nodes: {perm_sag.tolist()[:10]}...")
    print()


def demo_complexity_analysis():
    """Analyze computational complexity of different GNN layers."""
    print("="*70)
    print("Computational Complexity Analysis")
    print("="*70)
    
    configs = [
        (100, 1000, 32, 64),      # Small graph
        (1000, 10000, 64, 128),   # Medium graph
        (10000, 100000, 128, 256) # Large graph
    ]
    
    print(f"{'Graph Size':<20} {'GCN':<15} {'GAT (8 heads)':<20} {'GraphSAGE':<15}")
    print("-"*70)
    
    for num_nodes, num_edges, in_dim, out_dim in configs:
        # GCN: O(E * in_dim * out_dim)
        gcn_flops = num_edges * in_dim * out_dim
        
        # GAT: O(E * heads * out_dim * (in_dim + 2*out_dim))
        heads = 8
        gat_flops = num_edges * heads * (in_dim * out_dim + 2 * out_dim * out_dim)
        
        # GraphSAGE: Similar to GCN but with separate root transform
        sage_flops = num_edges * in_dim * out_dim + num_nodes * in_dim * out_dim
        
        graph_desc = f"N={num_nodes}, E={num_edges}"
        print(f"{graph_desc:<20} {gcn_flops/1e6:>10.1f}M {gat_flops/1e6:>15.1f}M {sage_flops/1e6:>12.1f}M")
    
    print()
    print("Note: FLOPs are approximate (matrix multiplications only)")
    print()


# ============================================================================
# 7. When to Use Which GNN?
# ============================================================================

def print_method_comparison():
    """Print comprehensive comparison of GNN methods."""
    print("="*70)
    print("GNN Method Comparison and Selection Guide")
    print("="*70)
    print()
    
    comparison = """
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Method       β”‚ Expressiveness β”‚ Scalability β”‚ Inductive  β”‚ Best For    β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ GCN          β”‚ Medium      β”‚ High         β”‚ No           β”‚ Fast        β”‚
β”‚              β”‚             β”‚              β”‚              β”‚ baseline    β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ GAT          β”‚ High        β”‚ Medium       β”‚ Yes          β”‚ Heterophily,β”‚
β”‚              β”‚             β”‚              β”‚              β”‚ interpret.  β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ GraphSAGE    β”‚ Medium      β”‚ Very High    β”‚ Yes          β”‚ Large-scale,β”‚
β”‚              β”‚             β”‚              β”‚              β”‚ dynamic     β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ GIN          β”‚ Highest     β”‚ Medium       β”‚ Yes          β”‚ Graph class,β”‚
β”‚              β”‚             β”‚              β”‚              β”‚ chemistry   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

**Decision Guide:**

1. **Use GCN if:**
   - Need fast baseline
   - Graph fits in memory
   - Node classification on homophilic graphs
   - Simplicity preferred

2. **Use GAT if:**
   - Need interpretability (attention weights)
   - Graph has heterophily
   - Different neighbor importance
   - Smaller graphs (<100K nodes)

3. **Use GraphSAGE if:**
   - Large graphs (>1M nodes)
   - Inductive learning (new nodes)
   - Production systems
   - Limited memory

4. **Use GIN if:**
   - Graph classification
   - Maximum expressiveness needed
   - Chemical/molecular graphs
   - Small to medium graphs

**Performance Expectations:**

Node Classification (Cora):
- GCN: ~81% accuracy, 0.1s training
- GAT: ~83% accuracy, 0.5s training
- GraphSAGE: ~80% accuracy, 0.2s training

Graph Classification (MUTAG):
- GIN: ~89% accuracy (SOTA among MPNNs)
- GCN: ~85% accuracy
- GraphSAGE: ~84% accuracy

Link Prediction (Cora):
- GAT: Best (attention captures relationships)
- GCN: Good baseline
- GraphSAGE: Scalable option
"""
    
    print(comparison)
    print()


# ============================================================================
# Run Demonstrations
# ============================================================================

if __name__ == "__main__":
    # Set random seed
    torch.manual_seed(42)
    
    # Run demos
    demo_gnn_comparison()
    demo_pooling()
    demo_complexity_analysis()
    print_method_comparison()
    
    print("="*70)
    print("GNN Implementations Complete")
    print("="*70)
    print()
    print("Summary:")
    print("  β€’ GCN: Fast, simple, spectral motivation")
    print("  β€’ GAT: Attention-based, interpretable")
    print("  β€’ GraphSAGE: Scalable, inductive, sampling-based")
    print("  β€’ GIN: Maximally expressive (within WL limit)")
    print("  β€’ Pooling: Top-K and SAG for graph-level tasks")
    print()
    print("All implementations follow message-passing framework:")
    print("  h_v^(k+1) = UPDATE(h_v^(k), AGG({h_u^(k) : u ∈ N(v)}))")
    print()