Making GNN Networks¶
In this example, you will learn how to easily build a full multi-task GNN using any kind of GNN layer. This tutorial uses the architecture defined by the class FullGraphMultiTaskNetwork.
FullGraphMultiTaskNetwork is an architecture that takes as input node features and (optionally) edge features. It applies a pre-MLP on both sets of features, then passes them into a main GNN network, and finally applies graph output NNs to produces the final outputs on possibly several task levels (graph, node, or edge-level).
The network is very easy to built via a dictionnary of parameter that allow to customize each part of the network.
%load_ext autoreload
%autoreload 2
import torch
from copy import deepcopy
from torch_geometric.data import Data, Batch
from graphium.nn.architectures import FullGraphMultiTaskNetwork
_ = torch.manual_seed(42)
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
We will first create some simple batched graphs that will be used accross the examples. Here, bg is a batch containing 2 graphs with random node features.
in_dim = 5 # Input node-feature dimensions
in_dim_edges = 13 # Input edge-feature dimensions
out_dim = 11 # Desired output node-feature dimensions
# Let's create 2 simple pyg graphs.
# start by specifying the edges with edge index
edge_idx1 = torch.tensor([[0, 1, 2],
[1, 2, 3]])
edge_idx2 = torch.tensor([[2, 0, 0, 1],
[0, 1, 2, 0]])
# specify the node features, convention with variable x
x1 = torch.randn(edge_idx1.max() + 1, in_dim, dtype=torch.float32)
x2 = torch.randn(edge_idx2.max() + 1, in_dim, dtype=torch.float32)
# specify the edge features in e
e1 = torch.randn(edge_idx1.shape[-1], in_dim_edges, dtype=torch.float32)
e2 = torch.randn(edge_idx2.shape[-1], in_dim_edges, dtype=torch.float32)
# make the pyg graph objects with our constructed features
g1 = Data(feat=x1, edge_index=edge_idx1, edge_feat=e1)
g2 = Data(feat=x2, edge_index=edge_idx2, edge_feat=e2)
# put the two graphs into a Batch graph
bg = Batch.from_data_list([g1, g2])
# The batched graph will show as a single graph with 7 nodes
print(bg)
DataBatch(edge_index=[2, 7], feat=[7, 5], edge_feat=[7, 13], batch=[7], ptr=[3])
Building a network¶
To build the network, we must define the arguments to pass at the different steps:
pre_nn_kwargs: The parameters used by a feed-forward neural network on the input node-features, before passing to the convolutional layers. See classFeedForwardNNfor details on the required parameters. Will be ignored if set toNone.gnn_kwargs: The parameters used by a feed-forward graph neural network on the features after it has passed through the pre-processing NN. See classFeedForwardGraphfor details on the required parameters.task_heads_kwargs: The parameters used to specify the different task heads for possibly multiple tasks, each with a specifiedtask_level(graph, node or edge).graph_output_nn_kwargs: The parameters used by the graph output NNs to process the features after the GNN layers. We need to to specify a NN for eachtask_leveloccuring in in thetask_heads_kwargs.
temp_dim_1 = 23
temp_dim_2 = 17
pre_nn_kwargs = {
"in_dim": in_dim,
"out_dim": temp_dim_1,
"hidden_dims": 4,
"depth": 2,
"activation": 'relu',
"last_activation": "none",
"dropout": 0.2
}
gnn_kwargs = {
"in_dim": temp_dim_1,
"out_dim": temp_dim_2,
"hidden_dims": 5,
"depth": 1,
"activation": 'gelu',
"last_activation": None,
"dropout": 0.1,
"normalization": 'layer_norm',
"last_normalization": 'layer_norm',
"residual_type": 'simple',
"virtual_node": None,
"layer_type": 'pyg:gcn',
"layer_kwargs": None
}
task_heads_kwargs = {
"graph-task-1": {
"task_level": 'graph',
"out_dim": 3,
"hidden_dims": 32,
"depth": 2,
"activation": 'relu',
"last_activation": None,
"dropout": 0.1,
"normalization": None,
"last_normalization": None,
"residual_type": "none"
},
"graph-task-2": {
"task_level": 'graph',
"out_dim": 4,
"hidden_dims": 32,
"depth": 2,
"activation": 'relu',
"last_activation": None,
"dropout": 0.1,
"normalization": None,
"last_normalization": None,
"residual_type": "none"
},
"node-task-1": {
"task_level": 'node',
"out_dim": 2,
"hidden_dims": 32,
"depth": 2,
"activation": 'relu',
"last_activation": None,
"dropout": 0.1,
"normalization": None,
"last_normalization": None,
"residual_type": "none"
}
}
graph_output_nn_kwargs = {
"graph": {
"pooling": ['sum'],
"out_dim": temp_dim_2,
"hidden_dims": temp_dim_2,
"depth": 1,
"activation": 'relu',
"last_activation": None,
"dropout": 0.1,
"normalization": None,
"last_normalization": None,
"residual_type": "none"
},
"node": {
"pooling": None,
"out_dim": temp_dim_2,
"hidden_dims": temp_dim_2,
"depth": 1,
"activation": 'relu',
"last_activation": None,
"dropout": 0.1,
"normalization": None,
"last_normalization": None,
"residual_type": "none"
}
}
gnn_net = FullGraphMultiTaskNetwork(
gnn_kwargs=gnn_kwargs,
pre_nn_kwargs=pre_nn_kwargs,
task_heads_kwargs=task_heads_kwargs,
graph_output_nn_kwargs = graph_output_nn_kwargs
)
Applying the network¶
Once the network is defined, we only need to run the forward pass on the input graphs to get a prediction.
The model outputs a dictionary of outputs, one for each task specified in task_heads_kwargs.
graph = deepcopy(bg)
print(graph.feat.shape)
print("\n")
print(gnn_net)
print("\n")
out = gnn_net(graph)
for task in out.keys():
print(task, out[task].shape)
torch.Size([7, 5])
FullGNN
---------
pre-NN(depth=2, ResidualConnectionNone)
[FCLayer[5 -> 4 -> 23]
GNN(depth=1, ResidualConnectionSimple(skip_steps=1))
GCNConvPyg[23 -> 17]
Task heads:
graph-task-1: NN-graph-task-1(depth=2, ResidualConnectionNone)
[FCLayer[17 -> 32 -> 3]
graph-task-2: NN-graph-task-2(depth=2, ResidualConnectionNone)
[FCLayer[17 -> 32 -> 4]
node-task-1: NN-node-task-1(depth=2, ResidualConnectionNone)
[FCLayer[17 -> 32 -> 2]
graph-task-1 torch.Size([2, 3])
graph-task-2 torch.Size([2, 4])
node-task-1 torch.Size([7, 2])
Building a network (using additional edge features)¶
We can further use a GNN that uses additional edge features as inputs. For that, we only need to slightly modify the gnn_kwargs from above. Below, we define the new gnn_edge_kwargs, where we can set the layer_type to one of the GNN layers that support edge features (e.g., pyg:gine or pyg:gps) and add the additional parameter in_dim_edges. Optionally, we can configure a preprocessing NN for the the edge features via a dictionaly pre_nn_edges_kwargs analogous to pre_nn_kwargs above, or skip this step by setting it to None.
temp_dim_edges = 8
gnn_edge_kwargs = {
"in_dim": temp_dim_1,
"in_dim_edges": temp_dim_edges,
"out_dim": temp_dim_2,
"hidden_dims": 5,
"depth": 1,
"activation": 'gelu',
"last_activation": None,
"dropout": 0.1,
"normalization": 'layer_norm',
"last_normalization": 'layer_norm',
"residual_type": 'simple',
"virtual_node": None,
"layer_type": 'pyg:gine',
"layer_kwargs": None
}
pre_nn_edges_kwargs = {
"in_dim": in_dim_edges,
"out_dim": temp_dim_edges,
"hidden_dims": 4,
"depth": 2,
"activation": 'relu',
"last_activation": "none",
"dropout": 0.2
}
gnn_net_edges = FullGraphMultiTaskNetwork(
gnn_kwargs=gnn_edge_kwargs,
pre_nn_kwargs=pre_nn_kwargs,
pre_nn_edges_kwargs=pre_nn_edges_kwargs,
task_heads_kwargs=task_heads_kwargs,
graph_output_nn_kwargs = graph_output_nn_kwargs
)
Applying the network (using additional edge features)¶
Again, we only need to run the forward pass on the input graphs to get a prediction.
graph = deepcopy(bg)
print(graph.feat.shape)
print(graph.edge_feat.shape)
print("\n")
print(gnn_net)
print("\n")
out = gnn_net_edges(graph)
for task in out.keys():
print(task, out[task].shape)
torch.Size([7, 5])
torch.Size([7, 13])
FullGNN
---------
pre-NN(depth=2, ResidualConnectionNone)
[FCLayer[5 -> 4 -> 23]
GNN(depth=1, ResidualConnectionSimple(skip_steps=1))
GCNConvPyg[23 -> 17]
Task heads:
graph-task-1: NN-graph-task-1(depth=2, ResidualConnectionNone)
[FCLayer[17 -> 32 -> 3]
graph-task-2: NN-graph-task-2(depth=2, ResidualConnectionNone)
[FCLayer[17 -> 32 -> 4]
node-task-1: NN-node-task-1(depth=2, ResidualConnectionNone)
[FCLayer[17 -> 32 -> 2]
graph-task-1 torch.Size([2, 3])
graph-task-2 torch.Size([2, 4])
node-task-1 torch.Size([7, 2])