Skip to content

graphium.nn

Base structures for neural networks in the library (to be inherented by other modules)

Base Graph Layer


graphium.nn.base_graph_layer

BaseGraphModule

Bases: BaseGraphStructure, nn.Module

__init__(in_dim, out_dim, activation='relu', dropout=0.0, normalization='none', layer_idx=None, layer_depth=None, droppath_rate=0.0)

Abstract class used to standardize the implementation of Pyg layers in the current library. It will allow a network to seemlesly swap between different GNN layers by better understanding the expected inputs and outputs.

Parameters:

Name Type Description Default
in_dim int

Input feature dimensions of the layer

required
out_dim int

Output feature dimensions of the layer

required
activation Union[str, Callable]

activation function to use in the layer

'relu'
dropout float

The ratio of units to dropout. Must be between 0 and 1

0.0
normalization Union[str, Callable]

Normalization to use. Choices:

  • "none" or None: No normalization
  • "batch_norm": Batch normalization
  • "layer_norm": Layer normalization
  • Callable: Any callable function
'none'
layer_idx Optional[int]

The index of the current layer

None
layer_depth Optional[int]

The total depth (number of layers) associated to this specific layer

None
droppath_rate float

stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382

0.0

BaseGraphStructure

layer_inputs_edges: bool abstractmethod property

Abstract method. Return a boolean specifying if the layer type uses edges as input or not. It is different from layer_supports_input_edges since a layer that supports edges can decide to not use them.

Returns:

Name Type Description
bool bool

Whether the layer uses input edges in the forward pass

layer_outputs_edges: bool abstractmethod property

Abstract method. Return a boolean specifying if the layer type uses edges as input or not. It is different from layer_supports_output_edges since a layer that supports edges can decide to not use them.

Returns:

Name Type Description
bool bool

Whether the layer outputs edges in the forward pass

max_num_edges_per_graph: Optional[int] property writable

Get the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU)

max_num_nodes_per_graph: Optional[int] property writable

Get the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU)

out_dim_factor: int abstractmethod property

Abstract method. Get the factor by which the output dimension is multiplied for the next layer.

For standard layers, this will return 1.

But for others, such as GatLayer, the output is the concatenation of the outputs from each head, so the out_dim gets multiplied by the number of heads, and this function should return the number of heads.

Returns:

Name Type Description
int int

The factor that multiplies the output dimensions

__init__(in_dim, out_dim, activation='relu', dropout=0.0, normalization='none', layer_idx=None, layer_depth=None, droppath_rate=0.0)

Abstract class used to standardize the implementation of Pyg layers in the current library. It will allow a network to seemlesly swap between different GNN layers by better understanding the expected inputs and outputs.

Parameters:

Name Type Description Default
in_dim int

Input feature dimensions of the layer

required
out_dim int

Output feature dimensions of the layer

required
activation Union[str, Callable]

activation function to use in the layer

'relu'
dropout float

The ratio of units to dropout. Must be between 0 and 1

0.0
normalization Union[str, Callable]

Normalization to use. Choices:

  • "none" or None: No normalization
  • "batch_norm": Batch normalization
  • "layer_norm": Layer normalization
  • Callable: Any callable function
'none'
layer_idx Optional[int]

The index of the current layer

None
layer_depth Optional[int]

The total depth (number of layers) associated to this specific layer

None
droppath_rate float

stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382

0.0
__repr__()

Controls how the class is printed

apply_norm_activation_dropout(feat, normalization=True, activation=True, dropout=True, droppath=True, batch_idx=None, batch_size=None)

Apply the different normalization and the dropout to the output layer.

Parameters:

Name Type Description Default
feat Tensor

Feature tensor, to be normalized

required
normalization bool

Whether to apply the normalization

True
activation bool

Whether to apply the activation layer

True
dropout bool

Whether to apply the dropout layer

True
droppath bool

Whether to apply the DropPath layer

True

Returns:

Name Type Description
feat

Normalized and dropped-out features

layer_supports_edges()

Abstract method. Return a boolean specifying if the layer type supports output edges edges or not.

Returns:

Name Type Description
bool bool

Whether the layer supports the use of edges

check_intpus_allow_int(obj, edge_index, size)

Overwrite the check_input to allow for int32 and int16 TODO: Remove when PyG and pytorch supports int32.

Base Layers


graphium.nn.base_layers

DropPath

Bases: nn.Module

__init__(drop_rate)

DropPath class for stochastic depth Deep Networks with Stochastic Depth Gao Huang, Yu Sun, Zhuang Liu, Daniel Sedra and Kilian Weinberger https://arxiv.org/abs/1603.09382

Parameters:

Name Type Description Default
drop_rate float

Drop out probability

required
forward(input, batch_idx, batch_size=None)

Parameters:

Name Type Description Default
input Tensor

torch.Tensor[total_num_nodes, hidden]

required
batch

batch attribute of the batch object, batch.batch

required
batch_size Optional[int]

The batch size. Must be provided when working on IPU

None

Returns:

Type Description
Tensor

torch.Tensor: torch.Tensor[total_num_nodes, hidde]

get_stochastic_drop_rate(drop_rate, layer_idx=None, layer_depth=None) staticmethod

Get the stochastic drop rate from the nominal drop rate, the layer index, and the layer depth.

return drop_rate * (layer_idx / (layer_depth - 1))

Parameters:

Name Type Description Default
drop_rate float

Drop out nominal probability

required
layer_idx Optional[int]

The index of the current layer

None
layer_depth Optional[int]

The total depth (number of layers) associated to this specific layer

None

FCLayer

Bases: nn.Module

in_channels: int property

Get the input channel size. For compatibility with PyG.

out_channels: int property

Get the output channel size. For compatibility with PyG.

__init__(in_dim, out_dim, activation='relu', dropout=0.0, normalization='none', bias=True, init_fn=None, is_readout_layer=False, droppath_rate=0.0)

A simple fully connected and customizable layer. This layer is centered around a torch.nn.Linear module. The order in which transformations are applied is:

  • Dense Layer
  • Activation
  • Dropout (if applicable)
  • Batch Normalization (if applicable)

Parameters:

Name Type Description Default
in_dim int

Input dimension of the layer (the torch.nn.Linear)

required
out_dim int

Output dimension of the layer.

required
dropout float

The ratio of units to dropout. No dropout by default.

0.0
activation Union[str, Callable]

Activation function to use.

'relu'
normalization Union[str, Callable]

Normalization to use. Choices:

  • "none" or None: No normalization
  • "batch_norm": Batch normalization
  • "layer_norm": Layer normalization
  • Callable: Any callable function
'none'
bias bool

Whether to enable bias in for the linear layer.

True
init_fn Optional[Callable]

Initialization function to use for the weight of the layer. Default is \(\(\mathcal{U}(-\sqrt{k}, \sqrt{k})\)\) with \(\(k=\frac{1}{ \text{in_dim}}\)\)

None
is_readout_layer bool

Whether the layer should be treated as a readout layer by replacing of torch.nn.Linear by mup.MuReadout from the muTransfer method https://github.com/microsoft/mup

False
droppath_rate float

stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382

0.0

Attributes:

Name Type Description
dropout int

The ratio of units to dropout.

normalization None or Callable

Normalization layer

linear `torch.nn.Linear`

The linear layer

activation `torch.nn.Module`

The activation layer

init_fn Callable

Initialization function used for the weight of the layer

in_dim int

Input dimension of the linear layer

out_dim int

Output dimension of the linear layer

forward(h)

Apply the FC layer on the input features.

Parameters:

Name Type Description Default
h torch.Tensor

torch.Tensor[..., Din]: Input feature tensor, before the FC. Din is the number of input features

required

Returns:

Type Description
torch.Tensor

torch.Tensor[..., Dout]: Output feature tensor, after the FC. Dout is the number of output features

reset_parameters(init_fn=None)

Reset the parameters of the linear layer using the init_fn.

GRU

Bases: nn.Module

__init__(in_dim, hidden_dim)

Wrapper class for the GRU used by the GNN framework, nn.GRU is used for the Gated Recurrent Unit itself

Parameters:

Name Type Description Default
in_dim int

Input dimension of the GRU layer

required
hidden_dim int

Hidden dimension of the GRU layer.

required
forward(x, y)

Parameters:

Name Type Description Default
x

torch.Tensor[B, N, Din] where Din <= in_dim (difference is padded)

required
y

torch.Tensor[B, N, Dh] where Dh <= hidden_dim (difference is padded)

required

Returns:

Type Description

torch.Tensor: torch.Tensor[B, N, Dh]

MLP

Bases: nn.Module

__init__(in_dim, hidden_dims, out_dim, depth, activation='relu', last_activation='none', dropout=0.0, last_dropout=0.0, normalization='none', last_normalization='none', first_normalization='none', last_layer_is_readout=False, droppath_rate=0.0, constant_droppath_rate=True)

Simple multi-layer perceptron, built of a series of FCLayers

Parameters:

Name Type Description Default
in_dim int

Input dimension of the MLP

required
hidden_dims Union[Iterable[int], int]

Either an integer specifying all the hidden dimensions, or a list of dimensions in the hidden layers.

required
out_dim int

Output dimension of the MLP.

required
depth int

If hidden_dims is an integer, depth is 1 + the number of hidden layers to use. If hidden_dims is a list, then depth must be None or equal to len(hidden_dims) + 1

required
activation Union[str, Callable]

Activation function to use in all the layers except the last. if layers==1, this parameter is ignored

'relu'
last_activation Union[str, Callable]

Activation function to use in the last layer.

'none'
dropout float

The ratio of units to dropout. Must be between 0 and 1

0.0
normalization Union[Type[None], str, Callable]

Normalization to use. Choices:

  • "none" or None: No normalization
  • "batch_norm": Batch normalization
  • "layer_norm": Layer normalization in the hidden layers.
  • Callable: Any callable function

if layers==1, this parameter is ignored

'none'
last_normalization Union[Type[None], str, Callable]

Norrmalization to use after the last layer. Same options as normalization.

'none'
first_normalization Union[Type[None], str, Callable]

Norrmalization to use in before the first layer. Same options as normalization.

'none'
last_dropout float

The ratio of units to dropout at the last layer.

0.0
last_layer_is_readout bool

Whether the last layer should be treated as a readout layer. Allows to use the mup.MuReadout from the muTransfer method https://github.com/microsoft/mup

False
droppath_rate float

stochastic depth drop rate, between 0 and 1. See https://arxiv.org/abs/1603.09382

0.0
constant_droppath_rate bool

If True, drop rates will remain constant accross layers. Otherwise, drop rates will vary stochastically. See DropPath.get_stochastic_drop_rate

True
__repr__()

Controls how the class is printed

forward(h)

Apply the MLP on the input features.

Parameters:

Name Type Description Default
h torch.Tensor

torch.Tensor[..., Din]: Input feature tensor, before the MLP. Din is the number of input features

required

Returns:

Type Description
torch.Tensor

torch.Tensor[..., Dout]: Output feature tensor, after the MLP. Dout is the number of output features

MuReadoutGraphium

Bases: MuReadout

PopTorch-compatible replacement for mup.MuReadout

Not quite a drop-in replacement for mup.MuReadout - you need to specify base_width.

Set base_width to width of base model passed to mup.set_base_shapes to get same results on IPU and CPU. Should still "work" with any other value, but won't give the same results as CPU

MultiheadAttentionMup

Bases: nn.MultiheadAttention

Modifying the MultiheadAttention to work with the muTransfer paradigm. The layers are initialized using the mup package. The _scaled_dot_product_attention normalizes the attention matrix with 1/d instead of 1/sqrt(d) The biased self-attention option is added to have 3D attention bias.

TransformerEncoderLayerMup

Bases: nn.TransformerEncoderLayer

Modified version of torch.nn.TransformerEncoderLayer that uses :math:1/n-scaled attention for compatibility with muP (as opposed to the original :math:1/\sqrt{n} scaling factor) Arguments are the same as torch.nn.TransformerEncoderLayer.

get_activation(activation)

returns the activation function represented by the input string

Parameters:

Name Type Description Default
activation Union[type(None), str, Callable]

Callable, None, or string with value: "none", "ReLU", "Sigmoid", "Tanh", "ELU", "SELU", "GLU", "GELU", "LeakyReLU", "Softplus"

required

Returns:

Type Description
Optional[Callable]

Callable or None: The activation function

get_activation_str(activation)

returns the string related to the activation function

Parameters:

Name Type Description Default
activation Union[type(None), str, Callable]

Callable, None, or string with value: "none", "ReLU", "Sigmoid", "Tanh", "ELU", "SELU", "GLU", "LeakyReLU", "Softplus"

required

Returns:

Type Description
str

The name of the activation function

get_norm(normalization, dim=None)

returns the normalization function represented by the input string

Parameters:

Name Type Description Default
normalization Union[Type[None], str, Callable]

Callable, None, or string with value: "none", "batch_norm", "layer_norm"

required
dim Optional[int]

Dimension where to apply the norm. Mandatory for 'batch_norm' and 'layer_norm'

None

Returns:

Type Description

Callable or None: The normalization function

Residual Connections


graphium.nn.residual_connections

Different types of residual connections, including None, Simple (ResNet-like), Concat and DenseNet

ResidualConnectionBase

Bases: nn.Module

__init__(skip_steps=1)

Abstract class for the residual connections. Using this class, we implement different types of residual connections, such as the ResNet, weighted-ResNet, skip-concat and DensNet.

The following methods must be implemented in a children class

  • h_dim_increase_type()
  • has_weights()

Parameters:

Name Type Description Default
skip_steps int

int The number of steps to skip between the residual connections. If 1, all the layers are connected. If 2, half of the layers are connected.

1
__repr__()

Controls how the class is printed

get_true_out_dims(out_dims)

find the true output dimensions

Parameters:

Name Type Description Default
out_dims List

List

required

Returns:

Name Type Description
true_out_dims List

List

h_dim_increase_type() abstractmethod

How does the dimension of the output features increases after each layer?

Returns:

Name Type Description
h_dim_increase_type

None or str - None: The dimension of the output features do not change at each layer. E.g. ResNet.

  • "previous": The dimension of the output features is the concatenation of the previous layer with the new layer.

  • "cumulative": The dimension of the output features is the concatenation of all previous layers.

has_weights() abstractmethod

Returns:

Name Type Description
has_weights

bool Whether the residual connection uses weights

ResidualConnectionConcat

Bases: ResidualConnectionBase

__init__(skip_steps=1)

Class for the simple residual connections proposed but where the skip connection features are concatenated to the current layer features.

Parameters:

Name Type Description Default
skip_steps int

int The number of steps to skip between the residual connections. If 1, all the layers are connected. If 2, half of the layers are connected.

1
forward(h, h_prev, step_idx)

Concatenate h with the previous layers with skip connection h_prev.

Parameters:

Name Type Description Default
h torch.Tensor

torch.Tensor(..., m) The current layer features

required
h_prev torch.Tensor

torch.Tensor(..., n), None The features from the previous layer with a skip connection. Usually, we have n equal to m. At step_idx==0, h_prev can be set to None.

required
step_idx int

int Current layer index or step index in the forward loop of the architecture.

required

Returns:

Name Type Description
h

torch.Tensor(..., m) or torch.Tensor(..., m + n) Either return h unchanged, or the concatenation with h_prev, depending on the step_idx and self.skip_steps.

h_prev

torch.Tensor(..., m) or torch.Tensor(..., m + n) Either return h_prev unchanged, or the same value as h, depending on the step_idx and self.skip_steps.

h_dim_increase_type()

Returns:

Type Description

"previous": The dimension of the output layer is the concatenation with the previous layer.

has_weights()

Returns:

Type Description

False The current class does not use weights

ResidualConnectionDenseNet

Bases: ResidualConnectionBase

__init__(skip_steps=1)

Class for the residual connections proposed by DenseNet, where all previous skip connection features are concatenated to the current layer features.

Parameters:

Name Type Description Default
skip_steps int

int The number of steps to skip between the residual connections. If 1, all the layers are connected. If 2, half of the layers are connected.

1
forward(h, h_prev, step_idx)

Concatenate h with all the previous layers with skip connection h_prev.

Parameters:

Name Type Description Default
h torch.Tensor

torch.Tensor(..., m) The current layer features

required
h_prev torch.Tensor

torch.Tensor(..., n), None The features from the previous layers. n = ((step_idx // self.skip_steps) + 1) * m

At step_idx==0, h_prev can be set to None.

required
step_idx int

int Current layer index or step index in the forward loop of the architecture.

required

Returns:

Name Type Description
h

torch.Tensor(..., m) or torch.Tensor(..., m + n) Either return h unchanged, or the concatenation with h_prev, depending on the step_idx and self.skip_steps.

h_prev

torch.Tensor(..., m) or torch.Tensor(..., m + n) Either return h_prev unchanged, or the same value as h, depending on the step_idx and self.skip_steps.

h_dim_increase_type()

Returns:

Type Description

"cumulative": The dimension of the output layer is the concatenation of all the previous layer.

has_weights()

Returns:

Type Description

False The current class does not use weights

ResidualConnectionNone

Bases: ResidualConnectionBase

No residual connection. This class is only used for simpler code compatibility

__repr__()

Controls how the class is printed

forward(h, h_prev, step_idx)

Ignore the skip connection.

Returns:

Name Type Description
h

torch.Tensor(..., m) Return same as input.

h_prev

torch.Tensor(..., m) Return same as input.

h_dim_increase_type()

Returns:

Name Type Description
None

The dimension of the output features do not change at each layer.

has_weights()

Returns:

Type Description

False The current class does not use weights

ResidualConnectionRandom

Bases: ResidualConnectionBase

__init__(skip_steps=1, out_dims=None, num_layers=None)

Class for the random residual connection, where each layer is connected to each following layer with a random weight between 0 and 1.

Parameters:

Name Type Description Default
skip_steps

Parameter only there for compatibility with other classes of the same parent.

1
out_dims List[int]

The list of output dimensions. Only required to get the number of layers. Must be provided if num_layers is None.

None
num_layers int

The number of layers. Must be provided if out_dims is None.

None
forward(h, h_prev, step_idx)

Add h with the previous layers with skip connection h_prev, similar to ResNet.

Parameters:

Name Type Description Default
h torch.Tensor

torch.Tensor(..., m) The current layer features

required
h_prev torch.Tensor

torch.Tensor(..., m), None The features from the previous layer with a skip connection. At step_idx==0, h_prev can be set to None.

required
step_idx int

int Current layer index or step index in the forward loop of the architecture.

required

Returns:

Name Type Description
h

torch.Tensor(..., m) Either return h unchanged, or the sum with on h_prev, depending on the step_idx and self.skip_steps.

h_prev

torch.Tensor(..., m) Either return h_prev unchanged, or the same value as h, depending on the step_idx and self.skip_steps.

h_dim_increase_type()

Returns:

Name Type Description
None

The dimension of the output features do not change at each layer.

has_weights()

Returns:

Type Description

False The current class does not use weights

ResidualConnectionSimple

Bases: ResidualConnectionBase

__init__(skip_steps=1)

Class for the simple residual connections proposed by ResNet, where the current layer output is summed to a previous layer output.

Parameters:

Name Type Description Default
skip_steps int

int The number of steps to skip between the residual connections. If 1, all the layers are connected. If 2, half of the layers are connected.

1
forward(h, h_prev, step_idx)

Add h with the previous layers with skip connection h_prev, similar to ResNet.

Parameters:

Name Type Description Default
h torch.Tensor

torch.Tensor(..., m) The current layer features

required
h_prev torch.Tensor

torch.Tensor(..., m), None The features from the previous layer with a skip connection. At step_idx==0, h_prev can be set to None.

required
step_idx int

int Current layer index or step index in the forward loop of the architecture.

required

Returns:

Name Type Description
h

torch.Tensor(..., m) Either return h unchanged, or the sum with on h_prev, depending on the step_idx and self.skip_steps.

h_prev

torch.Tensor(..., m) Either return h_prev unchanged, or the same value as h, depending on the step_idx and self.skip_steps.

h_dim_increase_type()

Returns:

Name Type Description
None

The dimension of the output features do not change at each layer.

has_weights()

Returns:

Type Description

False The current class does not use weights

ResidualConnectionWeighted

Bases: ResidualConnectionBase

__init__(out_dims, skip_steps=1, dropout=0.0, activation='none', normalization='none', bias=False)

Class for the simple residual connections proposed by ResNet, with an added layer in the residual connection itself. The layer output is summed to a a non-linear transformation of a previous layer output.

Parameters:

Name Type Description Default
skip_steps int

int The number of steps to skip between the residual connections. If 1, all the layers are connected. If 2, half of the layers are connected.

1
out_dims

list(int) list of all output dimensions for the network that will use this residual connection. E.g. out_dims = [4, 8, 8, 8, 2].

required
dropout

float value between 0 and 1.0 representing the percentage of dropout to use in the weights

0.0
activation Union[str, Callable]

str, Callable The activation function to use after the skip weights

'none'
normalization

Normalization to use. Choices:

  • "none" or None: No normalization
  • "batch_norm": Batch normalization
  • "layer_norm": Layer normalization in the hidden layers.
  • Callable: Any callable function
'none'
bias

bool Whether to apply add a bias after the weights

False
forward(h, h_prev, step_idx)

Add h with the previous layers with skip connection h_prev, after a feed-forward layer.

Parameters:

Name Type Description Default
h torch.Tensor

torch.Tensor(..., m) The current layer features

required
h_prev torch.Tensor

torch.Tensor(..., m), None The features from the previous layer with a skip connection. At step_idx==0, h_prev can be set to None.

required
step_idx int

int Current layer index or step index in the forward loop of the architecture.

required

Returns:

Name Type Description
h

torch.Tensor(..., m) Either return h unchanged, or the sum with the output of a NN layer on h_prev, depending on the step_idx and self.skip_steps.

h_prev

torch.Tensor(..., m) Either return h_prev unchanged, or the same value as h, depending on the step_idx and self.skip_steps.

h_dim_increase_type()

Returns:

Name Type Description
None

The dimension of the output features do not change at each layer.

has_weights()

Returns:

Type Description

True The current class uses weights