Skip to content

graphium.finetuning

Module for finetuning models and doing linear probing (fingerprinting).

graphium.finetuning.finetuning.GraphFinetuning

Bases: BaseFinetuning

__init__(finetuning_module, added_depth=0, unfreeze_pretrained_depth=None, epoch_unfreeze_all=0, train_bn=False)

Finetuning training callback that (un)freezes modules as specified in the configuration file. By default, the modified layers of the fineuning module and the finetuning head are unfrozen.

Parameters:

Name Type Description Default
finetuning_module str

Module to finetune from

required
added_depth int

Number of layers of finetuning module that have been modified rel. to pretrained model

0
unfreeze_pretrained_depth Optional[int]

Number of additional layers to unfreeze before layers modified rel. to pretrained model

None
epoch_unfreeze_all int

Epoch to unfreeze entire model

0
train_bn bool

Boolean value indicating if batchnorm layers stay in training mode

False

finetune_function(pl_module, epoch, optimizer)

Function unfreezing entire model at specified epoch

Parameters:

Name Type Description Default
pl_module LightningModule

PredictorModule used for finetuning

required
epoch int

Current training epoch

required
optimizer Optimizer

Optimizer used for finetuning

required

freeze_before_training(pl_module)

Freeze everything up to finetuning module (and parts of finetuning module)

Parameters:

Name Type Description Default
pl_module LightningModule

PredictorModule used for finetuning

required

freeze_module(pl_module, module_name, module_map)

Freeze specific modules

Parameters:

Name Type Description Default
module_name str

Name of module to (partally) freeze

required
module_map Dict[str, Union[ModuleList, Any]]

Dictionary mapping from module_name to corresponding module(s)

required

graphium.finetuning.finetuning_architecture.FullGraphFinetuningNetwork

Bases: Module, MupMixin

__init__(pretrained_model, pretrained_model_kwargs={}, pretrained_overwriting_kwargs={}, finetuning_head_kwargs=None, num_inference_to_average=1, last_layer_is_readout=False, name='FullFinetuningGNN')

Flexible class that allows to implement an end-to-end graph finetuning network architecture, supporting flexible pretrained models and finetuning heads. The network decomposes into two parts of class PretrainedModel and FinetuningHead. The PretrainedModel class allows basic finetuning such as finetuning from a specified module of the pretrained model and dropping/adding layers in this module. The (optional) FinetuningHead class allows more flexible finetuning with a custom network applied after the pretrained model. If not specified, we fall back to basic finetuning integrated in PretrainedModel.

Parameters:

pretrained_model:
    A PretrainedModel or an identifier of pretrained model within GRAPHIUM_PRETRAINED_MODELS_DICT or a valid .ckpt checkpoint path

pretrained_model_kwargs:
    Key-word arguments to instantiate a model of the same class as the pretrained model (e.g., FullGraphMultitaskNetwork))

pretrained_overwriting_kwargs:
    Key-word arguments indicating which parameters of loaded model are shared with the pretrained part of FullGraphFinetuningNetwork

finetuning_head_kwargs:
    Key-word arguments to use for the finetuning head.
    It must respect the following criteria:
    - pretrained_model_kwargs[last_used_module]["out_level"] must be equal to finetuning_head_kwargs["in_level"]
    - pretrained_model_kwargs[last_used_module]["out_dim"] must be equal to finetuning_head_kwargs["in_dim"]

    Here, [last_used_module] represents the module that is finetuned from,
    e.g., gnn, graph_output or (one of the) task_heads

num_inference_to_average:
    Number of inferences to average at val/test time. This is used to avoid the noise introduced
    by positional encodings with sign-flips. In case no such encoding is given,
    this parameter is ignored.
    NOTE: The inference time will be slowed-down proportionaly to this parameter.

last_layer_is_readout: Whether the last layer should be treated as a readout layer.
    Allows to use the `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup

name:
    Name attributed to the current network, for display and printing
    purposes.

forward(g)

Apply the pre-processing neural network, the graph neural network, and the post-processing neural network on the graph features.

Parameters:

g:
    pyg Batch graph on which the convolution is done.
    Must contain the following elements:

    - Node key `"feat"`: `torch.Tensor[..., N, Din]`.
      Input node feature tensor, before the network.
      `N` is the number of nodes, `Din` is the input features dimension ``self.pre_nn.in_dim``

    - Edge key `"edge_feat"`: `torch.Tensor[..., N, Ein]` **Optional**.
      The edge features to use. It will be ignored if the
      model doesn't supporte edge features or if
      `self.in_dim_edges==0`.

    - Other keys related to positional encodings `"pos_enc_feats_sign_flip"`,
      `"pos_enc_feats_no_flip"`.

Returns:

`torch.Tensor[..., M, Dout]` or `torch.Tensor[..., N, Dout]`:
    Node or graph feature tensor, after the network.
    `N` is the number of nodes, `M` is the number of graphs,
    `Dout` is the output dimension ``self.graph_output_nn.out_dim``
    If the `self.gnn.pooling` is [`None`], then it returns node features and the output dimension is `N`,
    otherwise it returns graph features and the output dimension is `M`

make_mup_base_kwargs(divide_factor=2.0)

Create a 'base' model to be used by the mup or muTransfer scaling of the model. The base model is usually identical to the regular model, but with the layers width divided by a given factor (2 by default)

Parameter

divide_factor: Factor by which to divide the width.

Returns:

Type Description
Dict[str, Any]

Dictionary with the kwargs to create the base model.

set_max_num_nodes_edges_per_graph(max_nodes, max_edges)

Set the maximum number of nodes and edges for all gnn layers and encoder layers

Parameters:

Name Type Description Default
max_nodes Optional[int]

Maximum number of nodes in the dataset. This will be useful for certain architecture, but ignored by others.

required
max_edges Optional[int]

Maximum number of edges in the dataset. This will be useful for certain architecture, but ignored by others.

required

graphium.finetuning.finetuning_architecture.PretrainedModel

Bases: Module, MupMixin

__init__(pretrained_model, pretrained_model_kwargs, pretrained_overwriting_kwargs)

Flexible class allowing to finetune pretrained models from GRAPHIUM_PRETRAINED_MODELS_DICT or from a ckeckpoint path. Can be any model that inherits from nn.Module, MupMixin and comes with a module map (e.g., FullGraphMultitaskNetwork)

Parameters:

pretrained_model:
    Identifier of pretrained model within GRAPHIUM_PRETRAINED_MODELS_DICT or from a checkpoint path

pretrained_model_kwargs:
    Key-word arguments to instantiate a model of the same class as the pretrained model (e.g., FullGraphMultitaskNetwork))

pretrained_overwriting_kwargs:
    Key-word arguments indicating which parameters of loaded model are shared with the pretrained part of FullGraphFinetuningNetwork

make_mup_base_kwargs(divide_factor=2.0)

Create a 'base' model to be used by the mup or muTransfer scaling of the model. The base model is usually identical to the regular model, but with the layers width divided by a given factor (2 by default)

Parameter

divide_factor: Factor by which to divide the width. factor_in_dim: Whether to factor the input dimension

Returns:

Type Description
Dict[str, Any]

Dictionary with the kwargs to create the base model.

overwrite_with_pretrained(pretrained_model, finetuning_module, added_depth=0, sub_module_from_pretrained=None)

Overwrite parameters shared between loaded and modified pretrained model

Parameters:

Name Type Description Default
pretrained_model

Model from GRAPHIUM_PRETRAINED_MODELS_DICT

required
finetuning_module str

Module to finetune from

required
added_depth int

Number of modified layers at the end of finetuning module

0
sub_module_from_pretrained str

Optional submodule to finetune from

None

graphium.finetuning.finetuning_architecture.FinetuningHead

Bases: Module, MupMixin

__init__(finetuning_head_kwargs)

Flexible class allowing to use a custom finetuning head on top of the pretrained model. Can be any model that inherits from nn.Module, MupMixin.

Parameters:

finetuning_head_kwargs: Key-word arguments needed to instantiate a custom (or existing) finetuning head from FINETUNING_HEADS_DICT

make_mup_base_kwargs(divide_factor=2.0, factor_in_dim=False)

Create a 'base' model to be used by the mup or muTransfer scaling of the model. The base model is usually identical to the regular model, but with the layers width divided by a given factor (2 by default)

Parameter

divide_factor: Factor by which to divide the width. factor_in_dim: Whether to factor the input dimension

Returns:

Type Description
Dict[str, Any]

Dictionary with the kwargs to create the base model.

graphium.finetuning.fingerprinting.Fingerprinter

Extract fingerprints from a FullGraphMultiTaskNetwork

Fingerprints are hidden representations of a pre-trained network. They can be used as an additional representation for task-specific ML models in downstream tasks.

Connection to linear probing.

This two-stage process is similar in concept to linear-probing, but pre-computing the fingerprints further decouples the two stages and allows for more flexibility.

CLI support

You can extract fingerprints easily with the CLI. For more info, see

graphium finetune fingerprint --help

To return intermediate representations, the network will be altered to save the readouts of the specified layers. This class is designed to be used as a context manager that automatically reverts the network to its original state.

Examples:

Basic usage:
```python
# Return layer 1 of the gcn submodule.
# For a full list of options, see the `_module_map` attribute of the FullGraphMultiTaskNetwork.
fp_spec = "gcn:1"

# Create the object
fingerprinter = Fingerprinter(predictor, "gcn:1")

# Setup, run and teardown the fingerprinting process
fingerprinter.setup()
fp = fp.get_fingerprints_for_batch(batch)
fingerprinter.teardown()
```

As context manager:
```python
with Fingerprinter(predictor, "gcn:1") as fingerprinter:
    fp = fp.get_fingerprints_for_batch(batch)
```

Concatenating multiple hidden representations:
```python
with Fingerprinter(predictor, ["gcn:0", "gcn:1"]) as fingerprinter:
    fp = fp.get_fingerprints_for_batch(batch)
```

For an entire dataset (expects a PyG dataloader):
```python
with Fingerprinter(predictor, ["gcn:0", "gcn:1"]) as fingerprinter:
    fps = fp.get_fingerprints_for_dataset(dataloader)
```

Returning numpy arrays instead of torch tensors:
```python
with Fingerprinter(predictor, ["gcn:0", "gcn:1"], out_type="numpy") as fingerprinter:
    fps = fp.get_fingerprints_for_dataset(dataloader)
```

__enter__()

Setup the network for fingerprinting

__exit__(exc_type, exc_val, exc_tb)

Revert the network to its original state

__init__(model, fingerprint_spec, out_type='torch')

Parameters:

Name Type Description Default
model Union[FullGraphMultiTaskNetwork, PredictorModule]

Either a FullGraphMultiTaskNetwork or a PredictorModule that contains one. The benefit of using a PredictorModule is that it will automatically convert the features to the correct dtype.

required
fingerprint_spec Union[str, List[str]]

The fingerprint specification. Of the format module:layer. If specified as a list, the fingerprints from all the specified layers will be concatenated.

required
out_type Literal['torch', 'numpy']

The output type of the fingerprints. Either torch or numpy.

'torch'

get_fingerprints_for_batch(batch)

Get the fingerprints for a single batch

get_fingerprints_for_dataset(dataloader)

Return the fingerprints for an entire dataset

setup()

Prepare the network for fingerprinting

teardown()

Restore the network to its original state