Skip to content

graphium.trainer

Code for training models

Predictor


graphium.trainer.predictor


Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.

Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.

Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.


PredictorModule

Bases: LightningModule

__init__(model_class, model_kwargs, loss_fun, task_levels, random_seed=42, featurization=None, optim_kwargs=None, torch_scheduler_kwargs=None, scheduler_kwargs=None, target_nan_mask=None, multitask_handling=None, metrics=None, metrics_on_progress_bar=[], metrics_on_training_set=None, flag_kwargs=None, task_norms=None, metrics_every_n_train_steps=None, replicas=1, gradient_acc=1, global_bs=1)

The Lightning module responsible for handling the predictions, losses, metrics, optimization, etc. It works in a multi-task setting, with different losses and metrics per class

Parameters:

Name Type Description Default
model_class Type[Module]

The torch Module containing the main forward function

required
model_kwargs Dict[str, Any]

The arguments to initialize the model from model_class

required
loss_fun Dict[str, Union[str, Callable]]

A dict[str, fun], where str is the task name and fun the loss function

required
random_seed int

The seed for random initialization

42
optim_kwargs Optional[Dict[str, Any]]

The optimization arguments. See class OptimOptions

None
torch_scheduler_kwargs Optional[Dict[str, Any]]

The torch scheduler arguments. See class OptimOptions

None
scheduler_kwargs Optional[Dict[str, Any]]

The lightning scheduler arguments. See class OptimOptions

None
target_nan_mask Optional[Union[str, int]]

How to handle the NaNs. See MetricsWrapper for options

None
metrics Dict[str, Callable]

A dict[str, fun], where str is the task name and fun the metric function

None
metrics_on_progress_bar Dict[str, List[str]]

A dict[str, list[str2], where str is the task name and str2 the metrics to include on the progress bar

[]
metrics_on_training_set Optional[Dict[str, List[str]]]

A dict[str, list[str2], where str is the task name and str2 the metrics to include on the training set

None
flag_kwargs Dict[str, Any]

Arguments related to using the FLAG adversarial augmentation

None
task_norms Optional[Dict[Callable, Any]]

the normalization for each task

None
metrics_every_n_train_steps Optional[int]

Compute and log metrics every n training steps. Set to None to never log the training metrics and statistics (the default). Set to 1 to log at every step.

None
__repr__()

Controls how the class is printed

compute_loss(preds, targets, weights, loss_fun, target_nan_mask=None, multitask_handling=None) staticmethod

Compute the loss using the specified loss function, and dealing with the nans in the targets.

Parameters:

Name Type Description Default
preds Dict[str, Tensor]

Predicted values

required
targets Dict[str, Tensor]

Target values

required
weights Optional[Tensor]

No longer supported, will raise an error.

required
target_nan_mask Optional[Union[str, int]]
  • None: Do not change behaviour if there are NaNs

  • int, float: Value used to replace NaNs. For example, if target_nan_mask==0, then all NaNs will be replaced by zeros

  • 'ignore': The NaN values will be removed from the tensor before computing the metrics. Must be coupled with the multitask_handling='flatten' or multitask_handling='mean-per-label'.

None
multitask_handling Optional[str]
  • None: Do not process the tensor before passing it to the metric. Cannot use the option multitask_handling=None when target_nan_mask=ignore. Use either 'flatten' or 'mean-per-label'.

  • 'flatten': Flatten the tensor to produce the equivalent of a single task

  • 'mean-per-label': Loop all the labels columns, process them as a single task, and average the results over each task This option might slowdown the computation if there are too many labels

None
loss_fun Dict[str, Callable]

Loss function to use

required

Returns:

Name Type Description
Tensor Tuple[Tensor, Dict[str, Tensor]]

weighted_loss: Resulting weighted loss all_task_losses: Loss per task

flag_step(batch, step_name, to_cpu)

Perform adversarial data agumentation during one training step using FLAG. Paper: https://arxiv.org/abs/2010.09891 Github: https://github.com/devnkong/FLAG

forward(inputs)

Returns the result of self.model.forward(*inputs) on the inputs. If the output of out = self.model.forward is a dictionary with a "preds" key, it is returned directly. Otherwise, a new dictionary is created and returns {"preds": out}.

Returns:

Type Description
Dict[str, Union[Tensor, Dict[str, Tensor], Dict[str, Dict[str, Tensor]]]]

A dict with a key "preds" representing the prediction of the network.

get_num_graphs(data)

Method to compute number of graphs in a Batch. Essential to estimate throughput in graphs/s.

list_pretrained_models() staticmethod

List available pretrained models.

load_pretrained_model(name_or_path, device=None) staticmethod

Load a pretrained model from its name.

Parameters:

Name Type Description Default
name

Name of the model to load or a valid checkpoint path. List available from graphium.trainer.PredictorModule.list_pretrained_models().

required

Metrics


graphium.trainer.metrics


Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.

Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.

Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.


MetricWrapper

Allows to initialize a metric from a name or Callable, and initialize the Thresholder in case the metric requires a threshold.

__call__(preds, target)

Compute the metric with the method self.compute

__getstate__()

Serialize the class for pickling.

__init__(metric, threshold_kwargs=None, target_nan_mask=None, multitask_handling=None, squeeze_targets=False, target_to_int=False, **kwargs)

Parameters metric: The metric to use. See METRICS_DICT

threshold_kwargs:
    If `None`, no threshold is applied.
    Otherwise, we use the class `Thresholder` is initialized with the
    provided argument, and called before the `compute`

target_nan_mask:

    - None: Do not change behaviour if there are NaNs

    - int, float: Value used to replace NaNs. For example, if `target_nan_mask==0`, then
      all NaNs will be replaced by zeros

    - 'ignore': The NaN values will be removed from the tensor before computing the metrics.
      Must be coupled with the `multitask_handling='flatten'` or `multitask_handling='mean-per-label'`.

multitask_handling:
    - None: Do not process the tensor before passing it to the metric.
      Cannot use the option `multitask_handling=None` when `target_nan_mask=ignore`.
      Use either 'flatten' or 'mean-per-label'.

    - 'flatten': Flatten the tensor to produce the equivalent of a single task

    - 'mean-per-label': Loop all the labels columns, process them as a single task,
        and average the results over each task
      *This option might slow down the computation if there are too many labels*

squeeze_targets:
    If true, targets will be squeezed prior to computing the metric.
    Required in classifigression task.

target_to_int:
    If true, targets will be converted to integers prior to computing the metric.

kwargs:
    Other arguments to call with the metric
__repr__()

Control how the class is printed

__setstate__(state)

Reload the class from pickling.

compute(preds, target)

Compute the metric, apply the thresholder if provided, and manage the NaNs

Thresholder

__getstate__()

Serialize the class for pickling.

__repr__()

Control how the class is printed

Predictor Summaries


graphium.trainer.predictor_summaries


Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.

Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.

Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.


Summary

Bases: SummaryInterface

Results
__init__(targets=None, preds=None, loss=None, metrics=None, monitored_metric=None, n_epochs=None)

This inner class is used as a container for storing the results of the summary. Parameters: targets: the targets preds: the prediction tensor loss: the loss, float or tensor metrics: the metrics monitored_metric: the monitored metric n_epochs: the number of epochs

__init__(loss_fun, metrics, metrics_on_training_set=[], metrics_on_progress_bar=[], monitor='loss', mode='min', task_name=None)

A container to be used by the Predictor Module that stores the results for the given metrics on the predictions and targets provided. Parameters: loss_fun: Loss function used during training. Acceptable strings are 'mse', 'bce', 'mae', 'cosine'. Otherwise, a callable object must be provided, with a method loss_fun._get_name().

metrics:
A dictionnary of metrics to compute on the prediction, other than the loss function.
These metrics will be logged into WandB or similar.

metrics_on_training_set:
The metrics names from `metrics` to be computed on the training set for each iteration.
If `None`, all the metrics are computed. Using less metrics can significantly improve
performance, depending on the number of readouts.

metrics_on_progress_bar:
The metrics names from `metrics` to display also on the progress bar of the training

monitor:
`str` metric to track (Default=`"loss/val"`)

task_name:
name of the task (Default=`None`)
get_best_results(step_name)

retrieve the best results for a given step Parameters: step_name: which stage you are in, e.g. "train" Returns: the best results for the given step

get_dict_summary()

retrieve the full summary in a dictionary Returns: the full summary in a dictionary

get_metrics_logs()

Get the data about metrics to log. Note: This function requires that self.update_predictor_state() be called before it. Returns: A dictionary of metrics to log.

get_results(step_name)

retrieve the results for a given step Parameters: step_name: which stage you are in, e.g. "train" Returns: the results for the given step

get_results_on_progress_bar(step_name)

retrieve the results to be displayed on the progress bar for a given step Parameters: step_name: which stage you are in, e.g. "train" Returns: the results to be displayed on the progress bar for the given step

is_best_epoch(step_name, loss, metrics)

check if the current epoch is the best epoch based on self.mode criteria Parameters: step_name: which stage you are in, e.g. "train" loss: the loss tensor metrics: a dictionary of metrics

set_results(metrics)

set the reults from the metrics [!] This function requires that self.update_predictor_state() be called before it. Parameters: metrics: a dictionary of metrics

update_predictor_state(step_name, targets, preds, loss, n_epochs)

update the state of the predictor Parameters: step_name: which stage you are in, e.g. "train" targets: the targets tensor predictions: the predictions tensor loss: the loss tensor n_epochs: the number of epochs

SummaryInterface

Bases: object

An interface to define the functions implemented by summary classes that implement SummaryInterface.

TaskSummaries

Bases: SummaryInterface

__init__(task_loss_fun, task_metrics, task_metrics_on_training_set, task_metrics_on_progress_bar, monitor='loss', mode='min')

class to store the summaries of the tasks Parameters: task_loss_fun: the loss function for each task task_metrics: the metrics for each task task_metrics_on_training_set: the metrics to use on the training set task_metrics_on_progress_bar: the metrics to use on the progress bar monitor: the metric to monitor mode: the mode of the metric to monitor

concatenate_metrics_logs(metrics_logs)

concatenate the metrics logs Parameters: metrics_logs: the metrics logs Returns: the concatenated metrics logs

get_best_results(step_name)

retrieve the best results Parameters: step_name: the name of the step, i.e. "train" Returns: the best results

get_dict_summary()

get task summaries in a dictionary Returns: the task summaries

get_metrics_logs()

get the logs for the metrics Returns: the task logs for the metrics

get_results(step_name)

retrieve the results Parameters: step_name: the name of the step, i.e. "train" Returns: the results

get_results_on_progress_bar(step_name)

return all results from all tasks for the progress bar Combine the dictionaries. Instead of having keys as task names, we merge all the task-specific dictionaries. Parameters: step_name: the name of the step Returns: the results for the progress bar

metric_log_name(task_name, metric_name, step_name)

print the metric name, task name and step name Returns: the metric name, task name and step name

set_results(task_metrics)

set the results for all tasks Parameters: task_metrics: the metrics for each task

update_predictor_state(step_name, targets, preds, loss, task_losses, n_epochs)

update the state for all predictors Parameters: step_name: the name of the step targets: the target tensors preds: the prediction tensors loss: the loss tensor task_losses: the task losses n_epochs: the number of epochs

Predictor Options


graphium.trainer.predictor_options


Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.

Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.

Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.


EvalOptions dataclass

This data class stores the arguments necessary to instantiate a model for the Predictor.

Parameters:

Name Type Description Default
loss_fun Union[str, Dict, Callable]

Loss function used during training. Acceptable strings are graphium.utils.spaces.LOSS_DICT.keys(). If a dict, must contain a 'name' key with one of the acceptable loss function strings as a value. The rest of the dict will be used as the arguments passed to the loss object. Otherwise, a callable object must be provided, with a method loss_fun._get_name().

required
metrics Dict[str, Callable]

A dictionnary of metrics to compute on the prediction, other than the loss function. These metrics will be logged into WandB or other.

None
metrics_on_progress_bar List[str]

The metrics names from metrics to display also on the progress bar of the training

field(default_factory=List[str])
metrics_on_training_set Optional[List[str]]

The metrics names from metrics to be computed on the training set for each iteration. If None, all the metrics are computed. Using less metrics can significantly improve performance, depending on the number of readouts.

None
check_metrics_validity()

Check that the metrics for the progress_par and training_set are valid

parse_loss_fun(loss_fun) staticmethod

Parse the loss function from a string or a dict

Parameters:

Name Type Description Default
loss_fun Union[str, Dict, Callable]

A callable corresponding to the loss function, a string specifying the loss function from LOSS_DICT, or a dict containing a key 'name' specifying the loss function, and the rest of the dict used as the arguments for the loss. Accepted strings are: graphium.utils.spaces.LOSS_DICT.keys().

required

Returns:

Name Type Description
Callable Callable

Function or callable to compute the loss, takes preds and targets as inputs.

FlagOptions dataclass

This data class stores the arguments necessary to instantiate a model for the Predictor.

Parameters:

Name Type Description Default
flag_kwargs Dict[str, Any]

Keyword arguments used for FLAG, and adversarial data augmentation for graph networks. See: https://arxiv.org/abs/2010.09891

  • n_steps: An integer that specifies the number of ascent steps when running FLAG during training. Default value of 0 trains GNNs without FLAG, and any value greater than 0 will use FLAG with that many iterations.

  • alpha: A float that specifies the ascent step size when running FLAG. Default=0.01

None

ModelOptions dataclass

This data class stores the arguments necessary to instantiate a model for the Predictor.

Parameters:

Name Type Description Default
model_class Type[Module]

pytorch module used to create a model

required
model_kwargs Dict[str, Any]

Key-word arguments used to initialize the model from model_class.

required

OptimOptions dataclass

This data class stores the arguments necessary to configure the optimizer for the Predictor.

Parameters:

Name Type Description Default
optim_kwargs Optional[Dict[str, Any]]

Dictionnary used to initialize the optimizer, with possible keys below.

  • lr float: Learning rate (Default=1e-3)
  • weight_decay float: Weight decay used to regularize the optimizer (Default=0.)
None
torch_scheduler_kwargs Optional[Dict[str, Any]]

Dictionnary for the scheduling of learning rate, with possible keys below.

  • type str: Type of the learning rate to use from pytorch. Examples are 'ReduceLROnPlateau' (default), 'CosineAnnealingWarmRestarts', 'StepLR', etc.
  • **kwargs: Any other argument for the learning rate scheduler
None
scheduler_kwargs Optional[Dict[str, Any]]

Dictionnary for the scheduling of the learning rate modification used by pytorch-lightning

  • monitor str: metric to track (Default="loss/val")
  • interval str: Whether to look at iterations or epochs (Default="epoch")
  • strict bool: if set to True will enforce that value specified in monitor is available while trying to call scheduler.step(), and stop training if not found. If False will only give a warning and continue training (without calling the scheduler). (Default=True)
  • frequency int: TODO: NOT REALLY SURE HOW IT WORKS! (Default=1)
None
scheduler_class Optional[Union[str, Type]]

The class to use for the scheduler, or the str representing the scheduler.

None