Making GNN Networks¶
In this example, you will learn how to easily build a full multi-task GNN using any kind of GNN layer. This tutorial uses the architecture defined by the class FullGraphMultiTaskNetwork
.
FullGraphMultiTaskNetwork
is an architecture that takes as input node features and (optionally) edge features. It applies a pre-MLP on both sets of features, then passes them into a main GNN network, and finally applies graph output NNs to produces the final outputs on possibly several task levels (graph, node, or edge-level).
The network is very easy to built via a dictionnary of parameter that allow to customize each part of the network.
%load_ext autoreload
%autoreload 2
import torch
from copy import deepcopy
from torch_geometric.data import Data, Batch
from graphium.nn.architectures import FullGraphMultiTaskNetwork
_ = torch.manual_seed(42)
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
We will first create some simple batched graphs that will be used accross the examples. Here, bg
is a batch containing 2 graphs with random node features.
in_dim = 5 # Input node-feature dimensions
in_dim_edges = 13 # Input edge-feature dimensions
out_dim = 11 # Desired output node-feature dimensions
# Let's create 2 simple pyg graphs.
# start by specifying the edges with edge index
edge_idx1 = torch.tensor([[0, 1, 2],
[1, 2, 3]])
edge_idx2 = torch.tensor([[2, 0, 0, 1],
[0, 1, 2, 0]])
# specify the node features, convention with variable x
x1 = torch.randn(edge_idx1.max() + 1, in_dim, dtype=torch.float32)
x2 = torch.randn(edge_idx2.max() + 1, in_dim, dtype=torch.float32)
# specify the edge features in e
e1 = torch.randn(edge_idx1.shape[-1], in_dim_edges, dtype=torch.float32)
e2 = torch.randn(edge_idx2.shape[-1], in_dim_edges, dtype=torch.float32)
# make the pyg graph objects with our constructed features
g1 = Data(feat=x1, edge_index=edge_idx1, edge_feat=e1)
g2 = Data(feat=x2, edge_index=edge_idx2, edge_feat=e2)
# put the two graphs into a Batch graph
bg = Batch.from_data_list([g1, g2])
# The batched graph will show as a single graph with 7 nodes
print(bg)
DataBatch(edge_index=[2, 7], feat=[7, 5], edge_feat=[7, 13], batch=[7], ptr=[3])
Building a network¶
To build the network, we must define the arguments to pass at the different steps:
pre_nn_kwargs
: The parameters used by a feed-forward neural network on the input node-features, before passing to the convolutional layers. See classFeedForwardNN
for details on the required parameters. Will be ignored if set toNone
.gnn_kwargs
: The parameters used by a feed-forward graph neural network on the features after it has passed through the pre-processing NN. See classFeedForwardGraph
for details on the required parameters.task_heads_kwargs
: The parameters used to specify the different task heads for possibly multiple tasks, each with a specifiedtask_level
(graph, node or edge).graph_output_nn_kwargs
: The parameters used by the graph output NNs to process the features after the GNN layers. We need to to specify a NN for eachtask_level
occuring in in thetask_heads_kwargs
.
temp_dim_1 = 23
temp_dim_2 = 17
pre_nn_kwargs = {
"in_dim": in_dim,
"out_dim": temp_dim_1,
"hidden_dims": 4,
"depth": 2,
"activation": 'relu',
"last_activation": "none",
"dropout": 0.2
}
gnn_kwargs = {
"in_dim": temp_dim_1,
"out_dim": temp_dim_2,
"hidden_dims": 5,
"depth": 1,
"activation": 'gelu',
"last_activation": None,
"dropout": 0.1,
"normalization": 'layer_norm',
"last_normalization": 'layer_norm',
"residual_type": 'simple',
"virtual_node": None,
"layer_type": 'pyg:gcn',
"layer_kwargs": None
}
task_heads_kwargs = {
"graph-task-1": {
"task_level": 'graph',
"out_dim": 3,
"hidden_dims": 32,
"depth": 2,
"activation": 'relu',
"last_activation": None,
"dropout": 0.1,
"normalization": None,
"last_normalization": None,
"residual_type": "none"
},
"graph-task-2": {
"task_level": 'graph',
"out_dim": 4,
"hidden_dims": 32,
"depth": 2,
"activation": 'relu',
"last_activation": None,
"dropout": 0.1,
"normalization": None,
"last_normalization": None,
"residual_type": "none"
},
"node-task-1": {
"task_level": 'node',
"out_dim": 2,
"hidden_dims": 32,
"depth": 2,
"activation": 'relu',
"last_activation": None,
"dropout": 0.1,
"normalization": None,
"last_normalization": None,
"residual_type": "none"
}
}
graph_output_nn_kwargs = {
"graph": {
"pooling": ['sum'],
"out_dim": temp_dim_2,
"hidden_dims": temp_dim_2,
"depth": 1,
"activation": 'relu',
"last_activation": None,
"dropout": 0.1,
"normalization": None,
"last_normalization": None,
"residual_type": "none"
},
"node": {
"pooling": None,
"out_dim": temp_dim_2,
"hidden_dims": temp_dim_2,
"depth": 1,
"activation": 'relu',
"last_activation": None,
"dropout": 0.1,
"normalization": None,
"last_normalization": None,
"residual_type": "none"
}
}
gnn_net = FullGraphMultiTaskNetwork(
gnn_kwargs=gnn_kwargs,
pre_nn_kwargs=pre_nn_kwargs,
task_heads_kwargs=task_heads_kwargs,
graph_output_nn_kwargs = graph_output_nn_kwargs
)
Applying the network¶
Once the network is defined, we only need to run the forward pass on the input graphs to get a prediction.
The model outputs a dictionary of outputs, one for each task specified in task_heads_kwargs
.
graph = deepcopy(bg)
print(graph.feat.shape)
print("\n")
print(gnn_net)
print("\n")
out = gnn_net(graph)
for task in out.keys():
print(task, out[task].shape)
torch.Size([7, 5]) FullGNN --------- pre-NN(depth=2, ResidualConnectionNone) [FCLayer[5 -> 4 -> 23] GNN(depth=1, ResidualConnectionSimple(skip_steps=1)) GCNConvPyg[23 -> 17] Task heads: graph-task-1: NN-graph-task-1(depth=2, ResidualConnectionNone) [FCLayer[17 -> 32 -> 3] graph-task-2: NN-graph-task-2(depth=2, ResidualConnectionNone) [FCLayer[17 -> 32 -> 4] node-task-1: NN-node-task-1(depth=2, ResidualConnectionNone) [FCLayer[17 -> 32 -> 2] graph-task-1 torch.Size([2, 3]) graph-task-2 torch.Size([2, 4]) node-task-1 torch.Size([7, 2])
Building a network (using additional edge features)¶
We can further use a GNN that uses additional edge features as inputs. For that, we only need to slightly modify the gnn_kwargs
from above. Below, we define the new gnn_edge_kwargs
, where we can set the layer_type
to one of the GNN layers that support edge features (e.g., pyg:gine
or pyg:gps
) and add the additional parameter in_dim_edges
. Optionally, we can configure a preprocessing NN for the the edge features via a dictionaly pre_nn_edges_kwargs
analogous to pre_nn_kwargs
above, or skip this step by setting it to None
.
temp_dim_edges = 8
gnn_edge_kwargs = {
"in_dim": temp_dim_1,
"in_dim_edges": temp_dim_edges,
"out_dim": temp_dim_2,
"hidden_dims": 5,
"depth": 1,
"activation": 'gelu',
"last_activation": None,
"dropout": 0.1,
"normalization": 'layer_norm',
"last_normalization": 'layer_norm',
"residual_type": 'simple',
"virtual_node": None,
"layer_type": 'pyg:gine',
"layer_kwargs": None
}
pre_nn_edges_kwargs = {
"in_dim": in_dim_edges,
"out_dim": temp_dim_edges,
"hidden_dims": 4,
"depth": 2,
"activation": 'relu',
"last_activation": "none",
"dropout": 0.2
}
gnn_net_edges = FullGraphMultiTaskNetwork(
gnn_kwargs=gnn_edge_kwargs,
pre_nn_kwargs=pre_nn_kwargs,
pre_nn_edges_kwargs=pre_nn_edges_kwargs,
task_heads_kwargs=task_heads_kwargs,
graph_output_nn_kwargs = graph_output_nn_kwargs
)
Applying the network (using additional edge features)¶
Again, we only need to run the forward pass on the input graphs to get a prediction.
graph = deepcopy(bg)
print(graph.feat.shape)
print(graph.edge_feat.shape)
print("\n")
print(gnn_net)
print("\n")
out = gnn_net_edges(graph)
for task in out.keys():
print(task, out[task].shape)
torch.Size([7, 5]) torch.Size([7, 13]) FullGNN --------- pre-NN(depth=2, ResidualConnectionNone) [FCLayer[5 -> 4 -> 23] GNN(depth=1, ResidualConnectionSimple(skip_steps=1)) GCNConvPyg[23 -> 17] Task heads: graph-task-1: NN-graph-task-1(depth=2, ResidualConnectionNone) [FCLayer[17 -> 32 -> 3] graph-task-2: NN-graph-task-2(depth=2, ResidualConnectionNone) [FCLayer[17 -> 32 -> 4] node-task-1: NN-node-task-1(depth=2, ResidualConnectionNone) [FCLayer[17 -> 32 -> 2] graph-task-1 torch.Size([2, 3]) graph-task-2 torch.Size([2, 4]) node-task-1 torch.Size([7, 2])