graphium.trainer¶
Code for training models
Predictor¶
graphium.trainer.predictor
¶
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.
PredictorModule
¶
Bases: LightningModule
__init__(model_class, model_kwargs, loss_fun, task_levels, random_seed=42, featurization=None, optim_kwargs=None, torch_scheduler_kwargs=None, scheduler_kwargs=None, target_nan_mask=None, multitask_handling=None, metrics=None, metrics_on_progress_bar=[], metrics_on_training_set=None, flag_kwargs=None, task_norms=None, metrics_every_n_train_steps=None, replicas=1, gradient_acc=1, global_bs=1)
¶
The Lightning module responsible for handling the predictions, losses, metrics, optimization, etc. It works in a multi-task setting, with different losses and metrics per class
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_class |
Type[Module]
|
The torch Module containing the main forward function |
required |
model_kwargs |
Dict[str, Any]
|
The arguments to initialize the model from |
required |
loss_fun |
Dict[str, Union[str, Callable]]
|
A |
required |
random_seed |
int
|
The seed for random initialization |
42
|
optim_kwargs |
Optional[Dict[str, Any]]
|
The optimization arguments. See class |
None
|
torch_scheduler_kwargs |
Optional[Dict[str, Any]]
|
The torch scheduler arguments. See class |
None
|
scheduler_kwargs |
Optional[Dict[str, Any]]
|
The lightning scheduler arguments. See class |
None
|
target_nan_mask |
Optional[Union[str, int]]
|
How to handle the NaNs. See |
None
|
metrics |
Dict[str, Callable]
|
A |
None
|
metrics_on_progress_bar |
Dict[str, List[str]]
|
A |
[]
|
metrics_on_training_set |
Optional[Dict[str, List[str]]]
|
A |
None
|
flag_kwargs |
Dict[str, Any]
|
Arguments related to using the FLAG adversarial augmentation |
None
|
task_norms |
Optional[Dict[Callable, Any]]
|
the normalization for each task |
None
|
metrics_every_n_train_steps |
Optional[int]
|
Compute and log metrics every n training steps.
Set to |
None
|
__repr__()
¶
Controls how the class is printed
compute_loss(preds, targets, weights, loss_fun, target_nan_mask=None, multitask_handling=None)
staticmethod
¶
Compute the loss using the specified loss function, and dealing with
the nans in the targets
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preds |
Dict[str, Tensor]
|
Predicted values |
required |
targets |
Dict[str, Tensor]
|
Target values |
required |
weights |
Optional[Tensor]
|
No longer supported, will raise an error. |
required |
target_nan_mask |
Optional[Union[str, int]]
|
|
None
|
multitask_handling |
Optional[str]
|
|
None
|
loss_fun |
Dict[str, Callable]
|
Loss function to use |
required |
Returns:
Name | Type | Description |
---|---|---|
Tensor |
Tuple[Tensor, Dict[str, Tensor]]
|
weighted_loss: Resulting weighted loss all_task_losses: Loss per task |
flag_step(batch, step_name, to_cpu)
¶
Perform adversarial data agumentation during one training step using FLAG. Paper: https://arxiv.org/abs/2010.09891 Github: https://github.com/devnkong/FLAG
forward(inputs)
¶
Returns the result of self.model.forward(*inputs)
on the inputs.
If the output of out = self.model.forward
is a dictionary with a "preds"
key,
it is returned directly. Otherwise, a new dictionary is created and
returns {"preds": out}
.
Returns:
Type | Description |
---|---|
Dict[str, Union[Tensor, Dict[str, Tensor], Dict[str, Dict[str, Tensor]]]]
|
A dict with a key |
get_num_graphs(data)
¶
Method to compute number of graphs in a Batch. Essential to estimate throughput in graphs/s.
list_pretrained_models()
staticmethod
¶
List available pretrained models.
load_pretrained_model(name_or_path, device=None)
staticmethod
¶
Load a pretrained model from its name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
Name of the model to load or a valid checkpoint path. List available
from |
required |
Metrics¶
graphium.trainer.metrics
¶
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.
MetricWrapper
¶
Allows to initialize a metric from a name or Callable, and initialize the
Thresholder
in case the metric requires a threshold.
__call__(preds, target)
¶
Compute the metric with the method self.compute
__getstate__()
¶
Serialize the class for pickling.
__init__(metric, threshold_kwargs=None, target_nan_mask=None, multitask_handling=None, squeeze_targets=False, target_to_int=False, **kwargs)
¶
Parameters
metric:
The metric to use. See METRICS_DICT
threshold_kwargs:
If `None`, no threshold is applied.
Otherwise, we use the class `Thresholder` is initialized with the
provided argument, and called before the `compute`
target_nan_mask:
- None: Do not change behaviour if there are NaNs
- int, float: Value used to replace NaNs. For example, if `target_nan_mask==0`, then
all NaNs will be replaced by zeros
- 'ignore': The NaN values will be removed from the tensor before computing the metrics.
Must be coupled with the `multitask_handling='flatten'` or `multitask_handling='mean-per-label'`.
multitask_handling:
- None: Do not process the tensor before passing it to the metric.
Cannot use the option `multitask_handling=None` when `target_nan_mask=ignore`.
Use either 'flatten' or 'mean-per-label'.
- 'flatten': Flatten the tensor to produce the equivalent of a single task
- 'mean-per-label': Loop all the labels columns, process them as a single task,
and average the results over each task
*This option might slow down the computation if there are too many labels*
squeeze_targets:
If true, targets will be squeezed prior to computing the metric.
Required in classifigression task.
target_to_int:
If true, targets will be converted to integers prior to computing the metric.
kwargs:
Other arguments to call with the metric
__repr__()
¶
Control how the class is printed
__setstate__(state)
¶
Reload the class from pickling.
compute(preds, target)
¶
Compute the metric, apply the thresholder if provided, and manage the NaNs
Predictor Summaries¶
graphium.trainer.predictor_summaries
¶
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.
Summary
¶
Bases: SummaryInterface
Results
¶
__init__(targets=None, preds=None, loss=None, metrics=None, monitored_metric=None, n_epochs=None)
¶
This inner class is used as a container for storing the results of the summary. Parameters: targets: the targets preds: the prediction tensor loss: the loss, float or tensor metrics: the metrics monitored_metric: the monitored metric n_epochs: the number of epochs
__init__(loss_fun, metrics, metrics_on_training_set=[], metrics_on_progress_bar=[], monitor='loss', mode='min', task_name=None)
¶
A container to be used by the Predictor Module that stores the results for the given metrics on the predictions and targets provided.
Parameters:
loss_fun:
Loss function used during training. Acceptable strings are 'mse', 'bce', 'mae', 'cosine'.
Otherwise, a callable object must be provided, with a method loss_fun._get_name()
.
metrics:
A dictionnary of metrics to compute on the prediction, other than the loss function.
These metrics will be logged into WandB or similar.
metrics_on_training_set:
The metrics names from `metrics` to be computed on the training set for each iteration.
If `None`, all the metrics are computed. Using less metrics can significantly improve
performance, depending on the number of readouts.
metrics_on_progress_bar:
The metrics names from `metrics` to display also on the progress bar of the training
monitor:
`str` metric to track (Default=`"loss/val"`)
task_name:
name of the task (Default=`None`)
get_best_results(step_name)
¶
retrieve the best results for a given step Parameters: step_name: which stage you are in, e.g. "train" Returns: the best results for the given step
get_dict_summary()
¶
retrieve the full summary in a dictionary Returns: the full summary in a dictionary
get_metrics_logs()
¶
Get the data about metrics to log. Note: This function requires that self.update_predictor_state() be called before it. Returns: A dictionary of metrics to log.
get_results(step_name)
¶
retrieve the results for a given step Parameters: step_name: which stage you are in, e.g. "train" Returns: the results for the given step
get_results_on_progress_bar(step_name)
¶
retrieve the results to be displayed on the progress bar for a given step Parameters: step_name: which stage you are in, e.g. "train" Returns: the results to be displayed on the progress bar for the given step
is_best_epoch(step_name, loss, metrics)
¶
check if the current epoch is the best epoch based on self.mode criteria Parameters: step_name: which stage you are in, e.g. "train" loss: the loss tensor metrics: a dictionary of metrics
set_results(metrics)
¶
set the reults from the metrics [!] This function requires that self.update_predictor_state() be called before it. Parameters: metrics: a dictionary of metrics
update_predictor_state(step_name, targets, preds, loss, n_epochs)
¶
update the state of the predictor Parameters: step_name: which stage you are in, e.g. "train" targets: the targets tensor predictions: the predictions tensor loss: the loss tensor n_epochs: the number of epochs
SummaryInterface
¶
Bases: object
An interface to define the functions implemented by summary classes that implement SummaryInterface.
TaskSummaries
¶
Bases: SummaryInterface
__init__(task_loss_fun, task_metrics, task_metrics_on_training_set, task_metrics_on_progress_bar, monitor='loss', mode='min')
¶
class to store the summaries of the tasks Parameters: task_loss_fun: the loss function for each task task_metrics: the metrics for each task task_metrics_on_training_set: the metrics to use on the training set task_metrics_on_progress_bar: the metrics to use on the progress bar monitor: the metric to monitor mode: the mode of the metric to monitor
concatenate_metrics_logs(metrics_logs)
¶
concatenate the metrics logs Parameters: metrics_logs: the metrics logs Returns: the concatenated metrics logs
get_best_results(step_name)
¶
retrieve the best results Parameters: step_name: the name of the step, i.e. "train" Returns: the best results
get_dict_summary()
¶
get task summaries in a dictionary Returns: the task summaries
get_metrics_logs()
¶
get the logs for the metrics Returns: the task logs for the metrics
get_results(step_name)
¶
retrieve the results Parameters: step_name: the name of the step, i.e. "train" Returns: the results
get_results_on_progress_bar(step_name)
¶
return all results from all tasks for the progress bar Combine the dictionaries. Instead of having keys as task names, we merge all the task-specific dictionaries. Parameters: step_name: the name of the step Returns: the results for the progress bar
metric_log_name(task_name, metric_name, step_name)
¶
print the metric name, task name and step name Returns: the metric name, task name and step name
set_results(task_metrics)
¶
set the results for all tasks Parameters: task_metrics: the metrics for each task
update_predictor_state(step_name, targets, preds, loss, task_losses, n_epochs)
¶
update the state for all predictors Parameters: step_name: the name of the step targets: the target tensors preds: the prediction tensors loss: the loss tensor task_losses: the task losses n_epochs: the number of epochs
Predictor Options¶
graphium.trainer.predictor_options
¶
Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore Limited.
Use of this software is subject to the terms and conditions outlined in the LICENSE file. Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without warranties of any kind.
Valence Labs, Recursion Pharmaceuticals and Graphcore Limited are not liable for any damages arising from its use. Refer to the LICENSE file for the full terms and conditions.
EvalOptions
dataclass
¶
This data class stores the arguments necessary to instantiate a model for the Predictor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fun |
Union[str, Dict, Callable]
|
Loss function used during training.
Acceptable strings are graphium.utils.spaces.LOSS_DICT.keys().
If a dict, must contain a 'name' key with one of the acceptable loss function strings
as a value. The rest of the dict will be used as the arguments passed to the loss object.
Otherwise, a callable object must be provided, with a method |
required |
metrics |
Dict[str, Callable]
|
A dictionnary of metrics to compute on the prediction, other than the loss function. These metrics will be logged into WandB or other. |
None
|
metrics_on_progress_bar |
List[str]
|
The metrics names from |
field(default_factory=List[str])
|
metrics_on_training_set |
Optional[List[str]]
|
The metrics names from |
None
|
check_metrics_validity()
¶
Check that the metrics for the progress_par and training_set are valid
parse_loss_fun(loss_fun)
staticmethod
¶
Parse the loss function from a string or a dict
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fun |
Union[str, Dict, Callable]
|
A callable corresponding to the loss function, a string specifying the loss
function from |
required |
Returns:
Name | Type | Description |
---|---|---|
Callable |
Callable
|
Function or callable to compute the loss, takes |
FlagOptions
dataclass
¶
This data class stores the arguments necessary to instantiate a model for the Predictor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flag_kwargs |
Dict[str, Any]
|
Keyword arguments used for FLAG, and adversarial data augmentation for graph networks. See: https://arxiv.org/abs/2010.09891
|
None
|
ModelOptions
dataclass
¶
This data class stores the arguments necessary to instantiate a model for the Predictor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_class |
Type[Module]
|
pytorch module used to create a model |
required |
model_kwargs |
Dict[str, Any]
|
Key-word arguments used to initialize the model from |
required |
OptimOptions
dataclass
¶
This data class stores the arguments necessary to configure the optimizer for the Predictor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optim_kwargs |
Optional[Dict[str, Any]]
|
Dictionnary used to initialize the optimizer, with possible keys below.
|
None
|
torch_scheduler_kwargs |
Optional[Dict[str, Any]]
|
Dictionnary for the scheduling of learning rate, with possible keys below.
|
None
|
scheduler_kwargs |
Optional[Dict[str, Any]]
|
Dictionnary for the scheduling of the learning rate modification used by pytorch-lightning
|
None
|
scheduler_class |
Optional[Union[str, Type]]
|
The class to use for the scheduler, or the str representing the scheduler. |
None
|