graphium.nn¶
Base structures for neural networks in the library (to be inherented by other modules)
Base Graph Layer¶
graphium.nn.base_graph_layer
¶
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.
BaseGraphModule
¶
Bases: BaseGraphStructure
, Module
__init__(in_dim, out_dim, activation='relu', dropout=0.0, normalization='none', layer_idx=None, layer_depth=None, droppath_rate=0.0)
¶
Abstract class used to standardize the implementation of Pyg layers in the current library. It will allow a network to seemlesly swap between different GNN layers by better understanding the expected inputs and outputs.
Parameters:
in_dim:
Input feature dimensions of the layer
out_dim:
Output feature dimensions of the layer
activation:
activation function to use in the layer
dropout:
The ratio of units to dropout. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
layer_idx:
The index of the current layer
layer_depth:
The total depth (number of layers) associated to this specific layer
droppath_rate:
stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382
BaseGraphStructure
¶
layer_inputs_edges: bool
abstractmethod
property
¶
Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from layer_supports_input_edges
since a layer that
supports edges can decide to not use them.
Returns:
bool:
Whether the layer uses input edges in the forward pass
layer_outputs_edges: bool
abstractmethod
property
¶
Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from layer_supports_output_edges
since a layer that
supports edges can decide to not use them.
Returns:
bool:
Whether the layer outputs edges in the forward pass
max_num_edges_per_graph: Optional[int]
property
writable
¶
Get the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU)
max_num_nodes_per_graph: Optional[int]
property
writable
¶
Get the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU)
out_dim_factor: int
abstractmethod
property
¶
Abstract method. Get the factor by which the output dimension is multiplied for the next layer.
For standard layers, this will return 1
.
But for others, such as GatLayer
, the output is the concatenation
of the outputs from each head, so the out_dim gets multiplied by
the number of heads, and this function should return the number
of heads.
Returns:
int:
The factor that multiplies the output dimensions
__init__(in_dim, out_dim, activation='relu', dropout=0.0, normalization='none', layer_idx=None, layer_depth=None, droppath_rate=0.0)
¶
Abstract class used to standardize the implementation of Pyg layers in the current library. It will allow a network to seemlesly swap between different GNN layers by better understanding the expected inputs and outputs.
Parameters:
in_dim:
Input feature dimensions of the layer
out_dim:
Output feature dimensions of the layer
activation:
activation function to use in the layer
dropout:
The ratio of units to dropout. Must be between 0 and 1
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization
- `Callable`: Any callable function
layer_idx:
The index of the current layer
layer_depth:
The total depth (number of layers) associated to this specific layer
droppath_rate:
stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382
__repr__()
¶
Controls how the class is printed
apply_norm_activation_dropout(feat, normalization=True, activation=True, dropout=True, droppath=True, batch_idx=None, batch_size=None)
¶
Apply the different normalization and the dropout to the output layer.
Parameters:
feat:
Feature tensor, to be normalized
batch_idx
normalization:
Whether to apply the normalization
activation:
Whether to apply the activation layer
dropout:
Whether to apply the dropout layer
droppath:
Whether to apply the DropPath layer
Returns:
feat:
Normalized and dropped-out features
layer_supports_edges()
¶
Abstract method. Return a boolean specifying if the layer type supports output edges edges or not.
Returns:
bool:
Whether the layer supports the use of edges
check_intpus_allow_int(obj, edge_index, size)
¶
Overwrite the check_input to allow for int32 and int16 TODO: Remove when PyG and pytorch supports int32.
Base Layers¶
graphium.nn.base_layers
¶
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.
DropPath
¶
Bases: Module
__init__(drop_rate)
¶
DropPath class for stochastic depth Deep Networks with Stochastic Depth Gao Huang, Yu Sun, Zhuang Liu, Daniel Sedra and Kilian Weinberger https://arxiv.org/abs/1603.09382
Parameters:
Name | Type | Description | Default |
---|---|---|---|
drop_rate |
float
|
Drop out probability |
required |
forward(input, batch_idx, batch_size=None)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input |
Tensor
|
|
required |
batch |
batch attribute of the batch object, batch.batch |
required | |
batch_size |
Optional[int]
|
The batch size. Must be provided when working on IPU |
None
|
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: |
get_stochastic_drop_rate(drop_rate, layer_idx=None, layer_depth=None)
staticmethod
¶
Get the stochastic drop rate from the nominal drop rate, the layer index, and the layer depth.
return drop_rate * (layer_idx / (layer_depth - 1))
Parameters:
Name | Type | Description | Default |
---|---|---|---|
drop_rate |
float
|
Drop out nominal probability |
required |
layer_idx |
Optional[int]
|
The index of the current layer |
None
|
layer_depth |
Optional[int]
|
The total depth (number of layers) associated to this specific layer |
None
|
FCLayer
¶
Bases: Module
in_channels: int
property
¶
Get the input channel size. For compatibility with PyG.
out_channels: int
property
¶
Get the output channel size. For compatibility with PyG.
__init__(in_dim, out_dim, activation='relu', dropout=0.0, normalization='none', bias=True, init_fn=None, is_readout_layer=False, droppath_rate=0.0)
¶
A simple fully connected and customizable layer. This layer is centered around a torch.nn.Linear
module.
The order in which transformations are applied is:
- Dense Layer
- Activation
- Dropout (if applicable)
- Batch Normalization (if applicable)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_dim |
int
|
Input dimension of the layer (the |
required |
out_dim |
int
|
Output dimension of the layer. |
required |
dropout |
float
|
The ratio of units to dropout. No dropout by default. |
0.0
|
activation |
Union[str, Callable]
|
Activation function to use. |
'relu'
|
normalization |
Union[str, Callable]
|
Normalization to use. Choices:
|
'none'
|
bias |
bool
|
Whether to enable bias in for the linear layer. |
True
|
init_fn |
Optional[Callable]
|
Initialization function to use for the weight of the layer. Default is \(\(\mathcal{U}(-\sqrt{k}, \sqrt{k})\)\) with \(\(k=\frac{1}{ \text{in_dim}}\)\) |
None
|
is_readout_layer |
bool
|
Whether the layer should be treated as a readout layer by replacing of |
False
|
droppath_rate |
float
|
stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382 |
0.0
|
Attributes:
dropout (int):
The ratio of units to dropout.
normalization (None or Callable):
Normalization layer
linear (torch.nn.Linear
):
The linear layer
activation (torch.nn.Module
):
The activation layer
init_fn (Callable):
Initialization function used for the weight of the layer
in_dim (int):
Input dimension of the linear layer
out_dim (int):
Output dimension of the linear layer
forward(h)
¶
Apply the FC layer on the input features.
Parameters:
h: `torch.Tensor[..., Din]`:
Input feature tensor, before the FC.
`Din` is the number of input features
Returns:
`torch.Tensor[..., Dout]`:
Output feature tensor, after the FC.
`Dout` is the number of output features
reset_parameters(init_fn=None)
¶
Reset the parameters of the linear layer using the init_fn
.
GRU
¶
Bases: Module
__init__(in_dim, hidden_dim)
¶
Wrapper class for the GRU used by the GNN framework, nn.GRU is used for the Gated Recurrent Unit itself
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_dim |
int
|
Input dimension of the GRU layer |
required |
hidden_dim |
int
|
Hidden dimension of the GRU layer. |
required |
forward(x, y)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
|
required | |
y |
|
required |
Returns:
Type | Description |
---|---|
torch.Tensor: |
MLP
¶
Bases: Module
__init__(in_dim, hidden_dims, out_dim, depth=None, activation='relu', last_activation='none', dropout=0.0, last_dropout=0.0, normalization='none', last_normalization='none', first_normalization='none', last_layer_is_readout=False, droppath_rate=0.0, constant_droppath_rate=True, fc_layer=FCLayer, fc_layer_kwargs=None)
¶
Simple multi-layer perceptron, built of a series of FCLayers
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_dim |
int
|
Input dimension of the MLP |
required |
hidden_dims |
Union[Iterable[int], int]
|
Either an integer specifying all the hidden dimensions, or a list of dimensions in the hidden layers. |
required |
out_dim |
int
|
Output dimension of the MLP. |
required |
depth |
Optional[int]
|
If |
None
|
activation |
Union[str, Callable]
|
Activation function to use in all the layers except the last.
if |
'relu'
|
last_activation |
Union[str, Callable]
|
Activation function to use in the last layer. |
'none'
|
dropout |
float
|
The ratio of units to dropout. Must be between 0 and 1 |
0.0
|
normalization |
Union[Type[None], str, Callable]
|
Normalization to use. Choices:
if |
'none'
|
last_normalization |
Union[Type[None], str, Callable]
|
Norrmalization to use after the last layer. Same options as |
'none'
|
first_normalization |
Union[Type[None], str, Callable]
|
Norrmalization to use in before the first layer. Same options as |
'none'
|
last_dropout |
float
|
The ratio of units to dropout at the last layer. |
0.0
|
last_layer_is_readout |
bool
|
Whether the last layer should be treated as a readout layer.
Allows to use the |
False
|
droppath_rate |
float
|
stochastic depth drop rate, between 0 and 1. See https://arxiv.org/abs/1603.09382 |
0.0
|
constant_droppath_rate |
bool
|
If |
True
|
fc_layer |
FCLayer
|
The fully connected layer to use. Must inherit from |
FCLayer
|
fc_layer_kwargs |
Optional[dict]
|
Keyword arguments to pass to the fully connected layer. |
None
|
__repr__()
¶
Controls how the class is printed
forward(h)
¶
Apply the MLP on the input features.
Parameters:
h: `torch.Tensor[..., Din]`:
Input feature tensor, before the MLP.
`Din` is the number of input features
Returns:
`torch.Tensor[..., Dout]`:
Output feature tensor, after the MLP.
`Dout` is the number of output features
MuReadoutGraphium
¶
Bases: MuReadout
PopTorch-compatible replacement for mup.MuReadout
Not quite a drop-in replacement for mup.MuReadout
- you need to specify
base_width
.
Set base_width
to width of base model passed to mup.set_base_shapes
to get same results on IPU and CPU. Should still "work" with any other
value, but won't give the same results as CPU
MultiheadAttentionMup
¶
Bases: MultiheadAttention
Modifying the MultiheadAttention to work with the muTransfer paradigm.
The layers are initialized using the mup package.
The _scaled_dot_product_attention
normalizes the attention matrix with 1/d
instead of 1/sqrt(d)
The biased self-attention option is added to have 3D attention bias.
TransformerEncoderLayerMup
¶
Bases: TransformerEncoderLayer
Modified version of torch.nn.TransformerEncoderLayer
that uses :math:1/n
-scaled attention
for compatibility with muP (as opposed to the original :math:1/\sqrt{n}
scaling factor)
Arguments are the same as torch.nn.TransformerEncoderLayer
.
get_activation(activation)
¶
returns the activation function represented by the input string
Parameters:
Name | Type | Description | Default |
---|---|---|---|
activation |
Union[type(None), str, Callable]
|
Callable, |
required |
Returns:
Type | Description |
---|---|
Optional[Callable]
|
Callable or None: The activation function |
get_activation_str(activation)
¶
returns the string related to the activation function
Parameters:
Name | Type | Description | Default |
---|---|---|---|
activation |
Union[type(None), str, Callable]
|
Callable, |
required |
Returns:
Type | Description |
---|---|
str
|
The name of the activation function |
get_norm(normalization, dim=None)
¶
returns the normalization function represented by the input string
Parameters:
Name | Type | Description | Default |
---|---|---|---|
normalization |
Union[Type[None], str, Callable]
|
Callable, |
required |
dim |
Optional[int]
|
Dimension where to apply the norm. Mandatory for 'batch_norm' and 'layer_norm' |
None
|
Returns:
Type | Description |
---|---|
Callable or None: The normalization function |
Residual Connections¶
graphium.nn.residual_connections
¶
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.
Valence Labs, Recursion Pharmaceuticals are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.
ResidualConnectionBase
¶
Bases: Module
__init__(skip_steps=1)
¶
Abstract class for the residual connections. Using this class, we implement different types of residual connections, such as the ResNet, weighted-ResNet, skip-concat and DensNet.
The following methods must be implemented in a children class
h_dim_increase_type()
has_weights()
Parameters:
skip_steps: int
The number of steps to skip between the residual connections.
If `1`, all the layers are connected. If `2`, half of the
layers are connected.
__repr__()
¶
Controls how the class is printed
get_true_out_dims(out_dims)
¶
find the true output dimensions Parameters: out_dims: List Returns: true_out_dims: List
h_dim_increase_type()
abstractmethod
¶
How does the dimension of the output features increases after each layer?
Returns:
h_dim_increase_type: None or str
- ``None``: The dimension of the output features do not change at each layer.
E.g. ResNet.
- "previous": The dimension of the output features is the concatenation of
the previous layer with the new layer.
- "cumulative": The dimension of the output features is the concatenation
of all previous layers.
has_weights()
abstractmethod
¶
Returns:
has_weights: bool
Whether the residual connection uses weights
ResidualConnectionConcat
¶
Bases: ResidualConnectionBase
__init__(skip_steps=1)
¶
Class for the simple residual connections proposed but where the skip connection features are concatenated to the current layer features.
Parameters:
skip_steps: int
The number of steps to skip between the residual connections.
If `1`, all the layers are connected. If `2`, half of the
layers are connected.
forward(h, h_prev, step_idx)
¶
Concatenate h
with the previous layers with skip connection h_prev
.
Parameters:
h: torch.Tensor(..., m)
The current layer features
h_prev: torch.Tensor(..., n), None
The features from the previous layer with a skip connection.
Usually, we have ``n`` equal to ``m``.
At ``step_idx==0``, ``h_prev`` can be set to ``None``.
step_idx: int
Current layer index or step index in the forward loop of the architecture.
Returns:
h: torch.Tensor(..., m) or torch.Tensor(..., m + n)
Either return ``h`` unchanged, or the concatenation
with ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``.
h_prev: torch.Tensor(..., m) or torch.Tensor(..., m + n)
Either return ``h_prev`` unchanged, or the same value as ``h``,
depending on the ``step_idx`` and ``self.skip_steps``.
h_dim_increase_type()
¶
Returns:
"previous":
The dimension of the output layer is the concatenation with the previous layer.
has_weights()
¶
Returns:
False
The current class does not use weights
ResidualConnectionDenseNet
¶
Bases: ResidualConnectionBase
__init__(skip_steps=1)
¶
Class for the residual connections proposed by DenseNet, where all previous skip connection features are concatenated to the current layer features.
Parameters:
skip_steps: int
The number of steps to skip between the residual connections.
If `1`, all the layers are connected. If `2`, half of the
layers are connected.
forward(h, h_prev, step_idx)
¶
Concatenate h
with all the previous layers with skip connection h_prev
.
Parameters:
h: torch.Tensor(..., m)
The current layer features
h_prev: torch.Tensor(..., n), None
The features from the previous layers.
n = ((step_idx // self.skip_steps) + 1) * m
At ``step_idx==0``, ``h_prev`` can be set to ``None``.
step_idx: int
Current layer index or step index in the forward loop of the architecture.
Returns:
h: torch.Tensor(..., m) or torch.Tensor(..., m + n)
Either return ``h`` unchanged, or the concatenation
with ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``.
h_prev: torch.Tensor(..., m) or torch.Tensor(..., m + n)
Either return ``h_prev`` unchanged, or the same value as ``h``,
depending on the ``step_idx`` and ``self.skip_steps``.
h_dim_increase_type()
¶
Returns:
"cumulative":
The dimension of the output layer is the concatenation of all the previous layer.
has_weights()
¶
Returns:
False
The current class does not use weights
ResidualConnectionNone
¶
Bases: ResidualConnectionBase
No residual connection. This class is only used for simpler code compatibility
__repr__()
¶
Controls how the class is printed
forward(h, h_prev, step_idx)
¶
Ignore the skip connection.
Returns:
h: torch.Tensor(..., m)
Return same as input.
h_prev: torch.Tensor(..., m)
Return same as input.
h_dim_increase_type()
¶
Returns:
None:
The dimension of the output features do not change at each layer.
has_weights()
¶
Returns:
False
The current class does not use weights
ResidualConnectionRandom
¶
Bases: ResidualConnectionBase
__init__(skip_steps=1, out_dims=None, num_layers=None)
¶
Class for the random residual connection, where each layer is connected
to each following layer with a random weight between 0 and 1.
Parameters:
skip_steps:
Parameter only there for compatibility with other classes of the same parent.
out_dims:
The list of output dimensions. Only required to get the number
of layers. Must be provided if num_layers
is None.
num_layers:
The number of layers. Must be provided if out_dims
is None.
forward(h, h_prev, step_idx)
¶
Add h
with the previous layers with skip connection h_prev
,
similar to ResNet.
Parameters:
h: torch.Tensor(..., m)
The current layer features
h_prev: torch.Tensor(..., m), None
The features from the previous layer with a skip connection.
At step_idx==0
, h_prev
can be set to None
.
step_idx: int
Current layer index or step index in the forward loop of the architecture.
Returns:
h: torch.Tensor(..., m)
Either return h
unchanged, or the sum with
on h_prev
, depending on the step_idx
and self.skip_steps
.
h_prev: torch.Tensor(..., m)
Either return h_prev
unchanged, or the same value as h
,
depending on the step_idx
and self.skip_steps
.
h_dim_increase_type()
¶
Returns:
Name | Type | Description |
---|---|---|
None |
The dimension of the output features do not change at each layer. |
has_weights()
¶
Returns:
Type | Description |
---|---|
False The current class does not use weights |
ResidualConnectionSimple
¶
Bases: ResidualConnectionBase
__init__(skip_steps=1)
¶
Class for the simple residual connections proposed by ResNet, where the current layer output is summed to a previous layer output.
Parameters:
skip_steps: int
The number of steps to skip between the residual connections.
If `1`, all the layers are connected. If `2`, half of the
layers are connected.
forward(h, h_prev, step_idx)
¶
Add h
with the previous layers with skip connection h_prev
,
similar to ResNet.
Parameters:
h: torch.Tensor(..., m)
The current layer features
h_prev: torch.Tensor(..., m), None
The features from the previous layer with a skip connection.
At ``step_idx==0``, ``h_prev`` can be set to ``None``.
step_idx: int
Current layer index or step index in the forward loop of the architecture.
Returns:
h: torch.Tensor(..., m)
Either return ``h`` unchanged, or the sum with
on ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``.
h_prev: torch.Tensor(..., m)
Either return ``h_prev`` unchanged, or the same value as ``h``,
depending on the ``step_idx`` and ``self.skip_steps``.
h_dim_increase_type()
¶
Returns:
None:
The dimension of the output features do not change at each layer.
has_weights()
¶
Returns:
False
The current class does not use weights
ResidualConnectionWeighted
¶
Bases: ResidualConnectionBase
__init__(out_dims, skip_steps=1, dropout=0.0, activation='none', normalization='none', bias=False)
¶
Class for the simple residual connections proposed by ResNet, with an added layer in the residual connection itself. The layer output is summed to a a non-linear transformation of a previous layer output.
Parameters:
skip_steps: int
The number of steps to skip between the residual connections.
If `1`, all the layers are connected. If `2`, half of the
layers are connected.
out_dims: list(int)
list of all output dimensions for the network
that will use this residual connection.
E.g. ``out_dims = [4, 8, 8, 8, 2]``.
dropout: float
value between 0 and 1.0 representing the percentage of dropout
to use in the weights
activation: str, Callable
The activation function to use after the skip weights
normalization:
Normalization to use. Choices:
- "none" or `None`: No normalization
- "batch_norm": Batch normalization
- "layer_norm": Layer normalization in the hidden layers.
- `Callable`: Any callable function
bias: bool
Whether to apply add a bias after the weights
forward(h, h_prev, step_idx)
¶
Add h
with the previous layers with skip connection h_prev
, after
a feed-forward layer.
Parameters:
h: torch.Tensor(..., m)
The current layer features
h_prev: torch.Tensor(..., m), None
The features from the previous layer with a skip connection.
At ``step_idx==0``, ``h_prev`` can be set to ``None``.
step_idx: int
Current layer index or step index in the forward loop of the architecture.
Returns:
h: torch.Tensor(..., m)
Either return ``h`` unchanged, or the sum with the output of a NN layer
on ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``.
h_prev: torch.Tensor(..., m)
Either return ``h_prev`` unchanged, or the same value as ``h``,
depending on the ``step_idx`` and ``self.skip_steps``.
h_dim_increase_type()
¶
Returns:
None:
The dimension of the output features do not change at each layer.
has_weights()
¶
Returns:
True
The current class uses weights