Making GNN Networks¶
In this example, you will learn how to easily build a full GNN network using any kind of GNN layer. This tutorial uses the architecture defined by the class FullDGLNetwork
, which is compatible with any layer that inherits from BaseDGLLayer
.
FullDGLNetwork
is an architecture that takes as input node features and (optionally) edge features. It applies a pre-MLP on both sets of features, then it passes it into a main GNN network, and finally applies a post-MLP to produce the final output (either node predictions or graph property predictions).
The network is very easy to built via a dictionnary of parameter that allow to custom each part of the network.
%load_ext autoreload
%autoreload 2
import torch
import dgl
from copy import deepcopy
from graphium.nn.dgl_layers import PNAMessagePassingLayer
from graphium.nn.architectures import FullDGLNetwork
_ = torch.manual_seed(42)
Using backend: pytorch
We will first create some simple batched graphs that will be used accross the examples. Here, bg
is a batch containing 2 graphs with random node features.
in_dim = 5 # Input node-feature dimensions
out_dim = 11 # Desired output node-feature dimensions
in_dim_edges = 13 # Input edge-feature dimensions
# Let's create 2 simple graphs. Here the tensors represent the connectivity between nodes
g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))
g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))
# We add some node features to the graphs
g1.ndata["feat"] = torch.rand(g1.num_nodes(), in_dim, dtype=float)
g2.ndata["feat"] = torch.rand(g2.num_nodes(), in_dim, dtype=float)
# We also add some edge features to the graphs
g1.edata["feat"] = torch.rand(g1.num_edges(), in_dim_edges, dtype=float)
g2.edata["feat"] = torch.rand(g2.num_edges(), in_dim_edges, dtype=float)
# Finally we batch the graphs in a way compatible with the DGL library
bg = dgl.batch([g1, g2])
bg = dgl.add_self_loop(bg)
# The batched graph will show as a single graph with 7 nodes
print(bg)
Graph(num_nodes=7, num_edges=14, ndata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float64)} edata_schemes={'feat': Scheme(shape=(13,), dtype=torch.float64)})
Building the network¶
To build the network, we must define the arguments to pass at the different steps:
pre_nn_kwargs
: The parameters used by a feed-forward neural network on the input node-features, before passing to the convolutional layers. See classFeedForwardNN
for details on the required parameters. Will be ignored if set toNone
.gnn_kwargs
: The parameters used by a feed-forward graph neural network on the features after it has passed through the pre-processing neural network. See classFeedForwardDGL
for details on the required parameters.graph_output_nn_kwargs
: The parameters used by a feed-forward neural network on the features after the GNN layers. See classFeedForwardNN
for details on the required parameters. Will be ignored if set toNone
.
temp_dim_1 = 23
temp_dim_2 = 17
pre_nn_kwargs = {
"in_dim": in_dim,
"out_dim": temp_dim_1,
"hidden_dims": [4, 4, 4],
"activation": "relu",
"last_activation": "none",
"batch_norm": True,
"dropout": 0.2, }
graph_output_nn_kwargs = {
"in_dim": temp_dim_2,
"out_dim": out_dim,
"hidden_dims": [6, 6],
"activation": "relu",
"last_activation": "sigmoid",
"batch_norm": False,
"dropout": 0., }
layer_kwargs = {
"aggregators": ["mean", "max", "sum"],
"scalers": ["identity", "amplification"],}
gnn_kwargs = {
"in_dim": temp_dim_1,
"out_dim": temp_dim_2,
"hidden_dims": [5, 5, 5, 5, 5, 5],
"residual_type": "densenet",
"residual_skip_steps": 2,
"layer_type": PNAMessagePassingLayer,
"pooling": ["sum"],
"activation": "relu",
"last_activation": "none",
"batch_norm": False,
"dropout": 0.2,
"in_dim_edges": in_dim_edges,
"layer_kwargs": layer_kwargs,
}
gnn_net = FullDGLNetwork(
pre_nn_kwargs=pre_nn_kwargs,
gnn_kwargs=gnn_kwargs,
graph_output_nn_kwargs=graph_output_nn_kwargs).to(float)
Applying the network¶
Once the network is defined, we only need to run the forward pass on the input graphs to get a prediction.
The network will handle the node and edge features depending on it's parameters and layer type.
Chosing between graph property prediction and node property prediction depends on the parameter given by gnn_kwargs["pooling"]
.
graph = deepcopy(bg)
h_in = graph.ndata["feat"]
h_out = gnn_net(graph)
print(h_in.shape)
print(h_out.shape)
print("\n")
print(gnn_net)
torch.Size([7, 5]) torch.Size([1, 11]) DGL_GNN --------- pre-NN(depth=4, ResidualConnectionNone) [FCLayer[5 -> 4 -> 4 -> 4 -> 23] GNN(depth=7, ResidualConnectionDenseNet(skip_steps=2)) PNAMessagePassingLayer[23 -> 5 -> 5 -> 5 -> 5 -> 5 -> 5 -> 17] -> Pooling(['sum']) -> FCLayer(17 -> 17, activation=None) post-NN(depth=3, ResidualConnectionNone) [FCLayer[17 -> 6 -> 6 -> 11]