Using GNN layers¶
The current library implements multiple state-of-the-art graph neural networks. In this tutorial, you will learn how to use the GCN, GIN, Gated-GCN and PNA layers in a simple forward
context.
Other layers such as DGN require additional positional encoding to work.
%load_ext autoreload
%autoreload 2
import torch
import dgl
from copy import deepcopy
from graphium.nn.dgl_layers import (
GCNLayer,
GINLayer,
GATLayer,
GatedGCNLayer,
PNAConvolutionalLayer,
PNAMessagePassingLayer,
)
_ = torch.manual_seed(42)
Using backend: pytorch
We will first create some simple batched graphs that will be used accross the examples. Here, bg
is a batch of 2 graphs containing random node features bg.ndata["h"]
and edge features bg.edata["e"]
.
in_dim = 5 # Input node-feature dimensions
out_dim = 11 # Desired output node-feature dimensions
in_dim_edges = 13 # Input edge-feature dimensions
# Let's create 2 simple graphs. Here the tensors represent the connectivity between nodes
g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))
g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))
# We add some node features to the graphs
g1.ndata["h"] = torch.rand(g1.num_nodes(), in_dim, dtype=float)
g2.ndata["h"] = torch.rand(g2.num_nodes(), in_dim, dtype=float)
# We also add some edge features to the graphs
g1.edata["e"] = torch.rand(g1.num_edges(), in_dim_edges, dtype=float)
g2.edata["e"] = torch.rand(g2.num_edges(), in_dim_edges, dtype=float)
# Finally we batch the graphs in a way compatible with the DGL library
bg = dgl.batch([g1, g2])
bg = dgl.add_self_loop(bg)
# The batched graph will show as a single graph with 7 nodes
print(bg)
Graph(num_nodes=7, num_edges=14, ndata_schemes={'h': Scheme(shape=(5,), dtype=torch.float64)} edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float64)})
GCN Layer¶
To use the GCN layer from the Kipf et al. paper, the steps are very simple. We create the layer with the desired attributes, and apply it to the graph.
Kipf, Thomas N., and Max Welling. "Semi-supervised classification with graph convolutional networks." arXiv preprint arXiv:1609.02907 (2016).
# We first need to extract the node features from the graph.
# The GCN method doesn't support edge features, so we ignore them
graph = deepcopy(bg)
h_in = graph.ndata["h"]
# We create the layer
layer = GCNLayer(
in_dim=in_dim, out_dim=out_dim,
activation="relu", dropout=.3, normalization="batch_norm").to(float)
# We apply the forward loop on the node features
h_out = layer(graph, h_in)
# 7 is the number of nodes, 5 number of input features and 11 number of output features
print(layer)
print(h_in.shape)
print(h_out.shape)
GCNLayer(5 -> 11, activation=relu) torch.Size([7, 5]) torch.Size([7, 11])
GIN Layer¶
To use the GIN layer from the Xu et al. paper, the steps are identical to GCN.
Xu, Keyulu, et al. "How powerful are graph neural networks?." arXiv preprint arXiv:1810.00826 (2018).
graph = deepcopy(bg)
h_in = graph.ndata["h"]
layer = GINLayer(
in_dim=in_dim, out_dim=out_dim,
activation="relu", dropout=.3, normalization="batch_norm").to(float)
h_out = layer(graph, h_in)
print(layer)
print(h_in.shape)
print(h_out.shape)
GINLayer(5 -> 11, activation=relu) torch.Size([7, 5]) torch.Size([7, 11])
GAT Layer¶
To use the GAT layer from the Velickovic et al. paper, the steps are identical to GCN, but the output dimension is multiplied by the number of heads.
Velickovic, Petar, et al. "Graph attention networks." arXiv preprint arXiv:1710.10903 (2017).
graph = deepcopy(bg)
h_in = graph.ndata["h"]
layer = GATLayer(
in_dim=in_dim, out_dim=out_dim, num_heads=5,
activation="elu", dropout=.3, normalization="batch_norm").to(float)
h_out = layer(graph, h_in)
print(layer)
print(h_in.shape)
print(h_out.shape)
GATLayer(5 -> 11 * 5, activation=elu) torch.Size([7, 5]) torch.Size([7, 55])
Gated-GCN Layer¶
To use the Gated-GCN layer from the Bresson et al. paper, the steps are different since the layer requires edge features as inputs, and outputs new edge features.
Bresson, Xavier, and Thomas Laurent. "Residual gated graph convnets." arXiv preprint arXiv:1711.07553 (2017).
# We first need to extract the node and edge features from the graph.
graph = deepcopy(bg)
h_in = graph.ndata["h"]
e_in = graph.edata["e"]
# We create the layer
layer = GatedGCNLayer(
in_dim=in_dim, out_dim=out_dim,
in_dim_edges=in_dim_edges, out_dim_edges=out_dim,
activation="relu", dropout=.3, normalization="batch_norm").to(float)
# We apply the forward loop on the node features
h_out, e_out = layer(graph, h_in, e_in)
# 7 is the number of nodes, 5 number of input features and 11 number of output features
# 13 is the number of input edge features
print(layer)
print(h_in.shape)
print(h_out.shape)
print(e_in.shape)
print(e_out.shape)
GatedGCNLayer(5 -> 11, activation=relu) torch.Size([7, 5]) torch.Size([7, 11]) torch.Size([14, 13]) torch.Size([14, 11])
PNA¶
PNA is a multi-aggregator method proposed by Corso et al.. It supports 2 types of aggregations, convolutional PNA-conv or message passing PNA-msgpass.
PNA: Principal Neighbourhood Aggregation Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic https://arxiv.org/abs/2004.05718
PNA-conv¶
First, let's focus on the PNA-conv. In this case, it works exactly as the GCN and GIN methods. Although not presented in the example, PNA-conv also supports edge features by concatenating them to the node features during convolution.
graph = deepcopy(bg)
h_in = graph.ndata["h"]
# We create the layer, and need to specify the aggregators and scalers
layer = PNAConvolutionalLayer(
in_dim=in_dim, out_dim=out_dim,
aggregators=["mean", "max", "min", "std"],
scalers=["identity", "amplification", "attenuation"],
activation="relu", dropout=.3, normalization="batch_norm").to(float)
h_out = layer(graph, h_in)
print(layer)
print(h_in.shape)
print(h_out.shape)
PNAConvolutionalLayer(5 -> 11, activation=relu) torch.Size([7, 5]) torch.Size([7, 11])
PNA-msgpass¶
The PNA message passing is typically more powerful that the convolutional one, and it supports edges as inputs, but doesn't output edges. It's usage is very similar to the PNA-conv. Here, we also present the option to specify the edge dimensions and features.
graph = deepcopy(bg)
h_in = graph.ndata["h"]
e_in = graph.edata["e"]
# We create the layer, and need to specify the aggregators and scalers
layer = PNAMessagePassingLayer(
in_dim=in_dim, out_dim=out_dim, in_dim_edges=in_dim_edges,
aggregators=["mean", "max", "min", "std"],
scalers=["identity", "amplification", "attenuation"],
activation="relu", dropout=.3, normalization="batch_norm").to(float)
h_out = layer(graph, h_in, e_in)
print(layer)
print(h_in.shape)
print(h_out.shape)
PNAMessagePassingLayer(5 -> 11, activation=relu) torch.Size([7, 5]) torch.Size([7, 11])