Skip to content

graphium.finetuning

graphium.finetuning.fingerprinting.Fingerprinter

Fingerprinter(
    model: Union[FullGraphMultiTaskNetwork, PredictorModule],
    fingerprint_spec: Union[str, List[str]],
    out_type: Literal["torch", "numpy"] = "torch",
)

Extract fingerprints from a FullGraphMultiTaskNetwork

Fingerprints are hidden representations of a pre-trained network. They can be used as an additional representation for task-specific ML models in downstream tasks.

Connection to linear probing.

This two-stage process is similar in concept to linear-probing, but pre-computing the fingerprints further decouples the two stages and allows for more flexibility.

CLI support

You can extract fingerprints easily with the CLI. For more info, see

graphium finetune fingerprint --help

To return intermediate representations, the network will be altered to save the readouts of the specified layers. This class is designed to be used as a context manager that automatically reverts the network to its original state.

Examples:

Basic usage:
```python
# Return layer 1 of the gcn submodule.
# For a full list of options, see the `_module_map` attribute of the FullGraphMultiTaskNetwork.
fp_spec = "gcn:1"

# Create the object
fingerprinter = Fingerprinter(predictor, "gcn:1")

# Setup, run and teardown the fingerprinting process
fingerprinter.setup()
fp = fp.get_fingerprints_for_batch(batch)
fingerprinter.teardown()
```

As context manager:
```python
with Fingerprinter(predictor, "gcn:1") as fingerprinter:
    fp = fp.get_fingerprints_for_batch(batch)
```

Concatenating multiple hidden representations:
```python
with Fingerprinter(predictor, ["gcn:0", "gcn:1"]) as fingerprinter:
    fp = fp.get_fingerprints_for_batch(batch)
```

For an entire dataset (expects a PyG dataloader):
```python
with Fingerprinter(predictor, ["gcn:0", "gcn:1"]) as fingerprinter:
    fps = fp.get_fingerprints_for_dataset(dataloader)
```

Returning numpy arrays instead of torch tensors:
```python
with Fingerprinter(predictor, ["gcn:0", "gcn:1"], out_type="numpy") as fingerprinter:
    fps = fp.get_fingerprints_for_dataset(dataloader)
```

Parameters:

Name Type Description Default
model Union[FullGraphMultiTaskNetwork, PredictorModule]

Either a FullGraphMultiTaskNetwork or a PredictorModule that contains one. The benefit of using a PredictorModule is that it will automatically convert the features to the correct dtype.

required
fingerprint_spec Union[str, List[str]]

The fingerprint specification. Of the format module:layer. If specified as a list, the fingerprints from all the specified layers will be concatenated.

required
out_type Literal['torch', 'numpy']

The output type of the fingerprints. Either torch or numpy.

'torch'

setup

setup()

Prepare the network for fingerprinting

get_fingerprints_for_batch

get_fingerprints_for_batch(batch)

Get the fingerprints for a single batch

get_fingerprints_for_dataset

get_fingerprints_for_dataset(dataloader)

Return the fingerprints for an entire dataset

teardown

teardown()

Restore the network to its original state