Creating GNN layers¶
One of the primary advantage of this library is the fact that GNN layers are independent from model architecture, thus allowing more flexibility with the code by easily swapping different layer types as hyper-parameters. However, this requires that the layers be implemented using the PyG library, and must be inherited from the class BaseGraphStructure
, which standardizes the inputs, outputs and properties of the layers. Thus, the architecture can be handled independantly using the class FeedForwardPyg
or other custom classes.
We will first start by a simple layer that does not use edges, to a more complex layer that uses edges.
Since these examples are built on top of PyG, we recommend looking at their Docs and Tutorials for more info.
%load_ext autoreload
%autoreload 2
import torch
from torch import Tensor
from torch_geometric.data import Data, Batch
from torch_geometric.nn.conv import MessagePassing
from torch_scatter import scatter
from torch_geometric.typing import (
OptPairTensor,
OptTensor
)
from copy import deepcopy
from typing import Callable, Union, Optional, List
from graphium.nn.base_graph_layer import BaseGraphStructure
from graphium.nn.base_layers import FCLayer
from graphium.utils.decorators import classproperty
_ = torch.manual_seed(42)
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
Define synthetic test graph¶
We define below a small batched graph on which we can test the created layers. You can also use synthetic graphs to define unit tests
in_dim = 5 # Input node-feature dimensions
out_dim = 11 # Desired output node-feature dimensions
in_dim_edges = 13 # Input edge-feature dimensions
# Let's create 2 simple pyg graphs.
# start by specifying the edges with edge index
edge_idx1 = torch.tensor([[0, 1, 2],
[1, 2, 3]])
edge_idx2 = torch.tensor([[2, 0, 0, 1],
[0, 1, 2, 0]])
# specify the node features, convention with variable x
x1 = torch.randn(edge_idx1.max() + 1, in_dim, dtype=torch.float32)
x2 = torch.randn(edge_idx2.max() + 1, in_dim, dtype=torch.float32)
# specify the edge features in e
e1 = torch.randn(edge_idx1.shape[-1], in_dim_edges, dtype=torch.float32)
e2 = torch.randn(edge_idx2.shape[-1], in_dim_edges, dtype=torch.float32)
# make the pyg graph objects with our constructed features
g1 = Data(feat=x1, edge_index=edge_idx1, edge_feat=e1)
g2 = Data(feat=x2, edge_index=edge_idx2, edge_feat=e2)
# put the two graphs into a Batch graph
bg = Batch.from_data_list([g1, g2])
# The batched graph will show as a single graph with 7 nodes
print(bg)
DataBatch(edge_index=[2, 7], feat=[7, 5], edge_feat=[7, 13], batch=[7], ptr=[3])
Creating a simple layer¶
Here, we will show how to create a GNN layer that simples does a mean aggregation on the neighbouring features.
First, for the layer to be fully compatible with the flexible architecture provided by FeedForwardPyg
, it needs to inherit from the class BaseGraphStructure
. This base-layer has multiple virtual methods that must be implemented in any class that inherits from it.
The virtual methods are below
layer_supports_edges
: We want to returnFalse
since our layer doesn't support edgeslayer_inputs_edges
: We want to returnFalse
since our layer doesn't input edgeslayer_outputs_edges
: We want to returnFalse
since our layer doesn't output edgesout_dim_factor
: We want to return1
since the output dimension does not depend on internal parameters.
The example is given below
# inherit from message passing class from pyg
# inherit from BaseGraphStructure
# this example is also based of : https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/simple_conv.html#SimpleConv.forward
class SimpleMeanLayer(MessagePassing, BaseGraphStructure):
def __init__(self,
in_dim: int,
out_dim: int,
activation: Union[Callable, str] = "relu",
dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
aggr: str = "mean",
):
r"""
add documentation in the format shown here for this layer to be automatically added to the graphium api reference
the type information will also be automatically shown in the doc
Parameters:
in_dim:
Input feature dimensions of the layer
out_dim:
Output feature dimensions of the layer
activation:
activation function to use in the layer
dropout:
The ratio of units to dropout. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
aggr:
what aggregation to use ("add", "mean" or "max")
"""
# Initialize the parent class
MessagePassing.__init__(self, node_dim=0, aggr=aggr)
BaseGraphStructure.__init__(self,
in_dim=in_dim,
out_dim=out_dim,
activation=activation,
dropout=dropout,
normalization=normalization)
self.aggr = aggr
self._initialize_activation_dropout_norm()
# Create the mlp layer
# https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_layers.MLP
self.transform = FCLayer(in_dim=in_dim, out_dim=out_dim)
# define the forward function
def forward(self,
batch: Union[Data, Batch],
) -> Union[Data, Batch]:
r"""
similarly add documentation to the functions in the class
Parameters:
batch: pyg Batch graphs to pass through the layer
Returns:
batch: pyg Batch graphs
"""
x = batch.feat
edge_index = batch.edge_index
if isinstance(x, Tensor):
x: OptPairTensor = (x, x)
# propagate_type: (x: OptPairTensor, edge_weight: OptTensor)
out = self.propagate(edge_index, x=x)
out = self.transform(out)
out = self.apply_norm_activation_dropout(out, batch_idx=batch.batch)
batch.feat = out
return batch
def message(self, x_j):
return x_j
# Finally, we define all the virtual properties according to how the class works with proper documentation of course
@classproperty
def layer_supports_edges(cls) -> bool:
r"""
Return a boolean specifying if the layer type supports edges or not.
Returns:
bool:
False for the current class
"""
return False
@property
def layer_inputs_edges(self) -> bool:
r"""
Returns:
bool:
Returns False
"""
return False
@property
def layer_outputs_edges(self) -> bool:
r"""
Returns:
bool:
Always ``False`` for the current class
"""
return False
@property
def out_dim_factor(self) -> int:
r"""
Get the factor by which the output dimension is multiplied for
the next layer.
For standard layers, this will return ``1``.
Returns:
int:
Always ``1`` for the current class
"""
return 1
Test our simple layer¶
Now, we are ready to test the SimpleMeanLayer
on our constructed graphs. Note that in this example, we ignore the edge features since they are not supported.
graph = deepcopy(bg)
print(graph.feat.shape)
layer = SimpleMeanLayer(
in_dim=in_dim,
out_dim=out_dim,
activation="relu",
dropout=.3,
normalization="batch_norm")
graph = layer(graph)
print(graph.feat.shape)
torch.Size([7, 5]) torch.Size([7, 11])
Creating a complex layer with edges¶
Here, we will show how to create a GNN layer that does a mean aggregation on the neighbouring features, concatenated to the edge features with their neighbours. In that case, only the node features will change, and the network will not update the edge features.
The virtual methods will have different outputs
layer_supports_edges
: We want to returnTrue
since our layer does support edgeslayer_inputs_edges
: We want to returnTrue
since our layer does input edgeslayer_outputs_edges
: We want to returnFalse
since our layer will not output new edgesout_dim_factor
: We want to return1
since the output dimension does not depend on internal parameters.
The example is given below
# inherent from message passing class from pyg
# inherent from BaseGraphStructure
# adapting example from graphium/nn/pyg_layers/png_pyg.py
class ComplexMeanLayer(MessagePassing, BaseGraphStructure):
def __init__(self,
in_dim: int,
out_dim: int,
in_dim_edges: int,
activation: Union[Callable, str] = "relu",
dropout: float = 0.0,
normalization: Union[str, Callable] = "none",
aggr: str = "mean",
):
r"""
add documentation in the format shown here for this layer to be automatically added to the graphium api reference
the type information will also be automatically shown in the doc
Parameters:
in_dim:
Input feature dimensions of the layer
out_dim:
Output feature dimensions of the layer
in_dim_edges:
Input edge feature dimension of the layer
activation:
activation function to use in the layer
dropout:
The ratio of units to dropout. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
aggr:
what aggregation to use ("add", "mean" or "max")
"""
# Initialize the parent class
MessagePassing.__init__(self, node_dim=0, aggr=aggr)
BaseGraphStructure.__init__(self,
in_dim=in_dim,
out_dim=out_dim,
activation=activation,
dropout=dropout,
normalization=normalization)
self.aggr = aggr
self._initialize_activation_dropout_norm()
# Create the mlp layer
# https://graphium-docs.datamol.io/stable/api/graphium.nn/graphium.nn.html#graphium.nn.base_layers.MLP
self.transform = FCLayer(in_dim=(in_dim + in_dim_edges), out_dim=out_dim)
# define the forward function
def forward(self,
batch: Union[Data, Batch],
) -> Union[Data, Batch]:
r"""
similarly add documentation to the functions in the class
Parameters:
batch: pyg Batch graphs to pass through the layer
Returns:
batch: pyg Batch graphs
"""
x = batch.feat
edge_index = batch.edge_index
edge_feat = batch.edge_feat
if isinstance(x, Tensor):
x: OptPairTensor = (x, x)
# propagate_type: (x: OptPairTensor, edge_weight: OptTensor)
out = self.propagate(edge_index, x=x, edge_feat=edge_feat, size=None)
out = self.transform(out)
out = self.apply_norm_activation_dropout(out, batch_idx=batch.batch)
batch.feat = out
return batch
def message(self, x_j: Tensor, edge_feat: OptTensor) -> Tensor:
r"""
message function
Parameters:
x_i: node features
x_j: neighbour node features
edge_feat: edge features
Returns:
feat: the message
"""
feat = torch.cat([x_j, edge_feat], dim=-1)
return feat
def aggregate(
self,
inputs: Tensor,
index: Tensor,
edge_index: Tensor,
dim_size: Optional[int] = None,
) -> Tensor:
r"""
aggregate function
Parameters:
inputs: input features
index: index of the nodes
edge_index: edge index
dim_size: dimension size
Returns:
out: aggregated features
"""
out = scatter(inputs, index, 0, None, dim_size, reduce=self.aggr)
return out
# Finally, we define all the virtual properties according to how the class works with proper documentation of course
@classproperty
def layer_supports_edges(cls) -> bool:
r"""
Return a boolean specifying if the layer type supports edges or not.
Returns:
bool:
False for the current class
"""
return True
@property
def layer_inputs_edges(self) -> bool:
r"""
Returns:
bool:
Returns True
"""
return True
@property
def layer_outputs_edges(self) -> bool:
r"""
Returns:
bool:
Always ``True`` for the current class
"""
return True
@property
def out_dim_factor(self) -> int:
r"""
Get the factor by which the output dimension is multiplied for
the next layer.
For standard layers, this will return ``1``.
Returns:
int:
Always ``1`` for the current class
"""
return 1
Test our complex layer¶
Now, we are ready to test the ComplexMeanLayer
on our constructed graphs. Note that in this example, we utilize the edge features and node features
graph = deepcopy(bg)
print(graph.feat.shape)
print(graph.edge_feat.shape)
layer = ComplexMeanLayer(
in_dim=in_dim,
out_dim=out_dim,
in_dim_edges=in_dim_edges,
activation="relu",
dropout=.3,
normalization="batch_norm")
graph = layer(graph)
print(graph.feat.shape)
torch.Size([7, 5]) torch.Size([7, 13]) torch.Size([7, 11])