graphium.nn¶
Base structures for neural networks in the library (to be inherented by other modules)
Base Graph Layer¶
          graphium.nn.base_graph_layer
¶
  Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.
          BaseGraphModule
¶
  
            Bases: BaseGraphStructure, Module
          __init__(in_dim, out_dim, activation='relu', dropout=0.0, normalization='none', layer_idx=None, layer_depth=None, droppath_rate=0.0)
¶
  Abstract class used to standardize the implementation of Pyg layers in the current library. It will allow a network to seemlesly swap between different GNN layers by better understanding the expected inputs and outputs.
Parameters:
in_dim:
    Input feature dimensions of the layer
out_dim:
    Output feature dimensions of the layer
activation:
    activation function to use in the layer
dropout:
    The ratio of units to dropout. Must be between 0 and 1
normalization:
    Normalization to use. Choices:
    - "none" or `None`: No normalization
    - "batch_norm": Batch normalization
    - "layer_norm": Layer normalization
    - `Callable`: Any callable function
layer_idx:
    The index of the current layer
layer_depth:
    The total depth (number of layers) associated to this specific layer
droppath_rate:
    stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382
          BaseGraphStructure
¶
  
          layer_inputs_edges: bool
  
  
      abstractmethod
      property
  
¶
  Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from layer_supports_input_edges since a layer that
supports edges can decide to not use them.
Returns:
bool:
    Whether the layer uses input edges in the forward pass
          layer_outputs_edges: bool
  
  
      abstractmethod
      property
  
¶
  Abstract method. Return a boolean specifying if the layer type
uses edges as input or not.
It is different from layer_supports_output_edges since a layer that
supports edges can decide to not use them.
Returns:
bool:
    Whether the layer outputs edges in the forward pass
          max_num_edges_per_graph: Optional[int]
  
  
      property
      writable
  
¶
  Get the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU)
          max_num_nodes_per_graph: Optional[int]
  
  
      property
      writable
  
¶
  Get the maximum number of nodes per graph. Useful for reshaping a compiled model (IPU)
          out_dim_factor: int
  
  
      abstractmethod
      property
  
¶
  Abstract method. Get the factor by which the output dimension is multiplied for the next layer.
For standard layers, this will return 1.
But for others, such as GatLayer, the output is the concatenation
of the outputs from each head, so the out_dim gets multiplied by
the number of heads, and this function should return the number
of heads.
Returns:
int:
    The factor that multiplies the output dimensions
          __init__(in_dim, out_dim, activation='relu', dropout=0.0, normalization='none', layer_idx=None, layer_depth=None, droppath_rate=0.0)
¶
  Abstract class used to standardize the implementation of Pyg layers in the current library. It will allow a network to seemlesly swap between different GNN layers by better understanding the expected inputs and outputs.
Parameters:
in_dim:
    Input feature dimensions of the layer
out_dim:
    Output feature dimensions of the layer
activation:
    activation function to use in the layer
dropout:
    The ratio of units to dropout. Must be between 0 and 1
normalization:
    Normalization to use. Choices:
    - "none" or `None`: No normalization
    - "batch_norm": Batch normalization
    - "layer_norm": Layer normalization
    - `Callable`: Any callable function
layer_idx:
    The index of the current layer
layer_depth:
    The total depth (number of layers) associated to this specific layer
droppath_rate:
    stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382
          __repr__()
¶
  Controls how the class is printed
          apply_norm_activation_dropout(feat, normalization=True, activation=True, dropout=True, droppath=True, batch_idx=None, batch_size=None)
¶
  Apply the different normalization and the dropout to the output layer.
Parameters:
feat:
    Feature tensor, to be normalized
batch_idx
normalization:
    Whether to apply the normalization
activation:
    Whether to apply the activation layer
dropout:
    Whether to apply the dropout layer
droppath:
    Whether to apply the DropPath layer
Returns:
feat:
    Normalized and dropped-out features
          layer_supports_edges()
¶
  Abstract method. Return a boolean specifying if the layer type supports output edges edges or not.
Returns:
bool:
    Whether the layer supports the use of edges
          check_intpus_allow_int(obj, edge_index, size)
¶
  Overwrite the check_input to allow for int32 and int16 TODO: Remove when PyG and pytorch supports int32.
Base Layers¶
          graphium.nn.base_layers
¶
  Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.
          DropPath
¶
  
            Bases: Module
          __init__(drop_rate)
¶
  DropPath class for stochastic depth Deep Networks with Stochastic Depth Gao Huang, Yu Sun, Zhuang Liu, Daniel Sedra and Kilian Weinberger https://arxiv.org/abs/1603.09382
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| drop_rate | float | Drop out probability | required | 
          forward(input, batch_idx, batch_size=None)
¶
  Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| input | Tensor | 
 | required | 
| batch | batch attribute of the batch object, batch.batch | required | |
| batch_size | Optional[int] | The batch size. Must be provided when working on IPU | None | 
Returns:
| Type | Description | 
|---|---|
| Tensor | torch.Tensor:  | 
          get_stochastic_drop_rate(drop_rate, layer_idx=None, layer_depth=None)
  
  
      staticmethod
  
¶
  Get the stochastic drop rate from the nominal drop rate, the layer index, and the layer depth.
return drop_rate * (layer_idx / (layer_depth - 1))
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| drop_rate | float | Drop out nominal probability | required | 
| layer_idx | Optional[int] | The index of the current layer | None | 
| layer_depth | Optional[int] | The total depth (number of layers) associated to this specific layer | None | 
          FCLayer
¶
  
            Bases: Module
          in_channels: int
  
  
      property
  
¶
  Get the input channel size. For compatibility with PyG.
          out_channels: int
  
  
      property
  
¶
  Get the output channel size. For compatibility with PyG.
          __init__(in_dim, out_dim, activation='relu', dropout=0.0, normalization='none', bias=True, init_fn=None, is_readout_layer=False, droppath_rate=0.0)
¶
  A simple fully connected and customizable layer. This layer is centered around a torch.nn.Linear module.
The order in which transformations are applied is:
- Dense Layer
- Activation
- Dropout (if applicable)
- Batch Normalization (if applicable)
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| in_dim | int | Input dimension of the layer (the  | required | 
| out_dim | int | Output dimension of the layer. | required | 
| dropout | float | The ratio of units to dropout. No dropout by default. | 0.0 | 
| activation | Union[str, Callable] | Activation function to use. | 'relu' | 
| normalization | Union[str, Callable] | Normalization to use. Choices: 
 | 'none' | 
| bias | bool | Whether to enable bias in for the linear layer. | True | 
| init_fn | Optional[Callable] | Initialization function to use for the weight of the layer. Default is \(\(\mathcal{U}(-\sqrt{k}, \sqrt{k})\)\) with \(\(k=\frac{1}{ \text{in_dim}}\)\) | None | 
| is_readout_layer | bool | Whether the layer should be treated as a readout layer by replacing of  | False | 
| droppath_rate | float | stochastic depth drop rate, between 0 and 1, see https://arxiv.org/abs/1603.09382 | 0.0 | 
Attributes:
    dropout (int):
        The ratio of units to dropout.
    normalization (None or Callable):
        Normalization layer
    linear (torch.nn.Linear):
        The linear layer
    activation (torch.nn.Module):
        The activation layer
    init_fn (Callable):
        Initialization function used for the weight of the layer
    in_dim (int):
        Input dimension of the linear layer
    out_dim (int):
        Output dimension of the linear layer
          forward(h)
¶
  Apply the FC layer on the input features.
Parameters:
h: `torch.Tensor[..., Din]`:
    Input feature tensor, before the FC.
    `Din` is the number of input features
Returns:
`torch.Tensor[..., Dout]`:
    Output feature tensor, after the FC.
    `Dout` is the number of output features
          reset_parameters(init_fn=None)
¶
  Reset the parameters of the linear layer using the init_fn.
          GRU
¶
  
            Bases: Module
          __init__(in_dim, hidden_dim)
¶
  Wrapper class for the GRU used by the GNN framework, nn.GRU is used for the Gated Recurrent Unit itself
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| in_dim | int | Input dimension of the GRU layer | required | 
| hidden_dim | int | Hidden dimension of the GRU layer. | required | 
          forward(x, y)
¶
  Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| x | 
 | required | |
| y | 
 | required | 
Returns:
| Type | Description | 
|---|---|
| torch.Tensor:  | 
          MLP
¶
  
            Bases: Module
          __init__(in_dim, hidden_dims, out_dim, depth=None, activation='relu', last_activation='none', dropout=0.0, last_dropout=0.0, normalization='none', last_normalization='none', first_normalization='none', last_layer_is_readout=False, droppath_rate=0.0, constant_droppath_rate=True, fc_layer=FCLayer, fc_layer_kwargs=None)
¶
  Simple multi-layer perceptron, built of a series of FCLayers
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| in_dim | int | Input dimension of the MLP | required | 
| hidden_dims | Union[Iterable[int], int] | Either an integer specifying all the hidden dimensions, or a list of dimensions in the hidden layers. | required | 
| out_dim | int | Output dimension of the MLP. | required | 
| depth | Optional[int] | If  | None | 
| activation | Union[str, Callable] | Activation function to use in all the layers except the last.
if  | 'relu' | 
| last_activation | Union[str, Callable] | Activation function to use in the last layer. | 'none' | 
| dropout | float | The ratio of units to dropout. Must be between 0 and 1 | 0.0 | 
| normalization | Union[Type[None], str, Callable] | Normalization to use. Choices: 
 if  | 'none' | 
| last_normalization | Union[Type[None], str, Callable] | Norrmalization to use after the last layer. Same options as  | 'none' | 
| first_normalization | Union[Type[None], str, Callable] | Norrmalization to use in before the first layer. Same options as  | 'none' | 
| last_dropout | float | The ratio of units to dropout at the last layer. | 0.0 | 
| last_layer_is_readout | bool | Whether the last layer should be treated as a readout layer.
Allows to use the  | False | 
| droppath_rate | float | stochastic depth drop rate, between 0 and 1. See https://arxiv.org/abs/1603.09382 | 0.0 | 
| constant_droppath_rate | bool | If  | True | 
| fc_layer | FCLayer | The fully connected layer to use. Must inherit from  | FCLayer | 
| fc_layer_kwargs | Optional[dict] | Keyword arguments to pass to the fully connected layer. | None | 
          __repr__()
¶
  Controls how the class is printed
          forward(h)
¶
  Apply the MLP on the input features.
Parameters:
h: `torch.Tensor[..., Din]`:
    Input feature tensor, before the MLP.
    `Din` is the number of input features
Returns:
`torch.Tensor[..., Dout]`:
    Output feature tensor, after the MLP.
    `Dout` is the number of output features
          MuReadoutGraphium
¶
  
            Bases: MuReadout
PopTorch-compatible replacement for mup.MuReadout
Not quite a drop-in replacement for mup.MuReadout - you need to specify
base_width.
Set base_width to width of base model passed to mup.set_base_shapes
to get same results on IPU and CPU. Should still "work" with any other
value, but won't give the same results as CPU
          MultiheadAttentionMup
¶
  
            Bases: MultiheadAttention
Modifying the MultiheadAttention to work with the muTransfer paradigm.
The layers are initialized using the mup package.
The _scaled_dot_product_attention normalizes the attention matrix with 1/d instead of 1/sqrt(d)
The biased self-attention option is added to have 3D attention bias.
          TransformerEncoderLayerMup
¶
  
            Bases: TransformerEncoderLayer
Modified version of torch.nn.TransformerEncoderLayer that uses :math:1/n-scaled attention
for compatibility with muP (as opposed to the original :math:1/\sqrt{n} scaling factor)
Arguments are the same as torch.nn.TransformerEncoderLayer.
          get_activation(activation)
¶
  returns the activation function represented by the input string
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| activation | Union[type(None), str, Callable] | Callable,  | required | 
Returns:
| Type | Description | 
|---|---|
| Optional[Callable] | Callable or None: The activation function | 
          get_activation_str(activation)
¶
  returns the string related to the activation function
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| activation | Union[type(None), str, Callable] | Callable,  | required | 
Returns:
| Type | Description | 
|---|---|
| str | The name of the activation function | 
          get_norm(normalization, dim=None)
¶
  returns the normalization function represented by the input string
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| normalization | Union[Type[None], str, Callable] | Callable,  | required | 
| dim | Optional[int] | Dimension where to apply the norm. Mandatory for 'batch_norm' and 'layer_norm' | None | 
Returns:
| Type | Description | 
|---|---|
| Callable or None: The normalization function | 
Residual Connections¶
          graphium.nn.residual_connections
¶
  Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals. Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.
Valence Labs, Recursion Pharmaceuticals are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.
          ResidualConnectionBase
¶
  
            Bases: Module
          __init__(skip_steps=1)
¶
  Abstract class for the residual connections. Using this class, we implement different types of residual connections, such as the ResNet, weighted-ResNet, skip-concat and DensNet.
The following methods must be implemented in a children class
- h_dim_increase_type()
- has_weights()
Parameters:
skip_steps: int
    The number of steps to skip between the residual connections.
    If `1`, all the layers are connected. If `2`, half of the
    layers are connected.
          __repr__()
¶
  Controls how the class is printed
          get_true_out_dims(out_dims)
¶
  find the true output dimensions Parameters: out_dims: List Returns: true_out_dims: List
          h_dim_increase_type()
  
  
      abstractmethod
  
¶
  How does the dimension of the output features increases after each layer?
Returns:
h_dim_increase_type: None or str
    - ``None``: The dimension of the output features do not change at each layer.
    E.g. ResNet.
    - "previous": The dimension of the output features is the concatenation of
    the previous layer with the new layer.
    - "cumulative": The dimension of the output features is the concatenation
    of all previous layers.
          has_weights()
  
  
      abstractmethod
  
¶
  Returns:
has_weights: bool
    Whether the residual connection uses weights
          ResidualConnectionConcat
¶
  
            Bases: ResidualConnectionBase
          __init__(skip_steps=1)
¶
  Class for the simple residual connections proposed but where the skip connection features are concatenated to the current layer features.
Parameters:
skip_steps: int
    The number of steps to skip between the residual connections.
    If `1`, all the layers are connected. If `2`, half of the
    layers are connected.
          forward(h, h_prev, step_idx)
¶
  Concatenate h with the previous layers with skip connection h_prev.
Parameters:
h: torch.Tensor(..., m)
    The current layer features
h_prev: torch.Tensor(..., n), None
    The features from the previous layer with a skip connection.
    Usually, we have ``n`` equal to ``m``.
    At ``step_idx==0``, ``h_prev`` can be set to ``None``.
step_idx: int
    Current layer index or step index in the forward loop of the architecture.
Returns:
h: torch.Tensor(..., m) or torch.Tensor(..., m + n)
    Either return ``h`` unchanged, or the concatenation
    with ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``.
h_prev: torch.Tensor(..., m) or torch.Tensor(..., m + n)
    Either return ``h_prev`` unchanged, or the same value as ``h``,
    depending on the ``step_idx`` and ``self.skip_steps``.
          h_dim_increase_type()
¶
  Returns:
"previous":
    The dimension of the output layer is the concatenation with the previous layer.
          has_weights()
¶
  Returns:
False
    The current class does not use weights
          ResidualConnectionDenseNet
¶
  
            Bases: ResidualConnectionBase
          __init__(skip_steps=1)
¶
  Class for the residual connections proposed by DenseNet, where all previous skip connection features are concatenated to the current layer features.
Parameters:
skip_steps: int
    The number of steps to skip between the residual connections.
    If `1`, all the layers are connected. If `2`, half of the
    layers are connected.
          forward(h, h_prev, step_idx)
¶
  Concatenate h with all the previous layers with skip connection h_prev.
Parameters:
h: torch.Tensor(..., m)
    The current layer features
h_prev: torch.Tensor(..., n), None
    The features from the previous layers.
    n = ((step_idx // self.skip_steps) + 1) * m
    At ``step_idx==0``, ``h_prev`` can be set to ``None``.
step_idx: int
    Current layer index or step index in the forward loop of the architecture.
Returns:
h: torch.Tensor(..., m) or torch.Tensor(..., m + n)
    Either return ``h`` unchanged, or the concatenation
    with ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``.
h_prev: torch.Tensor(..., m) or torch.Tensor(..., m + n)
    Either return ``h_prev`` unchanged, or the same value as ``h``,
    depending on the ``step_idx`` and ``self.skip_steps``.
          h_dim_increase_type()
¶
  Returns:
"cumulative":
    The dimension of the output layer is the concatenation of all the previous layer.
          has_weights()
¶
  Returns:
False
    The current class does not use weights
          ResidualConnectionNone
¶
  
            Bases: ResidualConnectionBase
No residual connection. This class is only used for simpler code compatibility
          __repr__()
¶
  Controls how the class is printed
          forward(h, h_prev, step_idx)
¶
  Ignore the skip connection.
Returns:
h: torch.Tensor(..., m)
    Return same as input.
h_prev: torch.Tensor(..., m)
    Return same as input.
          h_dim_increase_type()
¶
  Returns:
None:
    The dimension of the output features do not change at each layer.
          has_weights()
¶
  Returns:
False
    The current class does not use weights
          ResidualConnectionRandom
¶
  
            Bases: ResidualConnectionBase
          __init__(skip_steps=1, out_dims=None, num_layers=None)
¶
  Class for the random residual connection, where each layer is connected
to each following layer with a random weight between 0 and 1.
Parameters:
    skip_steps:
        Parameter only there for compatibility with other classes of the same parent.
    out_dims:
        The list of output dimensions. Only required to get the number
        of layers. Must be provided if num_layers is None.
    num_layers:
        The number of layers. Must be provided if out_dims is None.
          forward(h, h_prev, step_idx)
¶
  Add h with the previous layers with skip connection h_prev,
similar to ResNet.
Parameters:
    h: torch.Tensor(..., m)
        The current layer features
    h_prev: torch.Tensor(..., m), None
        The features from the previous layer with a skip connection.
        At step_idx==0, h_prev can be set to None.
    step_idx: int
        Current layer index or step index in the forward loop of the architecture.
Returns:
    h: torch.Tensor(..., m)
        Either return h unchanged, or the sum with
        on h_prev, depending on the step_idx and self.skip_steps.
    h_prev: torch.Tensor(..., m)
        Either return h_prev unchanged, or the same value as h,
        depending on the step_idx and self.skip_steps.
          h_dim_increase_type()
¶
  Returns:
| Name | Type | Description | 
|---|---|---|
| None | The dimension of the output features do not change at each layer. | 
          has_weights()
¶
  Returns:
| Type | Description | 
|---|---|
| False The current class does not use weights | 
          ResidualConnectionSimple
¶
  
            Bases: ResidualConnectionBase
          __init__(skip_steps=1)
¶
  Class for the simple residual connections proposed by ResNet, where the current layer output is summed to a previous layer output.
Parameters:
skip_steps: int
    The number of steps to skip between the residual connections.
    If `1`, all the layers are connected. If `2`, half of the
    layers are connected.
          forward(h, h_prev, step_idx)
¶
  Add h with the previous layers with skip connection h_prev,
similar to ResNet.
Parameters:
h: torch.Tensor(..., m)
    The current layer features
h_prev: torch.Tensor(..., m), None
    The features from the previous layer with a skip connection.
    At ``step_idx==0``, ``h_prev`` can be set to ``None``.
step_idx: int
    Current layer index or step index in the forward loop of the architecture.
Returns:
h: torch.Tensor(..., m)
    Either return ``h`` unchanged, or the sum with
    on ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``.
h_prev: torch.Tensor(..., m)
    Either return ``h_prev`` unchanged, or the same value as ``h``,
    depending on the ``step_idx`` and ``self.skip_steps``.
          h_dim_increase_type()
¶
  Returns:
None:
    The dimension of the output features do not change at each layer.
          has_weights()
¶
  Returns:
False
    The current class does not use weights
          ResidualConnectionWeighted
¶
  
            Bases: ResidualConnectionBase
          __init__(out_dims, skip_steps=1, dropout=0.0, activation='none', normalization='none', bias=False)
¶
  Class for the simple residual connections proposed by ResNet, with an added layer in the residual connection itself. The layer output is summed to a a non-linear transformation of a previous layer output.
Parameters:
skip_steps: int
    The number of steps to skip between the residual connections.
    If `1`, all the layers are connected. If `2`, half of the
    layers are connected.
out_dims: list(int)
    list of all output dimensions for the network
    that will use this residual connection.
    E.g. ``out_dims = [4, 8, 8, 8, 2]``.
dropout: float
    value between 0 and 1.0 representing the percentage of dropout
    to use in the weights
activation: str, Callable
    The activation function to use after the skip weights
normalization:
    Normalization to use. Choices:
    - "none" or `None`: No normalization
    - "batch_norm": Batch normalization
    - "layer_norm": Layer normalization in the hidden layers.
    - `Callable`: Any callable function
bias: bool
    Whether to apply add a bias after the weights
          forward(h, h_prev, step_idx)
¶
  Add h with the previous layers with skip connection h_prev, after
a feed-forward layer.
Parameters:
h: torch.Tensor(..., m)
    The current layer features
h_prev: torch.Tensor(..., m), None
    The features from the previous layer with a skip connection.
    At ``step_idx==0``, ``h_prev`` can be set to ``None``.
step_idx: int
    Current layer index or step index in the forward loop of the architecture.
Returns:
h: torch.Tensor(..., m)
    Either return ``h`` unchanged, or the sum with the output of a NN layer
    on ``h_prev``, depending on the ``step_idx`` and ``self.skip_steps``.
h_prev: torch.Tensor(..., m)
    Either return ``h_prev`` unchanged, or the same value as ``h``,
    depending on the ``step_idx`` and ``self.skip_steps``.
          h_dim_increase_type()
¶
  Returns:
None:
    The dimension of the output features do not change at each layer.
          has_weights()
¶
  Returns:
True
    The current class uses weights