Using GNN layers¶
The current library implements multiple state-of-the-art graph neural networks. In this tutorial, you will learn how to use the GCN, GIN, GINE, GPS, Gated-GCN and PNA layers in a simple forward
context.
%load_ext autoreload
%autoreload 2
import torch
import torch_geometric as pyg
from torch_geometric.data import Data, Batch
from copy import deepcopy
from graphium.nn.pyg_layers import (
GCNConvPyg,
GINConvPyg,
GatedGCNPyg,
GINEConvPyg,
GPSLayerPyg,
PNAMessagePassingPyg
)
_ = torch.manual_seed(42)
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
We will first create some simple batched graphs that will be used accross the examples. Here, bg
is a batch containing 2 graphs with random node features.
in_dim = 5 # Input node-feature dimensions
in_dim_edges = 13 # Input edge-feature dimensions
out_dim = 11 # Desired output node-feature dimensions
out_dim_edges = 15 # Desired output edge-feature dimensions
# Let's create 2 simple pyg graphs.
# start by specifying the edges with edge index
edge_idx1 = torch.tensor([[0, 1, 2],
[1, 2, 3]])
edge_idx2 = torch.tensor([[2, 0, 0, 1],
[0, 1, 2, 0]])
# specify the node features, convention with variable x
x1 = torch.randn(edge_idx1.max() + 1, in_dim, dtype=torch.float32)
x2 = torch.randn(edge_idx2.max() + 1, in_dim, dtype=torch.float32)
# specify the edge features in e
e1 = torch.randn(edge_idx1.shape[-1], in_dim_edges, dtype=torch.float32)
e2 = torch.randn(edge_idx2.shape[-1], in_dim_edges, dtype=torch.float32)
# make the pyg graph objects with our constructed features
g1 = Data(feat=x1, edge_index=edge_idx1, edge_feat=e1)
g2 = Data(feat=x2, edge_index=edge_idx2, edge_feat=e2)
# put the two graphs into a Batch graph
bg = Batch.from_data_list([g1, g2])
# The batched graph will show as a single graph with 7 nodes
print(bg)
DataBatch(edge_index=[2, 7], feat=[7, 5], edge_feat=[7, 13], batch=[7], ptr=[3])
GCN Layer¶
To use the GCN layer from the Kipf et al. paper, the steps are very simple. We create the layer with the desired attributes, and apply it to the graph.
Kipf, Thomas N., and Max Welling. "Semi-supervised classification with graph convolutional networks." arXiv preprint arXiv:1609.02907 (2016).
# The GCN method doesn't support edge features, so we ignore them
graph = deepcopy(bg)
print(graph.feat.shape)
# We create the layer
layer = GCNConvPyg(
in_dim=in_dim, out_dim=out_dim,
activation="relu", dropout=.3, normalization="batch_norm")
# We apply the forward loop on the node features
graph = layer(graph)
# 7 is the number of nodes, 5 number of input features and 11 number of output features
print(layer)
print(graph.feat.shape)
torch.Size([7, 5]) GCNConvPyg(5 -> 11, activation=relu) torch.Size([7, 11])
GIN Layer¶
To use the GIN layer from the Xu et al. paper, the steps are identical to GCN.
Xu, Keyulu, et al. "How powerful are graph neural networks?." arXiv preprint arXiv:1810.00826 (2018).
graph = deepcopy(bg)
print(graph.feat.shape)
# We create the layer
layer = GINConvPyg(
in_dim=in_dim, out_dim=out_dim,
activation="relu", dropout=.3, normalization="batch_norm")
# We apply the forward loop on the node features
graph = layer(graph)
# 7 is the number of nodes, 5 number of input features and 11 number of output features
print(layer)
print(graph.feat.shape)
torch.Size([7, 5]) GINConvPyg(5 -> 11, activation=relu) torch.Size([7, 11])
GINE Layer¶
To use the GINE layer from the Hu et al. paper, we also need to provide additional edge features as inputs.
Hu, Weihua, et al. "Strategies for Pre-training Graph Neural Networks." arXiv preprint arXiv:1905.12265 (2019).
# The GINE method uses edge features, so we have to pass the input dimension
graph = deepcopy(bg)
print(graph.feat.shape)
print(graph.edge_feat.shape)
# We create the layer
layer = GINEConvPyg(
in_dim=in_dim, out_dim=out_dim,
in_dim_edges=in_dim_edges,
activation="relu", dropout=.3, normalization="batch_norm")
# We apply the forward loop on the node features
graph = layer(graph)
# 7 is the number of nodes, 5 number of input features and 11 number of output features
# 7 is the number of edges, 13 number of input edge features
print(layer)
print(graph.feat.shape)
torch.Size([7, 5]) torch.Size([7, 13]) GINEConvPyg(5 -> 11, activation=relu) torch.Size([7, 11])
GPS Layer¶
To use the GPS layer from the Rampášek et al. paper, we also need to provide additional edge features as inputs. It is a hybrid approach using both a GNN and transformer in conjunction. Therefore, we further need to specify the GNN type and attention type used in the layer.
Rampášek, Ladislav, et al. "Recipe for a General, Powerful, Scalable Graph Transformer." arXiv preprint arXiv:2205.12454 (2022).
graph = deepcopy(bg)
print(graph.feat.shape)
print(graph.edge_feat.shape)
# We create the layer
layer = GPSLayerPyg(
in_dim=in_dim, out_dim=out_dim,
in_dim_edges=in_dim_edges,
mpnn_type = "pyg:gine", attn_type = "full-attention",
activation="relu", dropout=.3, normalization="batch_norm")
# We apply the forward loop on the node features
graph = layer(graph)
# 7 is the number of nodes, 5 number of input features and 11 number of output features
# 7 is the number of edges, 13 number of input edge features
print(layer)
print(graph.feat.shape)
print(graph.edge_feat.shape)
torch.Size([7, 5]) torch.Size([7, 13]) GPSLayerPyg(5 -> 11, activation=relu) torch.Size([7, 11]) torch.Size([7, 13])
Gated-GCN Layer¶
To use the Gated-GCN layer from the Bresson et al. paper, the steps are different since the layer not only requires edge features as inputs, but also outputs new edge features. Therefore, we have to further specify the number of output edge features
Bresson, Xavier, and Thomas Laurent. "Residual gated graph convnets." arXiv preprint arXiv:1711.07553 (2017).
graph = deepcopy(bg)
print(graph.feat.shape)
print(graph.edge_feat.shape)
# We create the layer
layer = GatedGCNPyg(
in_dim=in_dim, out_dim=out_dim,
in_dim_edges=in_dim_edges, out_dim_edges=out_dim_edges,
activation="relu", dropout=.3, normalization="batch_norm")
# We apply the forward loop on the node features
graph = layer(graph)
# 7 is the number of nodes, 5 number of input features and 11 number of output features
# 7 is the number of edges, 13 number of input edge features and 15 the the number of output edge features
print(layer)
print(graph.feat.shape)
print(graph.edge_feat.shape)
torch.Size([7, 5]) torch.Size([7, 13]) GatedGCNPyg() torch.Size([7, 11]) torch.Size([7, 15])
PNA¶
PNA is a multi-aggregator method proposed by Corso et al.. It supports 2 types of aggregations, convolutional PNA-conv or message passing PNA-msgpass. Here, we provide the typically more powerful PNA-msgpass. It supports edges as inputs, but doesn't output edges. Here, we need to further specify the aggregators and scalers specific to this layer.
Corso, Gabriele, et al. "Principal Neighbourhood Aggregation for Graph Nets." arXiv preprint arXiv:2004.05718 (2020).
graph = deepcopy(bg)
print(graph.feat.shape)
print(graph.edge_feat.shape)
# We create the layer, and need to specify the aggregators and scalers
layer = PNAMessagePassingPyg(
in_dim=in_dim, out_dim=out_dim,
in_dim_edges=in_dim_edges,
aggregators=["mean", "max", "min", "std"],
scalers=["identity", "amplification", "attenuation"],
activation="relu", dropout=.3, normalization="batch_norm")
graph = layer(graph)
print(layer)
print(graph.feat.shape)
torch.Size([7, 5]) torch.Size([7, 13]) PNAMessagePassingPyg() torch.Size([7, 11])