graphium.config¶
helper functions to load and convert config files
graphium.config._loader
¶
get_accelerator(config_acc)
¶
Get the accelerator from the config file, and ensure that they are
consistant. For example, specifying cpu
as the accelerators, but
gpus>0
as a Trainer option will yield an error.
load_architecture(config, in_dims)
¶
Loading the architecture used for training.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Union[omegaconf.DictConfig, Dict[str, Any]]
|
The config file, with key |
required |
in_dims |
Dict[str, int]
|
Dictionary of the input dimensions for various |
required |
Returns:
Name | Type | Description |
---|---|---|
architecture |
Union[FullGraphMultiTaskNetwork, torch.nn.Module]
|
The datamodule used to process and load the data |
load_datamodule(config, accelerator_type)
¶
Load the datamodule from the specified configurations at the key
datamodule: args
.
If the accelerator is IPU, load the IPU options as well.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Union[omegaconf.DictConfig, Dict[str, Any]]
|
The config file, with key |
required |
accelerator_type |
str
|
The accelerator type, e.g. "cpu", "gpu", "ipu" |
required |
Returns:
Name | Type | Description |
---|---|---|
datamodule |
BaseDataModule
|
The datamodule used to process and load the data |
load_metrics(config)
¶
Loading the metrics to be tracked.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Union[omegaconf.DictConfig, Dict[str, Any]]
|
The config file, with key |
required |
Returns:
Name | Type | Description |
---|---|---|
metrics |
Dict[str, MetricWrapper]
|
A dictionary of all the metrics |
load_mup(mup_base_path, predictor)
¶
Load the base shapes for the mup, based either on a .ckpt
or .yaml
file.
If .yaml
, it should be generated by mup.save_base_shapes
load_predictor(config, model_class, model_kwargs, metrics, accelerator_type, task_norms=None)
¶
Defining the predictor module, which handles the training logic from pytorch_lightning.LighningModule
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_class |
Type[torch.nn.Module]
|
The torch Module containing the main forward function |
required |
accelerator_type |
str
|
The accelerator type, e.g. "cpu", "gpu", "ipu" |
required |
Returns:
Name | Type | Description |
---|---|---|
predictor |
PredictorModule
|
The predictor module |
load_trainer(config, run_name, accelerator_type, date_time_suffix='')
¶
Defining the pytorch-lightning Trainer module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Union[omegaconf.DictConfig, Dict[str, Any]]
|
The config file, with key |
required |
run_name |
str
|
The name of the current run. To be used for logging. |
required |
accelerator_type |
str
|
The accelerator type, e.g. "cpu", "gpu", "ipu" |
required |
date_time_suffix |
str
|
The date and time of the current run. To be used for logging. |
''
|
Returns:
Name | Type | Description |
---|---|---|
trainer |
Trainer
|
the trainer module |
merge_dicts(dict_a, dict_b, previous_dict_path='')
¶
Recursively merges dict_b into dict_a. If a key is missing from dict_a,
it is added from dict_b. If a key exists in both, an error is raised.
dict_a
is modified in-place.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dict_a |
Dict[str, Any]
|
The dictionary to merge into. Modified in-place. |
required |
dict_b |
Dict[str, Any]
|
The dictionary to merge from. |
required |
previous_dict_path |
str
|
The key path of the parent dictionary, |
''
|
Raises:
Type | Description |
---|---|
ValueError
|
If a key path already exists in dict_a. |
save_params_to_wandb(logger, config, predictor, datamodule)
¶
Save a few stuff to weights-and-biases WandB
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logger |
Logger
|
The object used to log the training. Usually WandbLogger |
required |
config |
Union[omegaconf.DictConfig, Dict[str, Any]]
|
The config file, with key |
required |
predictor |
PredictorModule
|
The predictor used to handle the train/val/test steps logic |
required |
datamodule |
MultitaskFromSmilesDataModule
|
The datamodule used to load the data into training |
required |