Skip to content


Code for training models



Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.

Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.

Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.


Bases: LightningModule

__init__(model_class, model_kwargs, loss_fun, task_levels, random_seed=42, featurization=None, optim_kwargs=None, torch_scheduler_kwargs=None, scheduler_kwargs=None, target_nan_mask=None, multitask_handling=None, metrics=None, metrics_on_progress_bar=[], metrics_on_training_set=None, flag_kwargs=None, task_norms=None, metrics_every_n_train_steps=None, replicas=1, gradient_acc=1, global_bs=1)

The Lightning module responsible for handling the predictions, losses, metrics, optimization, etc. It works in a multi-task setting, with different losses and metrics per class


Name Type Description Default
model_class Type[Module]

The torch Module containing the main forward function

model_kwargs Dict[str, Any]

The arguments to initialize the model from model_class

loss_fun Dict[str, Union[str, Callable]]

A dict[str, fun], where str is the task name and fun the loss function

random_seed int

The seed for random initialization

optim_kwargs Optional[Dict[str, Any]]

The optimization arguments. See class OptimOptions

torch_scheduler_kwargs Optional[Dict[str, Any]]

The torch scheduler arguments. See class OptimOptions

scheduler_kwargs Optional[Dict[str, Any]]

The lightning scheduler arguments. See class OptimOptions

target_nan_mask Optional[Union[str, int]]

How to handle the NaNs. See MetricsWrapper for options

metrics Dict[str, Callable]

A dict[str, fun], where str is the task name and fun the metric function

metrics_on_progress_bar Dict[str, List[str]]

A dict[str, list[str2], where str is the task name and str2 the metrics to include on the progress bar

metrics_on_training_set Optional[Dict[str, List[str]]]

A dict[str, list[str2], where str is the task name and str2 the metrics to include on the training set

flag_kwargs Dict[str, Any]

Arguments related to using the FLAG adversarial augmentation

task_norms Optional[Dict[Callable, Any]]

the normalization for each task

metrics_every_n_train_steps Optional[int]

Compute and log metrics every n training steps. Set to None to never log the training metrics and statistics (the default). Set to 1 to log at every step.


Controls how the class is printed

compute_loss(preds, targets, weights, loss_fun, target_nan_mask=None, multitask_handling=None) staticmethod

Compute the loss using the specified loss function, and dealing with the nans in the targets.


Name Type Description Default
preds Dict[str, Tensor]

Predicted values

targets Dict[str, Tensor]

Target values

weights Optional[Tensor]

No longer supported, will raise an error.

target_nan_mask Optional[Union[str, int]]
  • None: Do not change behaviour if there are NaNs

  • int, float: Value used to replace NaNs. For example, if target_nan_mask==0, then all NaNs will be replaced by zeros

  • 'ignore': The NaN values will be removed from the tensor before computing the metrics. Must be coupled with the multitask_handling='flatten' or multitask_handling='mean-per-label'.

multitask_handling Optional[str]
  • None: Do not process the tensor before passing it to the metric. Cannot use the option multitask_handling=None when target_nan_mask=ignore. Use either 'flatten' or 'mean-per-label'.

  • 'flatten': Flatten the tensor to produce the equivalent of a single task

  • 'mean-per-label': Loop all the labels columns, process them as a single task, and average the results over each task This option might slowdown the computation if there are too many labels

loss_fun Dict[str, Callable]

Loss function to use



Name Type Description
Tensor Tuple[Tensor, Dict[str, Tensor]]

weighted_loss: Resulting weighted loss all_task_losses: Loss per task

flag_step(batch, step_name, to_cpu)

Perform adversarial data agumentation during one training step using FLAG. Paper: Github:


Returns the result of self.model.forward(*inputs) on the inputs. If the output of out = self.model.forward is a dictionary with a "preds" key, it is returned directly. Otherwise, a new dictionary is created and returns {"preds": out}.


Type Description
Dict[str, Union[Tensor, Dict[str, Tensor], Dict[str, Dict[str, Tensor]]]]

A dict with a key "preds" representing the prediction of the network.


Method to compute number of graphs in a Batch. Essential to estimate throughput in graphs/s.

list_pretrained_models() staticmethod

List available pretrained models.

load_pretrained_model(name_or_path, device=None) staticmethod

Load a pretrained model from its name.


Name Type Description Default

Name of the model to load or a valid checkpoint path. List available from graphium.trainer.PredictorModule.list_pretrained_models().




Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.

Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.

Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.


Allows to initialize a metric from a name or Callable, and initialize the Thresholder in case the metric requires a threshold.

__call__(preds, target)

Compute the metric with the method self.compute


Serialize the class for pickling.

__init__(metric, threshold_kwargs=None, target_nan_mask=None, multitask_handling=None, squeeze_targets=False, target_to_int=False, **kwargs)

Parameters metric: The metric to use. See METRICS_DICT

    If `None`, no threshold is applied.
    Otherwise, we use the class `Thresholder` is initialized with the
    provided argument, and called before the `compute`


    - None: Do not change behaviour if there are NaNs

    - int, float: Value used to replace NaNs. For example, if `target_nan_mask==0`, then
      all NaNs will be replaced by zeros

    - 'ignore': The NaN values will be removed from the tensor before computing the metrics.
      Must be coupled with the `multitask_handling='flatten'` or `multitask_handling='mean-per-label'`.

    - None: Do not process the tensor before passing it to the metric.
      Cannot use the option `multitask_handling=None` when `target_nan_mask=ignore`.
      Use either 'flatten' or 'mean-per-label'.

    - 'flatten': Flatten the tensor to produce the equivalent of a single task

    - 'mean-per-label': Loop all the labels columns, process them as a single task,
        and average the results over each task
      *This option might slow down the computation if there are too many labels*

    If true, targets will be squeezed prior to computing the metric.
    Required in classifigression task.

    If true, targets will be converted to integers prior to computing the metric.

    Other arguments to call with the metric

Control how the class is printed


Reload the class from pickling.

compute(preds, target)

Compute the metric, apply the thresholder if provided, and manage the NaNs



Serialize the class for pickling.


Control how the class is printed

Predictor Summaries


Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.

Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.

Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.


Bases: SummaryInterface

__init__(targets=None, preds=None, loss=None, metrics=None, monitored_metric=None, n_epochs=None)

This inner class is used as a container for storing the results of the summary. Parameters: targets: the targets preds: the prediction tensor loss: the loss, float or tensor metrics: the metrics monitored_metric: the monitored metric n_epochs: the number of epochs

__init__(loss_fun, metrics, metrics_on_training_set=[], metrics_on_progress_bar=[], monitor='loss', mode='min', task_name=None)

A container to be used by the Predictor Module that stores the results for the given metrics on the predictions and targets provided. Parameters: loss_fun: Loss function used during training. Acceptable strings are 'mse', 'bce', 'mae', 'cosine'. Otherwise, a callable object must be provided, with a method loss_fun._get_name().

A dictionnary of metrics to compute on the prediction, other than the loss function.
These metrics will be logged into WandB or similar.

The metrics names from `metrics` to be computed on the training set for each iteration.
If `None`, all the metrics are computed. Using less metrics can significantly improve
performance, depending on the number of readouts.

The metrics names from `metrics` to display also on the progress bar of the training

`str` metric to track (Default=`"loss/val"`)

name of the task (Default=`None`)

retrieve the best results for a given step Parameters: step_name: which stage you are in, e.g. "train" Returns: the best results for the given step


retrieve the full summary in a dictionary Returns: the full summary in a dictionary


Get the data about metrics to log. Note: This function requires that self.update_predictor_state() be called before it. Returns: A dictionary of metrics to log.


retrieve the results for a given step Parameters: step_name: which stage you are in, e.g. "train" Returns: the results for the given step


retrieve the results to be displayed on the progress bar for a given step Parameters: step_name: which stage you are in, e.g. "train" Returns: the results to be displayed on the progress bar for the given step

is_best_epoch(step_name, loss, metrics)

check if the current epoch is the best epoch based on self.mode criteria Parameters: step_name: which stage you are in, e.g. "train" loss: the loss tensor metrics: a dictionary of metrics


set the reults from the metrics [!] This function requires that self.update_predictor_state() be called before it. Parameters: metrics: a dictionary of metrics

update_predictor_state(step_name, targets, preds, loss, n_epochs)

update the state of the predictor Parameters: step_name: which stage you are in, e.g. "train" targets: the targets tensor predictions: the predictions tensor loss: the loss tensor n_epochs: the number of epochs


Bases: object

An interface to define the functions implemented by summary classes that implement SummaryInterface.


Bases: SummaryInterface

__init__(task_loss_fun, task_metrics, task_metrics_on_training_set, task_metrics_on_progress_bar, monitor='loss', mode='min')

class to store the summaries of the tasks Parameters: task_loss_fun: the loss function for each task task_metrics: the metrics for each task task_metrics_on_training_set: the metrics to use on the training set task_metrics_on_progress_bar: the metrics to use on the progress bar monitor: the metric to monitor mode: the mode of the metric to monitor


concatenate the metrics logs Parameters: metrics_logs: the metrics logs Returns: the concatenated metrics logs


retrieve the best results Parameters: step_name: the name of the step, i.e. "train" Returns: the best results


get task summaries in a dictionary Returns: the task summaries


get the logs for the metrics Returns: the task logs for the metrics


retrieve the results Parameters: step_name: the name of the step, i.e. "train" Returns: the results


return all results from all tasks for the progress bar Combine the dictionaries. Instead of having keys as task names, we merge all the task-specific dictionaries. Parameters: step_name: the name of the step Returns: the results for the progress bar

metric_log_name(task_name, metric_name, step_name)

print the metric name, task name and step name Returns: the metric name, task name and step name


set the results for all tasks Parameters: task_metrics: the metrics for each task

update_predictor_state(step_name, targets, preds, loss, task_losses, n_epochs)

update the state for all predictors Parameters: step_name: the name of the step targets: the target tensors preds: the prediction tensors loss: the loss tensor task_losses: the task losses n_epochs: the number of epochs

Predictor Options


Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.

Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.

Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.

EvalOptions dataclass

This data class stores the arguments necessary to instantiate a model for the Predictor.


Name Type Description Default
loss_fun Union[str, Dict, Callable]

Loss function used during training. Acceptable strings are graphium.utils.spaces.LOSS_DICT.keys(). If a dict, must contain a 'name' key with one of the acceptable loss function strings as a value. The rest of the dict will be used as the arguments passed to the loss object. Otherwise, a callable object must be provided, with a method loss_fun._get_name().

metrics Dict[str, Callable]

A dictionnary of metrics to compute on the prediction, other than the loss function. These metrics will be logged into WandB or other.

metrics_on_progress_bar List[str]

The metrics names from metrics to display also on the progress bar of the training

metrics_on_training_set Optional[List[str]]

The metrics names from metrics to be computed on the training set for each iteration. If None, all the metrics are computed. Using less metrics can significantly improve performance, depending on the number of readouts.


Check that the metrics for the progress_par and training_set are valid

parse_loss_fun(loss_fun) staticmethod

Parse the loss function from a string or a dict


Name Type Description Default
loss_fun Union[str, Dict, Callable]

A callable corresponding to the loss function, a string specifying the loss function from LOSS_DICT, or a dict containing a key 'name' specifying the loss function, and the rest of the dict used as the arguments for the loss. Accepted strings are: graphium.utils.spaces.LOSS_DICT.keys().



Name Type Description
Callable Callable

Function or callable to compute the loss, takes preds and targets as inputs.

FlagOptions dataclass

This data class stores the arguments necessary to instantiate a model for the Predictor.


Name Type Description Default
flag_kwargs Dict[str, Any]

Keyword arguments used for FLAG, and adversarial data augmentation for graph networks. See:

  • n_steps: An integer that specifies the number of ascent steps when running FLAG during training. Default value of 0 trains GNNs without FLAG, and any value greater than 0 will use FLAG with that many iterations.

  • alpha: A float that specifies the ascent step size when running FLAG. Default=0.01


ModelOptions dataclass

This data class stores the arguments necessary to instantiate a model for the Predictor.


Name Type Description Default
model_class Type[Module]

pytorch module used to create a model

model_kwargs Dict[str, Any]

Key-word arguments used to initialize the model from model_class.


OptimOptions dataclass

This data class stores the arguments necessary to configure the optimizer for the Predictor.


Name Type Description Default
optim_kwargs Optional[Dict[str, Any]]

Dictionnary used to initialize the optimizer, with possible keys below.

  • lr float: Learning rate (Default=1e-3)
  • weight_decay float: Weight decay used to regularize the optimizer (Default=0.)
torch_scheduler_kwargs Optional[Dict[str, Any]]

Dictionnary for the scheduling of learning rate, with possible keys below.

  • type str: Type of the learning rate to use from pytorch. Examples are 'ReduceLROnPlateau' (default), 'CosineAnnealingWarmRestarts', 'StepLR', etc.
  • **kwargs: Any other argument for the learning rate scheduler
scheduler_kwargs Optional[Dict[str, Any]]

Dictionnary for the scheduling of the learning rate modification used by pytorch-lightning

  • monitor str: metric to track (Default="loss/val")
  • interval str: Whether to look at iterations or epochs (Default="epoch")
  • strict bool: if set to True will enforce that value specified in monitor is available while trying to call scheduler.step(), and stop training if not found. If False will only give a warning and continue training (without calling the scheduler). (Default=True)
  • frequency int: TODO: NOT REALLY SURE HOW IT WORKS! (Default=1)
scheduler_class Optional[Union[str, Type]]

The class to use for the scheduler, or the str representing the scheduler.
