graphium.finetuning
graphium.finetuning.fingerprinting.Fingerprinter ¶
Fingerprinter(
model: Union[FullGraphMultiTaskNetwork, PredictorModule],
fingerprint_spec: Union[str, List[str]],
out_type: Literal["torch", "numpy"] = "torch",
)
Extract fingerprints from a FullGraphMultiTaskNetwork
Fingerprints are hidden representations of a pre-trained network. They can be used as an additional representation for task-specific ML models in downstream tasks.
Connection to linear probing.
This two-stage process is similar in concept to linear-probing, but pre-computing the fingerprints further decouples the two stages and allows for more flexibility.
CLI support
You can extract fingerprints easily with the CLI. For more info, see
graphium finetune fingerprint --help
To return intermediate representations, the network will be altered to save the readouts of the specified layers. This class is designed to be used as a context manager that automatically reverts the network to its original state.
Examples:
Basic usage:
```python
# Return layer 1 of the gcn submodule.
# For a full list of options, see the `_module_map` attribute of the FullGraphMultiTaskNetwork.
fp_spec = "gcn:1"
# Create the object
fingerprinter = Fingerprinter(predictor, "gcn:1")
# Setup, run and teardown the fingerprinting process
fingerprinter.setup()
fp = fp.get_fingerprints_for_batch(batch)
fingerprinter.teardown()
```
As context manager:
```python
with Fingerprinter(predictor, "gcn:1") as fingerprinter:
fp = fp.get_fingerprints_for_batch(batch)
```
Concatenating multiple hidden representations:
```python
with Fingerprinter(predictor, ["gcn:0", "gcn:1"]) as fingerprinter:
fp = fp.get_fingerprints_for_batch(batch)
```
For an entire dataset (expects a PyG dataloader):
```python
with Fingerprinter(predictor, ["gcn:0", "gcn:1"]) as fingerprinter:
fps = fp.get_fingerprints_for_dataset(dataloader)
```
Returning numpy arrays instead of torch tensors:
```python
with Fingerprinter(predictor, ["gcn:0", "gcn:1"], out_type="numpy") as fingerprinter:
fps = fp.get_fingerprints_for_dataset(dataloader)
```
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Union[FullGraphMultiTaskNetwork, PredictorModule]
|
Either a |
required |
fingerprint_spec |
Union[str, List[str]]
|
The fingerprint specification. Of the format |
required |
out_type |
Literal['torch', 'numpy']
|
The output type of the fingerprints. Either |
'torch'
|