graphium.trainer¶
Code for training models
Predictor¶
graphium.trainer.predictor
¶
PredictorModule
¶
Bases: lightning.LightningModule
__init__(model_class, model_kwargs, loss_fun, random_seed=42, optim_kwargs=None, torch_scheduler_kwargs=None, scheduler_kwargs=None, target_nan_mask=None, multitask_handling=None, metrics=None, metrics_on_progress_bar=[], metrics_on_training_set=None, flag_kwargs=None, task_norms=None)
¶
The Lightning module responsible for handling the predictions, losses, metrics, optimization, etc. It works in a multi-task setting, with different losses and metrics per class
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_class |
Type[nn.Module]
|
The torch Module containing the main forward function |
required |
model_kwargs |
Dict[str, Any]
|
The arguments to initialize the model from |
required |
loss_fun |
Dict[str, Union[str, Callable]]
|
A |
required |
random_seed |
int
|
The seed for random initialization |
42
|
optim_kwargs |
Optional[Dict[str, Any]]
|
The optimization arguments. See class |
None
|
torch_scheduler_kwargs |
Optional[Dict[str, Any]]
|
The torch scheduler arguments. See class |
None
|
scheduler_kwargs |
Optional[Dict[str, Any]]
|
The lightning scheduler arguments. See class |
None
|
target_nan_mask |
Optional[Union[str, int]]
|
How to handle the NaNs. See |
None
|
metrics |
Dict[str, Callable]
|
A |
None
|
metrics_on_progress_bar |
Dict[str, List[str]]
|
A |
[]
|
metrics_on_training_set |
Optional[Dict[str, List[str]]]
|
A |
None
|
flag_kwargs |
Dict[str, Any]
|
Arguments related to using the FLAG adversarial augmentation |
None
|
task_norms |
Optional[Dict[Callable, Any]]
|
the normalization for each task |
None
|
__repr__()
¶
Controls how the class is printed
compute_loss(preds, targets, weights, loss_fun, target_nan_mask=None, multitask_handling=None)
staticmethod
¶
Compute the loss using the specified loss function, and dealing with
the nans in the targets
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preds |
Dict[str, Tensor]
|
Predicted values |
required |
targets |
Dict[str, Tensor]
|
Target values |
required |
weights |
Optional[Tensor]
|
No longer supported, will raise an error. |
required |
target_nan_mask |
Optional[Union[str, int]]
|
|
None
|
multitask_handling |
Optional[str]
|
|
None
|
loss_fun |
Dict[str, Callable]
|
Loss function to use |
required |
Returns:
Name | Type | Description |
---|---|---|
Tensor |
Tuple[Tensor, Dict[str, Tensor]]
|
weighted_loss: Resulting weighted loss all_task_losses: Loss per task |
flag_step(batch, step_name, to_cpu)
¶
Perform adversarial data agumentation during one training step using FLAG. Paper: https://arxiv.org/abs/2010.09891 Github: https://github.com/devnkong/FLAG
forward(inputs)
¶
Returns the result of self.model.forward(*inputs)
on the inputs.
If the output of out = self.model.forward
is a dictionary with a "preds"
key,
it is returned directly. Otherwise, a new dictionary is created and
returns {"preds": out}
.
Returns:
Type | Description |
---|---|
Dict[str, Union[Tensor, Dict[str, Tensor], Dict[str, Dict[str, Tensor]]]]
|
A dict with a key |
get_num_graphs(data)
¶
Method to compute number of graphs in a Batch. Essential to estimate throughput in graphs/s.
list_pretrained_models()
staticmethod
¶
List available pretrained models.
load_pretrained_models(name)
staticmethod
¶
Load a pretrained model from its name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str
|
Name of the model to load. List available
from |
required |
Metrics¶
graphium.trainer.metrics
¶
MetricWrapper
¶
Allows to initialize a metric from a name or Callable, and initialize the
Thresholder
in case the metric requires a threshold.
__call__(preds, target)
¶
Compute the metric with the method self.compute
__getstate__()
¶
Serialize the class for pickling.
__init__(metric, threshold_kwargs=None, target_nan_mask=None, multitask_handling=None, squeeze_targets=False, target_to_int=False, **kwargs)
¶
Parameters
metric:
The metric to use. See METRICS_DICT
threshold_kwargs:
If `None`, no threshold is applied.
Otherwise, we use the class `Thresholder` is initialized with the
provided argument, and called before the `compute`
target_nan_mask:
- None: Do not change behaviour if there are NaNs
- int, float: Value used to replace NaNs. For example, if `target_nan_mask==0`, then
all NaNs will be replaced by zeros
- 'ignore': The NaN values will be removed from the tensor before computing the metrics.
Must be coupled with the `multitask_handling='flatten'` or `multitask_handling='mean-per-label'`.
multitask_handling:
- None: Do not process the tensor before passing it to the metric.
Cannot use the option `multitask_handling=None` when `target_nan_mask=ignore`.
Use either 'flatten' or 'mean-per-label'.
- 'flatten': Flatten the tensor to produce the equivalent of a single task
- 'mean-per-label': Loop all the labels columns, process them as a single task,
and average the results over each task
*This option might slow down the computation if there are too many labels*
squeeze_targets:
If true, targets will be squeezed prior to computing the metric.
Required in classifigression task.
target_to_int:
If true, targets will be converted to integers prior to computing the metric.
kwargs:
Other arguments to call with the metric
__repr__()
¶
Control how the class is printed
__setstate__(state)
¶
Reload the class from pickling.
compute(preds, target)
¶
Compute the metric, apply the thresholder if provided, and manage the NaNs
Predictor Summaries¶
graphium.trainer.predictor_summaries
¶
Classes to store information about resulting evaluation metrics when using a Predictor Module.
Summary
¶
Bases: SummaryInterface
Results
¶
__init__(targets=None, predictions=None, loss=None, metrics=None, monitored_metric=None, n_epochs=None)
¶
This inner class is used as a container for storing the results of the summary.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
targets |
Tensor
|
the targets |
None
|
predictions |
Tensor
|
the prediction tensor |
None
|
loss |
float
|
the loss, float or tensor |
None
|
metrics |
dict
|
the metrics |
None
|
monitored_metric |
str
|
the monitored metric |
None
|
n_epochs |
int
|
the number of epochs |
None
|
__init__(loss_fun, metrics, metrics_on_training_set=[], metrics_on_progress_bar=[], monitor='loss', mode='min', task_name=None)
¶
A container to be used by the Predictor Module that stores the results for the given metrics on the predictions and targets provided.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fun |
Union[str, Callable]
|
|
required |
metrics |
Dict[str, Callable]
|
|
required |
metrics_on_training_set |
List[str]
|
|
[]
|
metrics_on_progress_bar |
List[str]
|
|
[]
|
monitor |
str
|
|
'loss'
|
task_name |
Optional[str]
|
|
None
|
get_best_results(step_name)
¶
retrieve the best results for a given step
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str
|
which stage you are in, e.g. "train" |
required |
Returns:
Type | Description |
---|---|
the best results for the given step |
get_dict_summary()
¶
retrieve the full summary in a dictionary
Returns:
Type | Description |
---|---|
Dict[str, Any]
|
the full summary in a dictionary |
get_metrics_logs()
¶
Get the data about metrics to log. Note: This function requires that self.update_predictor_state() be called before it.
Returns:
Type | Description |
---|---|
Dict[str, Any]
|
A dictionary of metrics to log. |
get_results(step_name)
¶
retrieve the results for a given step
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str
|
which stage you are in, e.g. "train" |
required |
Returns:
Type | Description |
---|---|
the results for the given step |
get_results_on_progress_bar(step_name)
¶
retrieve the results to be displayed on the progress bar for a given step
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str
|
which stage you are in, e.g. "train" |
required |
Returns:
Type | Description |
---|---|
Dict[str, Tensor]
|
the results to be displayed on the progress bar for the given step |
is_best_epoch(step_name, loss, metrics)
¶
check if the current epoch is the best epoch based on self.mode criteria
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str
|
which stage you are in, e.g. "train" |
required |
loss |
Tensor
|
the loss tensor |
required |
metrics |
Dict[str, Tensor]
|
a dictionary of metrics |
required |
set_results(metrics)
¶
set the reults from the metrics [!] This function requires that self.update_predictor_state() be called before it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metrics |
Dict[str, Tensor]
|
a dictionary of metrics |
required |
update_predictor_state(step_name, targets, predictions, loss, n_epochs)
¶
update the state of the predictor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str
|
which stage you are in, e.g. "train" |
required |
targets |
Tensor
|
the targets tensor |
required |
predictions |
Tensor
|
the predictions tensor |
required |
loss |
Tensor
|
the loss tensor |
required |
n_epochs |
int
|
the number of epochs |
required |
SummaryInterface
¶
Bases: object
An interface to define the functions implemented by summary classes that implement SummaryInterface.
TaskSummaries
¶
Bases: SummaryInterface
__init__(task_loss_fun, task_metrics, task_metrics_on_training_set, task_metrics_on_progress_bar, monitor='loss', mode='min')
¶
class to store the summaries of the tasks
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task_loss_fun |
Callable
|
the loss function for each task |
required |
task_metrics |
Dict[str, Callable]
|
the metrics for each task |
required |
task_metrics_on_training_set |
List[str]
|
the metrics to use on the training set |
required |
task_metrics_on_progress_bar |
List[str]
|
the metrics to use on the progress bar |
required |
monitor |
str
|
the metric to monitor |
'loss'
|
mode |
str
|
the mode of the metric to monitor |
'min'
|
concatenate_metrics_logs(metrics_logs)
¶
concatenate the metrics logs
Parameters:
Name | Type | Description | Default |
---|---|---|---|
metrics_logs |
Dict[str, Dict[str, Tensor]]
|
the metrics logs |
required |
Returns:
Type | Description |
---|---|
Dict[str, Tensor]
|
the concatenated metrics logs |
get_best_results(step_name)
¶
retrieve the best results
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str
|
the name of the step, i.e. "train" |
required |
Returns:
Type | Description |
---|---|
Dict[str, Dict[str, Any]]
|
the best results |
get_dict_summary()
¶
get task summaries in a dictionary
Returns:
Type | Description |
---|---|
Dict[str, Dict[str, Any]]
|
the task summaries |
get_metrics_logs()
¶
get the logs for the metrics
Returns:
Type | Description |
---|---|
Dict[str, Dict[str, Tensor]]
|
the task logs for the metrics |
get_results(step_name)
¶
retrieve the results
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str
|
the name of the step, i.e. "train" |
required |
Returns:
Type | Description |
---|---|
Dict[str, Dict[str, Any]]
|
the results |
get_results_on_progress_bar(step_name)
¶
return all results from all tasks for the progress bar Combine the dictionaries. Instead of having keys as task names, we merge all the task-specific dictionaries.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str
|
the name of the step |
required |
Returns:
Type | Description |
---|---|
Dict[str, Dict[str, Any]]
|
the results for the progress bar |
metric_log_name(task_name, metric_name, step_name)
¶
print the metric name, task name and step name
Returns:
Type | Description |
---|---|
str
|
the metric name, task name and step name |
set_results(task_metrics)
¶
set the results for all tasks
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task_metrics |
Dict[str, Dict[str, Tensor]]
|
the metrics for each task |
required |
update_predictor_state(step_name, targets, predictions, loss, task_losses, n_epochs)
¶
update the state for all predictors
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str
|
the name of the step |
required |
targets |
Dict[str, Tensor]
|
the target tensors |
required |
predictions |
Dict[str, Tensor]
|
the prediction tensors |
required |
loss |
Tensor
|
the loss tensor |
required |
task_losses |
Dict[str, Tensor]
|
the task losses |
required |
n_epochs |
int
|
the number of epochs |
required |
Predictor Options¶
graphium.trainer.predictor_options
¶
Data classes to group together related arguments for the creation of a Predictor Module.
EvalOptions
dataclass
¶
This data class stores the arguments necessary to instantiate a model for the Predictor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fun |
Union[str, Dict, Callable]
|
Loss function used during training.
Acceptable strings are graphium.utils.spaces.LOSS_DICT.keys().
If a dict, must contain a 'name' key with one of the acceptable loss function strings
as a value. The rest of the dict will be used as the arguments passed to the loss object.
Otherwise, a callable object must be provided, with a method |
required |
metrics |
Dict[str, Callable]
|
A dictionnary of metrics to compute on the prediction, other than the loss function. These metrics will be logged into WandB or other. |
None
|
metrics_on_progress_bar |
List[str]
|
The metrics names from |
field(default_factory=List[str])
|
metrics_on_training_set |
Optional[List[str]]
|
The metrics names from |
None
|
check_metrics_validity()
¶
Check that the metrics for the progress_par and training_set are valid
parse_loss_fun(loss_fun)
staticmethod
¶
Parse the loss function from a string or a dict
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fun |
Union[str, Dict, Callable]
|
A callable corresponding to the loss function, a string specifying the loss
function from |
required |
Returns:
Name | Type | Description |
---|---|---|
Callable |
Callable
|
Function or callable to compute the loss, takes |
FlagOptions
dataclass
¶
This data class stores the arguments necessary to instantiate a model for the Predictor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flag_kwargs |
Dict[str, Any]
|
Keyword arguments used for FLAG, and adversarial data augmentation for graph networks. See: https://arxiv.org/abs/2010.09891
|
None
|
ModelOptions
dataclass
¶
This data class stores the arguments necessary to instantiate a model for the Predictor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_class |
Type[nn.Module]
|
pytorch module used to create a model |
required |
model_kwargs |
Dict[str, Any]
|
Key-word arguments used to initialize the model from |
required |
OptimOptions
dataclass
¶
This data class stores the arguments necessary to configure the optimizer for the Predictor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optim_kwargs |
Optional[Dict[str, Any]]
|
Dictionnary used to initialize the optimizer, with possible keys below.
|
None
|
torch_scheduler_kwargs |
Optional[Dict[str, Any]]
|
Dictionnary for the scheduling of learning rate, with possible keys below.
|
None
|
scheduler_kwargs |
Optional[Dict[str, Any]]
|
Dictionnary for the scheduling of the learning rate modification used by pytorch-lightning
|
None
|
scheduler_class |
Optional[Union[str, Type]]
|
The class to use for the scheduler, or the str representing the scheduler. |
None
|