graphium.nn¶
Base structures for neural networks in the library (to be inherented by other modules)
Base Graph Layer¶
graphium.nn.base_graph_layer
¶
BaseGraphModule
¶
Bases: BaseGraphStructure
, nn.Module
__init__(in_dim, out_dim, activation='relu', dropout=0.0, normalization='none', layer_idx=None, layer_depth=None, droppath_rate=0.0)
¶
Abstract class used to standardize the implementation of Pyg layers in the current library. It will allow a network to seemlesly swap between different GNN layers by better understanding the expected inputs and outputs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_dim |
int
|
Input feature dimensions of the layer |
required |
out_dim |
int
|
Output feature dimensions of the layer |
required |
activation |
Union[str, Callable]
|
activation function to use in the layer |
'relu'
|
dropout |
float
|
The ratio of units to dropout. Must be between 0 and 1 |
0.0
|
normalization |
Union[str, Callable]
|
Normalization to use. Choices:
|
'none'
|
layer_idx |
Optional[int]
|
The index of the current layer |
None
|
layer_depth |
Optional[int]
|
The total depth (number of layers) associated to this specific layer |
None
|
droppath_rate |
float
|
stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382 |
0.0
|
BaseGraphStructure
¶
layer_inputs_edges: bool
abstractmethod
property
¶
Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from layer_supports_input_edges
since a layer that
supports edges can decide to not use them.
Returns:
Name | Type | Description |
---|---|---|
bool |
bool
|
Whether the layer uses input edges in the forward pass |
layer_outputs_edges: bool
abstractmethod
property
¶
Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from layer_supports_output_edges
since a layer that
supports edges can decide to not use them.
Returns:
Name | Type | Description |
---|---|---|
bool |
bool
|
Whether the layer outputs edges in the forward pass |
max_num_edges_per_graph: Optional[int]
property
writable
¶
Get the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU)
max_num_nodes_per_graph: Optional[int]
property
writable
¶
Get the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU)
out_dim_factor: int
abstractmethod
property
¶
Abstract method. Get the factor by which the output dimension is multiplied for the next layer.
For standard layers, this will return 1
.
But for others, such as GatLayer
, the output is the concatenation
of the outputs from each head, so the out_dim gets multiplied by
the number of heads, and this function should return the number
of heads.
Returns:
Name | Type | Description |
---|---|---|
int |
int
|
The factor that multiplies the output dimensions |
__init__(in_dim, out_dim, activation='relu', dropout=0.0, normalization='none', layer_idx=None, layer_depth=None, droppath_rate=0.0)
¶
Abstract class used to standardize the implementation of Pyg layers in the current library. It will allow a network to seemlesly swap between different GNN layers by better understanding the expected inputs and outputs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_dim |
int
|
Input feature dimensions of the layer |
required |
out_dim |
int
|
Output feature dimensions of the layer |
required |
activation |
Union[str, Callable]
|
activation function to use in the layer |
'relu'
|
dropout |
float
|
The ratio of units to dropout. Must be between 0 and 1 |
0.0
|
normalization |
Union[str, Callable]
|
Normalization to use. Choices:
|
'none'
|
layer_idx |
Optional[int]
|
The index of the current layer |
None
|
layer_depth |
Optional[int]
|
The total depth (number of layers) associated to this specific layer |
None
|
droppath_rate |
float
|
stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382 |
0.0
|
__repr__()
¶
Controls how the class is printed
apply_norm_activation_dropout(feat, normalization=True, activation=True, dropout=True, droppath=True, batch_idx=None, batch_size=None)
¶
Apply the different normalization and the dropout to the output layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
feat |
Tensor
|
Feature tensor, to be normalized |
required |
normalization |
bool
|
Whether to apply the normalization |
True
|
activation |
bool
|
Whether to apply the activation layer |
True
|
dropout |
bool
|
Whether to apply the dropout layer |
True
|
droppath |
bool
|
Whether to apply the DropPath layer |
True
|
Returns:
Name | Type | Description |
---|---|---|
feat |
Normalized and dropped-out features |
layer_supports_edges()
¶
Abstract method. Return a boolean specifying if the layer type supports output edges edges or not.
Returns:
Name | Type | Description |
---|---|---|
bool |
bool
|
Whether the layer supports the use of edges |
check_intpus_allow_int(obj, edge_index, size)
¶
Overwrite the check_input to allow for int32 and int16 TODO: Remove when PyG and pytorch supports int32.
Base Layers¶
graphium.nn.base_layers
¶
DropPath
¶
Bases: nn.Module
__init__(drop_rate)
¶
DropPath class for stochastic depth Deep Networks with Stochastic Depth Gao Huang, Yu Sun, Zhuang Liu, Daniel Sedra and Kilian Weinberger https://arxiv.org/abs/1603.09382
Parameters:
Name | Type | Description | Default |
---|---|---|---|
drop_rate |
float
|
Drop out probability |
required |
forward(input, batch_idx, batch_size=None)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input |
Tensor
|
|
required |
batch |
batch attribute of the batch object, batch.batch |
required | |
batch_size |
Optional[int]
|
The batch size. Must be provided when working on IPU |
None
|
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: |
get_stochastic_drop_rate(drop_rate, layer_idx=None, layer_depth=None)
staticmethod
¶
Get the stochastic drop rate from the nominal drop rate, the layer index, and the layer depth.
return drop_rate * (layer_idx / (layer_depth - 1))
Parameters:
Name | Type | Description | Default |
---|---|---|---|
drop_rate |
float
|
Drop out nominal probability |
required |
layer_idx |
Optional[int]
|
The index of the current layer |
None
|
layer_depth |
Optional[int]
|
The total depth (number of layers) associated to this specific layer |
None
|
FCLayer
¶
Bases: nn.Module
in_channels: int
property
¶
Get the input channel size. For compatibility with PyG.
out_channels: int
property
¶
Get the output channel size. For compatibility with PyG.
__init__(in_dim, out_dim, activation='relu', dropout=0.0, normalization='none', bias=True, init_fn=None, is_readout_layer=False, droppath_rate=0.0)
¶
A simple fully connected and customizable layer. This layer is centered around a torch.nn.Linear
module.
The order in which transformations are applied is:
- Dense Layer
- Activation
- Dropout (if applicable)
- Batch Normalization (if applicable)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_dim |
int
|
Input dimension of the layer (the |
required |
out_dim |
int
|
Output dimension of the layer. |
required |
dropout |
float
|
The ratio of units to dropout. No dropout by default. |
0.0
|
activation |
Union[str, Callable]
|
Activation function to use. |
'relu'
|
normalization |
Union[str, Callable]
|
Normalization to use. Choices:
|
'none'
|
bias |
bool
|
Whether to enable bias in for the linear layer. |
True
|
init_fn |
Optional[Callable]
|
Initialization function to use for the weight of the layer. Default is \(\(\mathcal{U}(-\sqrt{k}, \sqrt{k})\)\) with \(\(k=\frac{1}{ \text{in_dim}}\)\) |
None
|
is_readout_layer |
bool
|
Whether the layer should be treated as a readout layer by replacing of |
False
|
droppath_rate |
float
|
stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382 |
0.0
|
Attributes:
Name | Type | Description |
---|---|---|
dropout |
int
|
The ratio of units to dropout. |
normalization |
None or Callable
|
Normalization layer |
linear |
`torch.nn.Linear`
|
The linear layer |
activation |
`torch.nn.Module`
|
The activation layer |
init_fn |
Callable
|
Initialization function used for the weight of the layer |
in_dim |
int
|
Input dimension of the linear layer |
out_dim |
int
|
Output dimension of the linear layer |
forward(h)
¶
Apply the FC layer on the input features.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
h |
torch.Tensor
|
|
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
|
reset_parameters(init_fn=None)
¶
Reset the parameters of the linear layer using the init_fn
.
GRU
¶
Bases: nn.Module
__init__(in_dim, hidden_dim)
¶
Wrapper class for the GRU used by the GNN framework, nn.GRU is used for the Gated Recurrent Unit itself
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_dim |
int
|
Input dimension of the GRU layer |
required |
hidden_dim |
int
|
Hidden dimension of the GRU layer. |
required |
forward(x, y)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
|
required | |
y |
|
required |
Returns:
Type | Description |
---|---|
torch.Tensor: |
MLP
¶
Bases: nn.Module
__init__(in_dim, hidden_dims, out_dim, depth, activation='relu', last_activation='none', dropout=0.0, last_dropout=0.0, normalization='none', last_normalization='none', first_normalization='none', last_layer_is_readout=False, droppath_rate=0.0, constant_droppath_rate=True)
¶
Simple multi-layer perceptron, built of a series of FCLayers
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_dim |
int
|
Input dimension of the MLP |
required |
hidden_dims |
Union[Iterable[int], int]
|
Either an integer specifying all the hidden dimensions, or a list of dimensions in the hidden layers. |
required |
out_dim |
int
|
Output dimension of the MLP. |
required |
depth |
int
|
If |
required |
activation |
Union[str, Callable]
|
Activation function to use in all the layers except the last.
if |
'relu'
|
last_activation |
Union[str, Callable]
|
Activation function to use in the last layer. |
'none'
|
dropout |
float
|
The ratio of units to dropout. Must be between 0 and 1 |
0.0
|
normalization |
Union[Type[None], str, Callable]
|
Normalization to use. Choices:
if |
'none'
|
last_normalization |
Union[Type[None], str, Callable]
|
Norrmalization to use after the last layer. Same options as |
'none'
|
first_normalization |
Union[Type[None], str, Callable]
|
Norrmalization to use in before the first layer. Same options as |
'none'
|
last_dropout |
float
|
The ratio of units to dropout at the last layer. |
0.0
|
last_layer_is_readout |
bool
|
Whether the last layer should be treated as a readout layer.
Allows to use the |
False
|
droppath_rate |
float
|
stochastic depth drop rate, between 0 and 1. See https://arxiv.org/abs/1603.09382 |
0.0
|
constant_droppath_rate |
bool
|
If |
True
|
__repr__()
¶
Controls how the class is printed
forward(h)
¶
Apply the MLP on the input features.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
h |
torch.Tensor
|
|
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
|
MuReadoutGraphium
¶
Bases: MuReadout
PopTorch-compatible replacement for mup.MuReadout
Not quite a drop-in replacement for mup.MuReadout
- you need to specify
base_width
.
Set base_width
to width of base model passed to mup.set_base_shapes
to get same results on IPU and CPU. Should still "work" with any other
value, but won't give the same results as CPU
MultiheadAttentionMup
¶
Bases: nn.MultiheadAttention
Modifying the MultiheadAttention to work with the muTransfer paradigm.
The layers are initialized using the mup package.
The _scaled_dot_product_attention
normalizes the attention matrix with 1/d
instead of 1/sqrt(d)
The biased self-attention option is added to have 3D attention bias.
TransformerEncoderLayerMup
¶
Bases: nn.TransformerEncoderLayer
Modified version of torch.nn.TransformerEncoderLayer
that uses :math:1/n
-scaled attention
for compatibility with muP (as opposed to the original :math:1/\sqrt{n}
scaling factor)
Arguments are the same as torch.nn.TransformerEncoderLayer
.
get_activation(activation)
¶
returns the activation function represented by the input string
Parameters:
Name | Type | Description | Default |
---|---|---|---|
activation |
Union[type(None), str, Callable]
|
Callable, |
required |
Returns:
Type | Description |
---|---|
Optional[Callable]
|
Callable or None: The activation function |
get_activation_str(activation)
¶
returns the string related to the activation function
Parameters:
Name | Type | Description | Default |
---|---|---|---|
activation |
Union[type(None), str, Callable]
|
Callable, |
required |
Returns:
Type | Description |
---|---|
str
|
The name of the activation function |
get_norm(normalization, dim=None)
¶
returns the normalization function represented by the input string
Parameters:
Name | Type | Description | Default |
---|---|---|---|
normalization |
Union[Type[None], str, Callable]
|
Callable, |
required |
dim |
Optional[int]
|
Dimension where to apply the norm. Mandatory for 'batch_norm' and 'layer_norm' |
None
|
Returns:
Type | Description |
---|---|
Callable or None: The normalization function |
Residual Connections¶
graphium.nn.residual_connections
¶
Different types of residual connections, including None, Simple (ResNet-like), Concat and DenseNet
ResidualConnectionBase
¶
Bases: nn.Module
__init__(skip_steps=1)
¶
Abstract class for the residual connections. Using this class, we implement different types of residual connections, such as the ResNet, weighted-ResNet, skip-concat and DensNet.
The following methods must be implemented in a children class
h_dim_increase_type()
has_weights()
Parameters:
Name | Type | Description | Default |
---|---|---|---|
skip_steps |
int
|
int
The number of steps to skip between the residual connections.
If |
1
|
__repr__()
¶
Controls how the class is printed
get_true_out_dims(out_dims)
¶
find the true output dimensions
Parameters:
Name | Type | Description | Default |
---|---|---|---|
out_dims |
List
|
List |
required |
Returns:
Name | Type | Description |
---|---|---|
true_out_dims |
List
|
List |
h_dim_increase_type()
abstractmethod
¶
How does the dimension of the output features increases after each layer?
Returns:
Name | Type | Description |
---|---|---|
h_dim_increase_type |
None or str
-
|
has_weights()
abstractmethod
¶
Returns:
Name | Type | Description |
---|---|---|
has_weights |
bool Whether the residual connection uses weights |
ResidualConnectionConcat
¶
Bases: ResidualConnectionBase
__init__(skip_steps=1)
¶
Class for the simple residual connections proposed but where the skip connection features are concatenated to the current layer features.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
skip_steps |
int
|
int
The number of steps to skip between the residual connections.
If |
1
|
forward(h, h_prev, step_idx)
¶
Concatenate h
with the previous layers with skip connection h_prev
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
h |
torch.Tensor
|
torch.Tensor(..., m) The current layer features |
required |
h_prev |
torch.Tensor
|
torch.Tensor(..., n), None
The features from the previous layer with a skip connection.
Usually, we have |
required |
step_idx |
int
|
int Current layer index or step index in the forward loop of the architecture. |
required |
Returns:
Name | Type | Description |
---|---|---|
h |
torch.Tensor(..., m) or torch.Tensor(..., m + n)
Either return |
|
h_prev |
torch.Tensor(..., m) or torch.Tensor(..., m + n)
Either return |
h_dim_increase_type()
¶
Returns:
Type | Description |
---|---|
"previous": The dimension of the output layer is the concatenation with the previous layer. |
has_weights()
¶
Returns:
Type | Description |
---|---|
False The current class does not use weights |
ResidualConnectionDenseNet
¶
Bases: ResidualConnectionBase
__init__(skip_steps=1)
¶
Class for the residual connections proposed by DenseNet, where all previous skip connection features are concatenated to the current layer features.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
skip_steps |
int
|
int
The number of steps to skip between the residual connections.
If |
1
|
forward(h, h_prev, step_idx)
¶
Concatenate h
with all the previous layers with skip connection h_prev
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
h |
torch.Tensor
|
torch.Tensor(..., m) The current layer features |
required |
h_prev |
torch.Tensor
|
torch.Tensor(..., n), None The features from the previous layers. n = ((step_idx // self.skip_steps) + 1) * m At |
required |
step_idx |
int
|
int Current layer index or step index in the forward loop of the architecture. |
required |
Returns:
Name | Type | Description |
---|---|---|
h |
torch.Tensor(..., m) or torch.Tensor(..., m + n)
Either return |
|
h_prev |
torch.Tensor(..., m) or torch.Tensor(..., m + n)
Either return |
h_dim_increase_type()
¶
Returns:
Type | Description |
---|---|
"cumulative": The dimension of the output layer is the concatenation of all the previous layer. |
has_weights()
¶
Returns:
Type | Description |
---|---|
False The current class does not use weights |
ResidualConnectionNone
¶
Bases: ResidualConnectionBase
No residual connection. This class is only used for simpler code compatibility
__repr__()
¶
Controls how the class is printed
forward(h, h_prev, step_idx)
¶
Ignore the skip connection.
Returns:
Name | Type | Description |
---|---|---|
h |
torch.Tensor(..., m) Return same as input. |
|
h_prev |
torch.Tensor(..., m) Return same as input. |
h_dim_increase_type()
¶
Returns:
Name | Type | Description |
---|---|---|
None |
The dimension of the output features do not change at each layer. |
has_weights()
¶
Returns:
Type | Description |
---|---|
False The current class does not use weights |
ResidualConnectionRandom
¶
Bases: ResidualConnectionBase
__init__(skip_steps=1, out_dims=None, num_layers=None)
¶
Class for the random residual connection, where each layer is connected to each following layer with a random weight between 0 and 1.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
skip_steps |
Parameter only there for compatibility with other classes of the same parent. |
1
|
|
out_dims |
List[int]
|
The list of output dimensions. Only required to get the number
of layers. Must be provided if |
None
|
num_layers |
int
|
The number of layers. Must be provided if |
None
|
forward(h, h_prev, step_idx)
¶
Add h
with the previous layers with skip connection h_prev
,
similar to ResNet.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
h |
torch.Tensor
|
torch.Tensor(..., m) The current layer features |
required |
h_prev |
torch.Tensor
|
torch.Tensor(..., m), None
The features from the previous layer with a skip connection.
At |
required |
step_idx |
int
|
int Current layer index or step index in the forward loop of the architecture. |
required |
Returns:
Name | Type | Description |
---|---|---|
h |
torch.Tensor(..., m)
Either return |
|
h_prev |
torch.Tensor(..., m)
Either return |
h_dim_increase_type()
¶
Returns:
Name | Type | Description |
---|---|---|
None |
The dimension of the output features do not change at each layer. |
has_weights()
¶
Returns:
Type | Description |
---|---|
False The current class does not use weights |
ResidualConnectionSimple
¶
Bases: ResidualConnectionBase
__init__(skip_steps=1)
¶
Class for the simple residual connections proposed by ResNet, where the current layer output is summed to a previous layer output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
skip_steps |
int
|
int
The number of steps to skip between the residual connections.
If |
1
|
forward(h, h_prev, step_idx)
¶
Add h
with the previous layers with skip connection h_prev
,
similar to ResNet.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
h |
torch.Tensor
|
torch.Tensor(..., m) The current layer features |
required |
h_prev |
torch.Tensor
|
torch.Tensor(..., m), None
The features from the previous layer with a skip connection.
At |
required |
step_idx |
int
|
int Current layer index or step index in the forward loop of the architecture. |
required |
Returns:
Name | Type | Description |
---|---|---|
h |
torch.Tensor(..., m)
Either return |
|
h_prev |
torch.Tensor(..., m)
Either return |
h_dim_increase_type()
¶
Returns:
Name | Type | Description |
---|---|---|
None |
The dimension of the output features do not change at each layer. |
has_weights()
¶
Returns:
Type | Description |
---|---|
False The current class does not use weights |
ResidualConnectionWeighted
¶
Bases: ResidualConnectionBase
__init__(out_dims, skip_steps=1, dropout=0.0, activation='none', normalization='none', bias=False)
¶
Class for the simple residual connections proposed by ResNet, with an added layer in the residual connection itself. The layer output is summed to a a non-linear transformation of a previous layer output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
skip_steps |
int
|
int
The number of steps to skip between the residual connections.
If |
1
|
out_dims |
list(int)
list of all output dimensions for the network
that will use this residual connection.
E.g. |
required | |
dropout |
float value between 0 and 1.0 representing the percentage of dropout to use in the weights |
0.0
|
|
activation |
Union[str, Callable]
|
str, Callable The activation function to use after the skip weights |
'none'
|
normalization |
Normalization to use. Choices:
|
'none'
|
|
bias |
bool Whether to apply add a bias after the weights |
False
|
forward(h, h_prev, step_idx)
¶
Add h
with the previous layers with skip connection h_prev
, after
a feed-forward layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
h |
torch.Tensor
|
torch.Tensor(..., m) The current layer features |
required |
h_prev |
torch.Tensor
|
torch.Tensor(..., m), None
The features from the previous layer with a skip connection.
At |
required |
step_idx |
int
|
int Current layer index or step index in the forward loop of the architecture. |
required |
Returns:
Name | Type | Description |
---|---|---|
h |
torch.Tensor(..., m)
Either return |
|
h_prev |
torch.Tensor(..., m)
Either return |
h_dim_increase_type()
¶
Returns:
Name | Type | Description |
---|---|---|
None |
The dimension of the output features do not change at each layer. |
has_weights()
¶
Returns:
Type | Description |
---|---|
True The current class uses weights |