graphium.config¶
helper functions to load and convert config files
graphium.config._loader
¶
get_accelerator(config_acc)
¶
Get the accelerator from the config file, and ensure that they are consistant.
load_architecture(config, in_dims)
¶
Loading the architecture used for training.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Union[omegaconf.DictConfig, Dict[str, Any]]
|
The config file, with key |
required |
in_dims |
Dict[str, int]
|
Dictionary of the input dimensions for various |
required |
Returns:
Name | Type | Description |
---|---|---|
architecture |
Union[FullGraphMultiTaskNetwork, torch.nn.Module]
|
The datamodule used to process and load the data |
load_datamodule(config, accelerator_type)
¶
Load the datamodule from the specified configurations at the key
datamodule: args
.
If the accelerator is IPU, load the IPU options as well.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Union[omegaconf.DictConfig, Dict[str, Any]]
|
The config file, with key |
required |
accelerator_type |
str
|
The accelerator type, e.g. "cpu", "gpu", "ipu" |
required |
Returns:
Name | Type | Description |
---|---|---|
datamodule |
BaseDataModule
|
The datamodule used to process and load the data |
load_metrics(config)
¶
Loading the metrics to be tracked.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Union[omegaconf.DictConfig, Dict[str, Any]]
|
The config file, with key |
required |
Returns:
Name | Type | Description |
---|---|---|
metrics |
Dict[str, MetricWrapper]
|
A dictionary of all the metrics |
load_mup(mup_base_path, predictor)
¶
Load the base shapes for the mup, based either on a .ckpt
or .yaml
file.
If .yaml
, it should be generated by mup.save_base_shapes
load_predictor(config, model_class, model_kwargs, metrics, accelerator_type, task_norms=None)
¶
Defining the predictor module, which handles the training logic from lightning.LighningModule
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_class |
Type[torch.nn.Module]
|
The torch Module containing the main forward function |
required |
accelerator_type |
str
|
The accelerator type, e.g. "cpu", "gpu", "ipu" |
required |
Returns:
Name | Type | Description |
---|---|---|
predictor |
PredictorModule
|
The predictor module |
load_trainer(config, run_name, accelerator_type, date_time_suffix='')
¶
Defining the pytorch-lightning Trainer module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Union[omegaconf.DictConfig, Dict[str, Any]]
|
The config file, with key |
required |
run_name |
str
|
The name of the current run. To be used for logging. |
required |
accelerator_type |
str
|
The accelerator type, e.g. "cpu", "gpu", "ipu" |
required |
date_time_suffix |
str
|
The date and time of the current run. To be used for logging. |
''
|
Returns:
Name | Type | Description |
---|---|---|
trainer |
Trainer
|
the trainer module |
load_yaml_config(config_path, main_dir=None, unknown_args=None)
¶
Load a YAML config file and return it as a dictionary.
Also returns the anchors &
and aliases *
of the YAML file.
Then, update the config with the unknown arguments.
Finally, update the config with the config override file specified in constants.config_override
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config_path |
Union[str, os.PathLike]
|
The path to the YAML config file |
required |
main_dir |
Optional[Union[str, os.PathLike]]
|
The main directory of the project. If specified, the config override file will be loaded from this directory |
None
|
unknown_args |
The unknown arguments to update the config with, taken from |
None
|
Returns:
Name | Type | Description |
---|---|---|
config |
Dict[str, Any]
|
The config dictionary |
merge_dicts(dict_a, dict_b, previous_dict_path='', on_exist='raise')
¶
Recursively merges dict_b into dict_a. If a key is missing from dict_a,
it is added from dict_b. If a key exists in both, an error is raised.
dict_a
is modified in-place.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dict_a |
Dict[str, Any]
|
The dictionary to merge into. Modified in-place. |
required |
dict_b |
Dict[str, Any]
|
The dictionary to merge from. |
required |
previous_dict_path |
str
|
The key path of the parent dictionary, |
''
|
on_exist |
str
|
What to do if a key already exists in dict_a. Options are "raise", "overwrite", "ignore". |
'raise'
|
Raises:
Type | Description |
---|---|
ValueError
|
If a key path already exists in dict_a. |
save_params_to_wandb(logger, config, predictor, datamodule)
¶
Save a few stuff to weights-and-biases WandB
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logger |
Logger
|
The object used to log the training. Usually WandbLogger |
required |
config |
Union[omegaconf.DictConfig, Dict[str, Any]]
|
The config file, with key |
required |
predictor |
PredictorModule
|
The predictor used to handle the train/val/test steps logic |
required |
datamodule |
MultitaskFromSmilesDataModule
|
The datamodule used to load the data into training |
required |