graphium.trainer¶
Code for training models
Predictor¶
graphium.trainer.predictor
¶
PredictorModule
¶
Bases: LightningModule
__init__(model_class, model_kwargs, loss_fun, task_levels, random_seed=42, featurization=None, optim_kwargs=None, torch_scheduler_kwargs=None, scheduler_kwargs=None, target_nan_mask=None, multitask_handling=None, metrics=None, metrics_on_progress_bar=[], metrics_on_training_set=None, flag_kwargs=None, task_norms=None, metrics_every_n_train_steps=None)
¶
The Lightning module responsible for handling the predictions, losses, metrics, optimization, etc. It works in a multi-task setting, with different losses and metrics per class
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_class |
Type[Module]
|
The torch Module containing the main forward function |
required |
model_kwargs |
Dict[str, Any]
|
The arguments to initialize the model from |
required |
loss_fun |
Dict[str, Union[str, Callable]]
|
A |
required |
random_seed |
int
|
The seed for random initialization |
42
|
optim_kwargs |
Optional[Dict[str, Any]]
|
The optimization arguments. See class |
None
|
torch_scheduler_kwargs |
Optional[Dict[str, Any]]
|
The torch scheduler arguments. See class |
None
|
scheduler_kwargs |
Optional[Dict[str, Any]]
|
The lightning scheduler arguments. See class |
None
|
target_nan_mask |
Optional[Union[str, int]]
|
How to handle the NaNs. See |
None
|
metrics |
Dict[str, Callable]
|
A |
None
|
metrics_on_progress_bar |
Dict[str, List[str]]
|
A |
[]
|
metrics_on_training_set |
Optional[Dict[str, List[str]]]
|
A |
None
|
flag_kwargs |
Dict[str, Any]
|
Arguments related to using the FLAG adversarial augmentation |
None
|
task_norms |
Optional[Dict[Callable, Any]]
|
the normalization for each task |
None
|
metrics_every_n_train_steps |
Optional[int]
|
Compute and log metrics every n training steps.
Set to |
None
|
__repr__()
¶
Controls how the class is printed
compute_loss(preds, targets, weights, loss_fun, target_nan_mask=None, multitask_handling=None)
staticmethod
¶
Compute the loss using the specified loss function, and dealing with
the nans in the targets
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preds |
Dict[str, Tensor]
|
Predicted values |
required |
targets |
Dict[str, Tensor]
|
Target values |
required |
weights |
Optional[Tensor]
|
No longer supported, will raise an error. |
required |
target_nan_mask |
Optional[Union[str, int]]
|
|
None
|
multitask_handling |
Optional[str]
|
|
None
|
loss_fun |
Dict[str, Callable]
|
Loss function to use |
required |
Returns:
Name | Type | Description |
---|---|---|
Tensor |
Tuple[Tensor, Dict[str, Tensor]]
|
weighted_loss: Resulting weighted loss all_task_losses: Loss per task |
flag_step(batch, step_name, to_cpu)
¶
Perform adversarial data agumentation during one training step using FLAG. Paper: https://arxiv.org/abs/2010.09891 Github: https://github.com/devnkong/FLAG
forward(inputs)
¶
Returns the result of self.model.forward(*inputs)
on the inputs.
If the output of out = self.model.forward
is a dictionary with a "preds"
key,
it is returned directly. Otherwise, a new dictionary is created and
returns {"preds": out}
.
Returns:
Type | Description |
---|---|
Dict[str, Union[Tensor, Dict[str, Tensor], Dict[str, Dict[str, Tensor]]]]
|
A dict with a key |
get_num_graphs(data)
¶
Method to compute number of graphs in a Batch. Essential to estimate throughput in graphs/s.
list_pretrained_models()
staticmethod
¶
List available pretrained models.
load_pretrained_model(name_or_path, device=None)
staticmethod
¶
Load a pretrained model from its name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
Name of the model to load or a valid checkpoint path. List available
from |
required |
Metrics¶
graphium.trainer.metrics
¶
MetricWrapper
¶
Allows to initialize a metric from a name or Callable, and initialize the
Thresholder
in case the metric requires a threshold.
__call__(preds, target)
¶
Compute the metric with the method self.compute
__getstate__()
¶
Serialize the class for pickling.
__init__(metric, threshold_kwargs=None, target_nan_mask=None, multitask_handling=None, squeeze_targets=False, target_to_int=False, **kwargs)
¶
Parameters
metric:
The metric to use. See METRICS_DICT
threshold_kwargs:
If `None`, no threshold is applied.
Otherwise, we use the class `Thresholder` is initialized with the
provided argument, and called before the `compute`
target_nan_mask:
- None: Do not change behaviour if there are NaNs
- int, float: Value used to replace NaNs. For example, if `target_nan_mask==0`, then
all NaNs will be replaced by zeros
- 'ignore': The NaN values will be removed from the tensor before computing the metrics.
Must be coupled with the `multitask_handling='flatten'` or `multitask_handling='mean-per-label'`.
multitask_handling:
- None: Do not process the tensor before passing it to the metric.
Cannot use the option `multitask_handling=None` when `target_nan_mask=ignore`.
Use either 'flatten' or 'mean-per-label'.
- 'flatten': Flatten the tensor to produce the equivalent of a single task
- 'mean-per-label': Loop all the labels columns, process them as a single task,
and average the results over each task
*This option might slow down the computation if there are too many labels*
squeeze_targets:
If true, targets will be squeezed prior to computing the metric.
Required in classifigression task.
target_to_int:
If true, targets will be converted to integers prior to computing the metric.
kwargs:
Other arguments to call with the metric
__repr__()
¶
Control how the class is printed
__setstate__(state)
¶
Reload the class from pickling.
compute(preds, target)
¶
Compute the metric, apply the thresholder if provided, and manage the NaNs
Predictor Summaries¶
graphium.trainer.predictor_summaries
¶
Classes to store information about resulting evaluation metrics when using a Predictor Module.
Summary
¶
Bases: SummaryInterface
Results
¶
__init__(targets=None, preds=None, loss=None, metrics=None, monitored_metric=None, n_epochs=None)
¶
This inner class is used as a container for storing the results of the summary. Parameters: targets: the targets preds: the prediction tensor loss: the loss, float or tensor metrics: the metrics monitored_metric: the monitored metric n_epochs: the number of epochs
__init__(loss_fun, metrics, metrics_on_training_set=[], metrics_on_progress_bar=[], monitor='loss', mode='min', task_name=None)
¶
A container to be used by the Predictor Module that stores the results for the given metrics on the predictions and targets provided.
Parameters:
loss_fun:
Loss function used during training. Acceptable strings are 'mse', 'bce', 'mae', 'cosine'.
Otherwise, a callable object must be provided, with a method loss_fun._get_name()
.
metrics:
A dictionnary of metrics to compute on the prediction, other than the loss function.
These metrics will be logged into WandB or similar.
metrics_on_training_set:
The metrics names from `metrics` to be computed on the training set for each iteration.
If `None`, all the metrics are computed. Using less metrics can significantly improve
performance, depending on the number of readouts.
metrics_on_progress_bar:
The metrics names from `metrics` to display also on the progress bar of the training
monitor:
`str` metric to track (Default=`"loss/val"`)
task_name:
name of the task (Default=`None`)
get_best_results(step_name)
¶
retrieve the best results for a given step Parameters: step_name: which stage you are in, e.g. "train" Returns: the best results for the given step
get_dict_summary()
¶
retrieve the full summary in a dictionary Returns: the full summary in a dictionary
get_metrics_logs()
¶
Get the data about metrics to log. Note: This function requires that self.update_predictor_state() be called before it. Returns: A dictionary of metrics to log.
get_results(step_name)
¶
retrieve the results for a given step Parameters: step_name: which stage you are in, e.g. "train" Returns: the results for the given step
get_results_on_progress_bar(step_name)
¶
retrieve the results to be displayed on the progress bar for a given step Parameters: step_name: which stage you are in, e.g. "train" Returns: the results to be displayed on the progress bar for the given step
is_best_epoch(step_name, loss, metrics)
¶
check if the current epoch is the best epoch based on self.mode criteria Parameters: step_name: which stage you are in, e.g. "train" loss: the loss tensor metrics: a dictionary of metrics
set_results(metrics)
¶
set the reults from the metrics [!] This function requires that self.update_predictor_state() be called before it. Parameters: metrics: a dictionary of metrics
update_predictor_state(step_name, targets, preds, loss, n_epochs)
¶
update the state of the predictor Parameters: step_name: which stage you are in, e.g. "train" targets: the targets tensor predictions: the predictions tensor loss: the loss tensor n_epochs: the number of epochs
SummaryInterface
¶
Bases: object
An interface to define the functions implemented by summary classes that implement SummaryInterface.
TaskSummaries
¶
Bases: SummaryInterface
__init__(task_loss_fun, task_metrics, task_metrics_on_training_set, task_metrics_on_progress_bar, monitor='loss', mode='min')
¶
class to store the summaries of the tasks Parameters: task_loss_fun: the loss function for each task task_metrics: the metrics for each task task_metrics_on_training_set: the metrics to use on the training set task_metrics_on_progress_bar: the metrics to use on the progress bar monitor: the metric to monitor mode: the mode of the metric to monitor
concatenate_metrics_logs(metrics_logs)
¶
concatenate the metrics logs Parameters: metrics_logs: the metrics logs Returns: the concatenated metrics logs
get_best_results(step_name)
¶
retrieve the best results Parameters: step_name: the name of the step, i.e. "train" Returns: the best results
get_dict_summary()
¶
get task summaries in a dictionary Returns: the task summaries
get_metrics_logs()
¶
get the logs for the metrics Returns: the task logs for the metrics
get_results(step_name)
¶
retrieve the results Parameters: step_name: the name of the step, i.e. "train" Returns: the results
get_results_on_progress_bar(step_name)
¶
return all results from all tasks for the progress bar Combine the dictionaries. Instead of having keys as task names, we merge all the task-specific dictionaries. Parameters: step_name: the name of the step Returns: the results for the progress bar
metric_log_name(task_name, metric_name, step_name)
¶
print the metric name, task name and step name Returns: the metric name, task name and step name
set_results(task_metrics)
¶
set the results for all tasks Parameters: task_metrics: the metrics for each task
update_predictor_state(step_name, targets, preds, loss, task_losses, n_epochs)
¶
update the state for all predictors Parameters: step_name: the name of the step targets: the target tensors preds: the prediction tensors loss: the loss tensor task_losses: the task losses n_epochs: the number of epochs
Predictor Options¶
graphium.trainer.predictor_options
¶
Data classes to group together related arguments for the creation of a Predictor Module.
EvalOptions
dataclass
¶
This data class stores the arguments necessary to instantiate a model for the Predictor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fun |
Union[str, Dict, Callable]
|
Loss function used during training.
Acceptable strings are graphium.utils.spaces.LOSS_DICT.keys().
If a dict, must contain a 'name' key with one of the acceptable loss function strings
as a value. The rest of the dict will be used as the arguments passed to the loss object.
Otherwise, a callable object must be provided, with a method |
required |
metrics |
Dict[str, Callable]
|
A dictionnary of metrics to compute on the prediction, other than the loss function. These metrics will be logged into WandB or other. |
None
|
metrics_on_progress_bar |
List[str]
|
The metrics names from |
field(default_factory=List[str])
|
metrics_on_training_set |
Optional[List[str]]
|
The metrics names from |
None
|
check_metrics_validity()
¶
Check that the metrics for the progress_par and training_set are valid
parse_loss_fun(loss_fun)
staticmethod
¶
Parse the loss function from a string or a dict
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fun |
Union[str, Dict, Callable]
|
A callable corresponding to the loss function, a string specifying the loss
function from |
required |
Returns:
Name | Type | Description |
---|---|---|
Callable |
Callable
|
Function or callable to compute the loss, takes |
FlagOptions
dataclass
¶
This data class stores the arguments necessary to instantiate a model for the Predictor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flag_kwargs |
Dict[str, Any]
|
Keyword arguments used for FLAG, and adversarial data augmentation for graph networks. See: https://arxiv.org/abs/2010.09891
|
None
|
ModelOptions
dataclass
¶
This data class stores the arguments necessary to instantiate a model for the Predictor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_class |
Type[Module]
|
pytorch module used to create a model |
required |
model_kwargs |
Dict[str, Any]
|
Key-word arguments used to initialize the model from |
required |
OptimOptions
dataclass
¶
This data class stores the arguments necessary to configure the optimizer for the Predictor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optim_kwargs |
Optional[Dict[str, Any]]
|
Dictionnary used to initialize the optimizer, with possible keys below.
|
None
|
torch_scheduler_kwargs |
Optional[Dict[str, Any]]
|
Dictionnary for the scheduling of learning rate, with possible keys below.
|
None
|
scheduler_kwargs |
Optional[Dict[str, Any]]
|
Dictionnary for the scheduling of the learning rate modification used by pytorch-lightning
|
None
|
scheduler_class |
Optional[Union[str, Type]]
|
The class to use for the scheduler, or the str representing the scheduler. |
None
|