Skip to content

graphium.trainer

Code for training models

Predictor


graphium.trainer.predictor

PredictorModule

Bases: lightning.LightningModule

__init__(model_class, model_kwargs, loss_fun, random_seed=42, optim_kwargs=None, torch_scheduler_kwargs=None, scheduler_kwargs=None, target_nan_mask=None, multitask_handling=None, metrics=None, metrics_on_progress_bar=[], metrics_on_training_set=None, flag_kwargs=None, task_norms=None)

The Lightning module responsible for handling the predictions, losses, metrics, optimization, etc. It works in a multi-task setting, with different losses and metrics per class

Parameters:

Name Type Description Default
model_class Type[nn.Module]

The torch Module containing the main forward function

required
model_kwargs Dict[str, Any]

The arguments to initialize the model from model_class

required
loss_fun Dict[str, Union[str, Callable]]

A dict[str, fun], where str is the task name and fun the loss function

required
random_seed int

The seed for random initialization

42
optim_kwargs Optional[Dict[str, Any]]

The optimization arguments. See class OptimOptions

None
torch_scheduler_kwargs Optional[Dict[str, Any]]

The torch scheduler arguments. See class OptimOptions

None
scheduler_kwargs Optional[Dict[str, Any]]

The lightning scheduler arguments. See class OptimOptions

None
target_nan_mask Optional[Union[str, int]]

How to handle the NaNs. See MetricsWrapper for options

None
metrics Dict[str, Callable]

A dict[str, fun], where str is the task name and fun the metric function

None
metrics_on_progress_bar Dict[str, List[str]]

A dict[str, list[str2], where str is the task name and str2 the metrics to include on the progress bar

[]
metrics_on_training_set Optional[Dict[str, List[str]]]

A dict[str, list[str2], where str is the task name and str2 the metrics to include on the training set

None
flag_kwargs Dict[str, Any]

Arguments related to using the FLAG adversarial augmentation

None
task_norms Optional[Dict[Callable, Any]]

the normalization for each task

None
__repr__()

Controls how the class is printed

compute_loss(preds, targets, weights, loss_fun, target_nan_mask=None, multitask_handling=None) staticmethod

Compute the loss using the specified loss function, and dealing with the nans in the targets.

Parameters:

Name Type Description Default
preds Dict[str, Tensor]

Predicted values

required
targets Dict[str, Tensor]

Target values

required
weights Optional[Tensor]

No longer supported, will raise an error.

required
target_nan_mask Optional[Union[str, int]]
  • None: Do not change behaviour if there are NaNs

  • int, float: Value used to replace NaNs. For example, if target_nan_mask==0, then all NaNs will be replaced by zeros

  • 'ignore': The NaN values will be removed from the tensor before computing the metrics. Must be coupled with the multitask_handling='flatten' or multitask_handling='mean-per-label'.

None
multitask_handling Optional[str]
  • None: Do not process the tensor before passing it to the metric. Cannot use the option multitask_handling=None when target_nan_mask=ignore. Use either 'flatten' or 'mean-per-label'.

  • 'flatten': Flatten the tensor to produce the equivalent of a single task

  • 'mean-per-label': Loop all the labels columns, process them as a single task, and average the results over each task This option might slowdown the computation if there are too many labels

None
loss_fun Dict[str, Callable]

Loss function to use

required

Returns:

Name Type Description
Tensor Tuple[Tensor, Dict[str, Tensor]]

weighted_loss: Resulting weighted loss all_task_losses: Loss per task

flag_step(batch, step_name, to_cpu)

Perform adversarial data agumentation during one training step using FLAG. Paper: https://arxiv.org/abs/2010.09891 Github: https://github.com/devnkong/FLAG

forward(inputs)

Returns the result of self.model.forward(*inputs) on the inputs. If the output of out = self.model.forward is a dictionary with a "preds" key, it is returned directly. Otherwise, a new dictionary is created and returns {"preds": out}.

Returns:

Type Description
Dict[str, Union[Tensor, Dict[str, Tensor], Dict[str, Dict[str, Tensor]]]]

A dict with a key "preds" representing the prediction of the network.

get_num_graphs(data)

Method to compute number of graphs in a Batch. Essential to estimate throughput in graphs/s.

list_pretrained_models() staticmethod

List available pretrained models.

load_pretrained_models(name) staticmethod

Load a pretrained model from its name.

Parameters:

Name Type Description Default
name str

Name of the model to load. List available from graphium.trainer.PredictorModule.list_pretrained_models().

required

Metrics


graphium.trainer.metrics

MetricWrapper

Allows to initialize a metric from a name or Callable, and initialize the Thresholder in case the metric requires a threshold.

__call__(preds, target)

Compute the metric with the method self.compute

__getstate__()

Serialize the class for pickling.

__init__(metric, threshold_kwargs=None, target_nan_mask=None, multitask_handling=None, squeeze_targets=False, target_to_int=False, **kwargs)

Parameters metric: The metric to use. See METRICS_DICT

threshold_kwargs:
    If `None`, no threshold is applied.
    Otherwise, we use the class `Thresholder` is initialized with the
    provided argument, and called before the `compute`

target_nan_mask:

    - None: Do not change behaviour if there are NaNs

    - int, float: Value used to replace NaNs. For example, if `target_nan_mask==0`, then
      all NaNs will be replaced by zeros

    - 'ignore': The NaN values will be removed from the tensor before computing the metrics.
      Must be coupled with the `multitask_handling='flatten'` or `multitask_handling='mean-per-label'`.

multitask_handling:
    - None: Do not process the tensor before passing it to the metric.
      Cannot use the option `multitask_handling=None` when `target_nan_mask=ignore`.
      Use either 'flatten' or 'mean-per-label'.

    - 'flatten': Flatten the tensor to produce the equivalent of a single task

    - 'mean-per-label': Loop all the labels columns, process them as a single task,
        and average the results over each task
      *This option might slow down the computation if there are too many labels*

squeeze_targets:
    If true, targets will be squeezed prior to computing the metric.
    Required in classifigression task.

target_to_int:
    If true, targets will be converted to integers prior to computing the metric.

kwargs:
    Other arguments to call with the metric
__repr__()

Control how the class is printed

__setstate__(state)

Reload the class from pickling.

compute(preds, target)

Compute the metric, apply the thresholder if provided, and manage the NaNs

Thresholder

__getstate__()

Serialize the class for pickling.

__repr__()

Control how the class is printed

Predictor Summaries


graphium.trainer.predictor_summaries

Classes to store information about resulting evaluation metrics when using a Predictor Module.

Summary

Bases: SummaryInterface

Results
__init__(targets=None, predictions=None, loss=None, metrics=None, monitored_metric=None, n_epochs=None)

This inner class is used as a container for storing the results of the summary.

Parameters:

Name Type Description Default
targets Tensor

the targets

None
predictions Tensor

the prediction tensor

None
loss float

the loss, float or tensor

None
metrics dict

the metrics

None
monitored_metric str

the monitored metric

None
n_epochs int

the number of epochs

None
__init__(loss_fun, metrics, metrics_on_training_set=[], metrics_on_progress_bar=[], monitor='loss', mode='min', task_name=None)

A container to be used by the Predictor Module that stores the results for the given metrics on the predictions and targets provided.

Parameters:

Name Type Description Default
loss_fun Union[str, Callable]
required
metrics Dict[str, Callable]
required
metrics_on_training_set List[str]
[]
metrics_on_progress_bar List[str]
[]
monitor str
'loss'
task_name Optional[str]
None
get_best_results(step_name)

retrieve the best results for a given step

Parameters:

Name Type Description Default
step_name str

which stage you are in, e.g. "train"

required

Returns:

Type Description

the best results for the given step

get_dict_summary()

retrieve the full summary in a dictionary

Returns:

Type Description
Dict[str, Any]

the full summary in a dictionary

get_metrics_logs()

Get the data about metrics to log. Note: This function requires that self.update_predictor_state() be called before it.

Returns:

Type Description
Dict[str, Any]

A dictionary of metrics to log.

get_results(step_name)

retrieve the results for a given step

Parameters:

Name Type Description Default
step_name str

which stage you are in, e.g. "train"

required

Returns:

Type Description

the results for the given step

get_results_on_progress_bar(step_name)

retrieve the results to be displayed on the progress bar for a given step

Parameters:

Name Type Description Default
step_name str

which stage you are in, e.g. "train"

required

Returns:

Type Description
Dict[str, Tensor]

the results to be displayed on the progress bar for the given step

is_best_epoch(step_name, loss, metrics)

check if the current epoch is the best epoch based on self.mode criteria

Parameters:

Name Type Description Default
step_name str

which stage you are in, e.g. "train"

required
loss Tensor

the loss tensor

required
metrics Dict[str, Tensor]

a dictionary of metrics

required
set_results(metrics)

set the reults from the metrics [!] This function requires that self.update_predictor_state() be called before it.

Parameters:

Name Type Description Default
metrics Dict[str, Tensor]

a dictionary of metrics

required
update_predictor_state(step_name, targets, predictions, loss, n_epochs)

update the state of the predictor

Parameters:

Name Type Description Default
step_name str

which stage you are in, e.g. "train"

required
targets Tensor

the targets tensor

required
predictions Tensor

the predictions tensor

required
loss Tensor

the loss tensor

required
n_epochs int

the number of epochs

required

SummaryInterface

Bases: object

An interface to define the functions implemented by summary classes that implement SummaryInterface.

TaskSummaries

Bases: SummaryInterface

__init__(task_loss_fun, task_metrics, task_metrics_on_training_set, task_metrics_on_progress_bar, monitor='loss', mode='min')

class to store the summaries of the tasks

Parameters:

Name Type Description Default
task_loss_fun Callable

the loss function for each task

required
task_metrics Dict[str, Callable]

the metrics for each task

required
task_metrics_on_training_set List[str]

the metrics to use on the training set

required
task_metrics_on_progress_bar List[str]

the metrics to use on the progress bar

required
monitor str

the metric to monitor

'loss'
mode str

the mode of the metric to monitor

'min'
concatenate_metrics_logs(metrics_logs)

concatenate the metrics logs

Parameters:

Name Type Description Default
metrics_logs Dict[str, Dict[str, Tensor]]

the metrics logs

required

Returns:

Type Description
Dict[str, Tensor]

the concatenated metrics logs

get_best_results(step_name)

retrieve the best results

Parameters:

Name Type Description Default
step_name str

the name of the step, i.e. "train"

required

Returns:

Type Description
Dict[str, Dict[str, Any]]

the best results

get_dict_summary()

get task summaries in a dictionary

Returns:

Type Description
Dict[str, Dict[str, Any]]

the task summaries

get_metrics_logs()

get the logs for the metrics

Returns:

Type Description
Dict[str, Dict[str, Tensor]]

the task logs for the metrics

get_results(step_name)

retrieve the results

Parameters:

Name Type Description Default
step_name str

the name of the step, i.e. "train"

required

Returns:

Type Description
Dict[str, Dict[str, Any]]

the results

get_results_on_progress_bar(step_name)

return all results from all tasks for the progress bar Combine the dictionaries. Instead of having keys as task names, we merge all the task-specific dictionaries.

Parameters:

Name Type Description Default
step_name str

the name of the step

required

Returns:

Type Description
Dict[str, Dict[str, Any]]

the results for the progress bar

metric_log_name(task_name, metric_name, step_name)

print the metric name, task name and step name

Returns:

Type Description
str

the metric name, task name and step name

set_results(task_metrics)

set the results for all tasks

Parameters:

Name Type Description Default
task_metrics Dict[str, Dict[str, Tensor]]

the metrics for each task

required
update_predictor_state(step_name, targets, predictions, loss, task_losses, n_epochs)

update the state for all predictors

Parameters:

Name Type Description Default
step_name str

the name of the step

required
targets Dict[str, Tensor]

the target tensors

required
predictions Dict[str, Tensor]

the prediction tensors

required
loss Tensor

the loss tensor

required
task_losses Dict[str, Tensor]

the task losses

required
n_epochs int

the number of epochs

required

Predictor Options


graphium.trainer.predictor_options

Data classes to group together related arguments for the creation of a Predictor Module.

EvalOptions dataclass

This data class stores the arguments necessary to instantiate a model for the Predictor.

Parameters:

Name Type Description Default
loss_fun Union[str, Dict, Callable]

Loss function used during training. Acceptable strings are graphium.utils.spaces.LOSS_DICT.keys(). If a dict, must contain a 'name' key with one of the acceptable loss function strings as a value. The rest of the dict will be used as the arguments passed to the loss object. Otherwise, a callable object must be provided, with a method loss_fun._get_name().

required
metrics Dict[str, Callable]

A dictionnary of metrics to compute on the prediction, other than the loss function. These metrics will be logged into WandB or other.

None
metrics_on_progress_bar List[str]

The metrics names from metrics to display also on the progress bar of the training

field(default_factory=List[str])
metrics_on_training_set Optional[List[str]]

The metrics names from metrics to be computed on the training set for each iteration. If None, all the metrics are computed. Using less metrics can significantly improve performance, depending on the number of readouts.

None
check_metrics_validity()

Check that the metrics for the progress_par and training_set are valid

parse_loss_fun(loss_fun) staticmethod

Parse the loss function from a string or a dict

Parameters:

Name Type Description Default
loss_fun Union[str, Dict, Callable]

A callable corresponding to the loss function, a string specifying the loss function from LOSS_DICT, or a dict containing a key 'name' specifying the loss function, and the rest of the dict used as the arguments for the loss. Accepted strings are: graphium.utils.spaces.LOSS_DICT.keys().

required

Returns:

Name Type Description
Callable Callable

Function or callable to compute the loss, takes preds and targets as inputs.

FlagOptions dataclass

This data class stores the arguments necessary to instantiate a model for the Predictor.

Parameters:

Name Type Description Default
flag_kwargs Dict[str, Any]

Keyword arguments used for FLAG, and adversarial data augmentation for graph networks. See: https://arxiv.org/abs/2010.09891

  • n_steps: An integer that specifies the number of ascent steps when running FLAG during training. Default value of 0 trains GNNs without FLAG, and any value greater than 0 will use FLAG with that many iterations.

  • alpha: A float that specifies the ascent step size when running FLAG. Default=0.01

None

ModelOptions dataclass

This data class stores the arguments necessary to instantiate a model for the Predictor.

Parameters:

Name Type Description Default
model_class Type[nn.Module]

pytorch module used to create a model

required
model_kwargs Dict[str, Any]

Key-word arguments used to initialize the model from model_class.

required

OptimOptions dataclass

This data class stores the arguments necessary to configure the optimizer for the Predictor.

Parameters:

Name Type Description Default
optim_kwargs Optional[Dict[str, Any]]

Dictionnary used to initialize the optimizer, with possible keys below.

  • lr float: Learning rate (Default=1e-3)
  • weight_decay float: Weight decay used to regularize the optimizer (Default=0.)
None
torch_scheduler_kwargs Optional[Dict[str, Any]]

Dictionnary for the scheduling of learning rate, with possible keys below.

  • type str: Type of the learning rate to use from pytorch. Examples are 'ReduceLROnPlateau' (default), 'CosineAnnealingWarmRestarts', 'StepLR', etc.
  • **kwargs: Any other argument for the learning rate scheduler
None
scheduler_kwargs Optional[Dict[str, Any]]

Dictionnary for the scheduling of the learning rate modification used by pytorch-lightning

  • monitor str: metric to track (Default="loss/val")
  • interval str: Whether to look at iterations or epochs (Default="epoch")
  • strict bool: if set to True will enforce that value specified in monitor is available while trying to call scheduler.step(), and stop training if not found. If False will only give a warning and continue training (without calling the scheduler). (Default=True)
  • frequency int: TODO: NOT REALLY SURE HOW IT WORKS! (Default=1)
None
scheduler_class Optional[Union[str, Type]]

The class to use for the scheduler, or the str representing the scheduler.

None