graphium.finetuning¶
Module for finetuning models and doing linear probing (fingerprinting).
graphium.finetuning.finetuning.GraphFinetuning
¶
Bases: BaseFinetuning
__init__(finetuning_module, added_depth=0, unfreeze_pretrained_depth=None, epoch_unfreeze_all=0, train_bn=False)
¶
Finetuning training callback that (un)freezes modules as specified in the configuration file. By default, the modified layers of the fineuning module and the finetuning head are unfrozen.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
finetuning_module |
str
|
Module to finetune from |
required |
added_depth |
int
|
Number of layers of finetuning module that have been modified rel. to pretrained model |
0
|
unfreeze_pretrained_depth |
Optional[int]
|
Number of additional layers to unfreeze before layers modified rel. to pretrained model |
None
|
epoch_unfreeze_all |
int
|
Epoch to unfreeze entire model |
0
|
train_bn |
bool
|
Boolean value indicating if batchnorm layers stay in training mode |
False
|
finetune_function(pl_module, epoch, optimizer)
¶
Function unfreezing entire model at specified epoch
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pl_module |
LightningModule
|
PredictorModule used for finetuning |
required |
epoch |
int
|
Current training epoch |
required |
optimizer |
Optimizer
|
Optimizer used for finetuning |
required |
freeze_before_training(pl_module)
¶
Freeze everything up to finetuning module (and parts of finetuning module)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pl_module |
LightningModule
|
PredictorModule used for finetuning |
required |
freeze_module(pl_module, module_name, module_map)
¶
Freeze specific modules
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module_name |
str
|
Name of module to (partally) freeze |
required |
module_map |
Dict[str, Union[ModuleList, Any]]
|
Dictionary mapping from module_name to corresponding module(s) |
required |
graphium.finetuning.finetuning_architecture.FullGraphFinetuningNetwork
¶
Bases: Module
, MupMixin
__init__(pretrained_model, pretrained_model_kwargs={}, pretrained_overwriting_kwargs={}, finetuning_head_kwargs=None, num_inference_to_average=1, last_layer_is_readout=False, name='FullFinetuningGNN')
¶
Flexible class that allows to implement an end-to-end graph finetuning network architecture, supporting flexible pretrained models and finetuning heads. The network decomposes into two parts of class PretrainedModel and FinetuningHead. The PretrainedModel class allows basic finetuning such as finetuning from a specified module of the pretrained model and dropping/adding layers in this module. The (optional) FinetuningHead class allows more flexible finetuning with a custom network applied after the pretrained model. If not specified, we fall back to basic finetuning integrated in PretrainedModel.
Parameters:
pretrained_model:
A PretrainedModel or an identifier of pretrained model within GRAPHIUM_PRETRAINED_MODELS_DICT or a valid .ckpt checkpoint path
pretrained_model_kwargs:
Key-word arguments to instantiate a model of the same class as the pretrained model (e.g., FullGraphMultitaskNetwork))
pretrained_overwriting_kwargs:
Key-word arguments indicating which parameters of loaded model are shared with the pretrained part of FullGraphFinetuningNetwork
finetuning_head_kwargs:
Key-word arguments to use for the finetuning head.
It must respect the following criteria:
- pretrained_model_kwargs[last_used_module]["out_level"] must be equal to finetuning_head_kwargs["in_level"]
- pretrained_model_kwargs[last_used_module]["out_dim"] must be equal to finetuning_head_kwargs["in_dim"]
Here, [last_used_module] represents the module that is finetuned from,
e.g., gnn, graph_output or (one of the) task_heads
num_inference_to_average:
Number of inferences to average at val/test time. This is used to avoid the noise introduced
by positional encodings with sign-flips. In case no such encoding is given,
this parameter is ignored.
NOTE: The inference time will be slowed-down proportionaly to this parameter.
last_layer_is_readout: Whether the last layer should be treated as a readout layer.
Allows to use the `mup.MuReadout` from the muTransfer method https://github.com/microsoft/mup
name:
Name attributed to the current network, for display and printing
purposes.
forward(g)
¶
Apply the pre-processing neural network, the graph neural network, and the post-processing neural network on the graph features.
Parameters:
g:
pyg Batch graph on which the convolution is done.
Must contain the following elements:
- Node key `"feat"`: `torch.Tensor[..., N, Din]`.
Input node feature tensor, before the network.
`N` is the number of nodes, `Din` is the input features dimension ``self.pre_nn.in_dim``
- Edge key `"edge_feat"`: `torch.Tensor[..., N, Ein]` **Optional**.
The edge features to use. It will be ignored if the
model doesn't supporte edge features or if
`self.in_dim_edges==0`.
- Other keys related to positional encodings `"pos_enc_feats_sign_flip"`,
`"pos_enc_feats_no_flip"`.
Returns:
`torch.Tensor[..., M, Dout]` or `torch.Tensor[..., N, Dout]`:
Node or graph feature tensor, after the network.
`N` is the number of nodes, `M` is the number of graphs,
`Dout` is the output dimension ``self.graph_output_nn.out_dim``
If the `self.gnn.pooling` is [`None`], then it returns node features and the output dimension is `N`,
otherwise it returns graph features and the output dimension is `M`
make_mup_base_kwargs(divide_factor=2.0)
¶
Create a 'base' model to be used by the mup
or muTransfer
scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter
divide_factor: Factor by which to divide the width.
Returns:
Type | Description |
---|---|
Dict[str, Any]
|
Dictionary with the kwargs to create the base model. |
set_max_num_nodes_edges_per_graph(max_nodes, max_edges)
¶
Set the maximum number of nodes and edges for all gnn layers and encoder layers
Parameters:
Name | Type | Description | Default |
---|---|---|---|
max_nodes |
Optional[int]
|
Maximum number of nodes in the dataset. This will be useful for certain architecture, but ignored by others. |
required |
max_edges |
Optional[int]
|
Maximum number of edges in the dataset. This will be useful for certain architecture, but ignored by others. |
required |
graphium.finetuning.finetuning_architecture.PretrainedModel
¶
Bases: Module
, MupMixin
__init__(pretrained_model, pretrained_model_kwargs, pretrained_overwriting_kwargs)
¶
Flexible class allowing to finetune pretrained models from GRAPHIUM_PRETRAINED_MODELS_DICT or from a ckeckpoint path. Can be any model that inherits from nn.Module, MupMixin and comes with a module map (e.g., FullGraphMultitaskNetwork)
Parameters:
pretrained_model:
Identifier of pretrained model within GRAPHIUM_PRETRAINED_MODELS_DICT or from a checkpoint path
pretrained_model_kwargs:
Key-word arguments to instantiate a model of the same class as the pretrained model (e.g., FullGraphMultitaskNetwork))
pretrained_overwriting_kwargs:
Key-word arguments indicating which parameters of loaded model are shared with the pretrained part of FullGraphFinetuningNetwork
make_mup_base_kwargs(divide_factor=2.0)
¶
Create a 'base' model to be used by the mup
or muTransfer
scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter
divide_factor: Factor by which to divide the width. factor_in_dim: Whether to factor the input dimension
Returns:
Type | Description |
---|---|
Dict[str, Any]
|
Dictionary with the kwargs to create the base model. |
overwrite_with_pretrained(pretrained_model, finetuning_module, added_depth=0, sub_module_from_pretrained=None)
¶
Overwrite parameters shared between loaded and modified pretrained model
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pretrained_model |
Model from GRAPHIUM_PRETRAINED_MODELS_DICT |
required | |
finetuning_module |
str
|
Module to finetune from |
required |
added_depth |
int
|
Number of modified layers at the end of finetuning module |
0
|
sub_module_from_pretrained |
str
|
Optional submodule to finetune from |
None
|
graphium.finetuning.finetuning_architecture.FinetuningHead
¶
Bases: Module
, MupMixin
__init__(finetuning_head_kwargs)
¶
Flexible class allowing to use a custom finetuning head on top of the pretrained model. Can be any model that inherits from nn.Module, MupMixin.
Parameters:
finetuning_head_kwargs: Key-word arguments needed to instantiate a custom (or existing) finetuning head from FINETUNING_HEADS_DICT
make_mup_base_kwargs(divide_factor=2.0, factor_in_dim=False)
¶
Create a 'base' model to be used by the mup
or muTransfer
scaling of the model.
The base model is usually identical to the regular model, but with the
layers width divided by a given factor (2 by default)
Parameter
divide_factor: Factor by which to divide the width. factor_in_dim: Whether to factor the input dimension
Returns:
Type | Description |
---|---|
Dict[str, Any]
|
Dictionary with the kwargs to create the base model. |
graphium.finetuning.fingerprinting.Fingerprinter
¶
Extract fingerprints from a FullGraphMultiTaskNetwork
Fingerprints are hidden representations of a pre-trained network. They can be used as an additional representation for task-specific ML models in downstream tasks.
Connection to linear probing.
This two-stage process is similar in concept to linear-probing, but pre-computing the fingerprints further decouples the two stages and allows for more flexibility.
CLI support
You can extract fingerprints easily with the CLI. For more info, see
graphium finetune fingerprint --help
To return intermediate representations, the network will be altered to save the readouts of the specified layers. This class is designed to be used as a context manager that automatically reverts the network to its original state.
Examples:
Basic usage:
```python
# Return layer 1 of the gcn submodule.
# For a full list of options, see the `_module_map` attribute of the FullGraphMultiTaskNetwork.
fp_spec = "gcn:1"
# Create the object
fingerprinter = Fingerprinter(predictor, "gcn:1")
# Setup, run and teardown the fingerprinting process
fingerprinter.setup()
fp = fp.get_fingerprints_for_batch(batch)
fingerprinter.teardown()
```
As context manager:
```python
with Fingerprinter(predictor, "gcn:1") as fingerprinter:
fp = fp.get_fingerprints_for_batch(batch)
```
Concatenating multiple hidden representations:
```python
with Fingerprinter(predictor, ["gcn:0", "gcn:1"]) as fingerprinter:
fp = fp.get_fingerprints_for_batch(batch)
```
For an entire dataset (expects a PyG dataloader):
```python
with Fingerprinter(predictor, ["gcn:0", "gcn:1"]) as fingerprinter:
fps = fp.get_fingerprints_for_dataset(dataloader)
```
Returning numpy arrays instead of torch tensors:
```python
with Fingerprinter(predictor, ["gcn:0", "gcn:1"], out_type="numpy") as fingerprinter:
fps = fp.get_fingerprints_for_dataset(dataloader)
```
__enter__()
¶
Setup the network for fingerprinting
__exit__(exc_type, exc_val, exc_tb)
¶
Revert the network to its original state
__init__(model, fingerprint_spec, out_type='torch')
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Union[FullGraphMultiTaskNetwork, PredictorModule]
|
Either a |
required |
fingerprint_spec |
Union[str, List[str]]
|
The fingerprint specification. Of the format |
required |
out_type |
Literal['torch', 'numpy']
|
The output type of the fingerprints. Either |
'torch'
|
get_fingerprints_for_batch(batch)
¶
Get the fingerprints for a single batch
get_fingerprints_for_dataset(dataloader)
¶
Return the fingerprints for an entire dataset
setup()
¶
Prepare the network for fingerprinting
teardown()
¶
Restore the network to its original state