graphium.config¶
helper functions to load and convert config files
graphium.config._loader
¶
get_accelerator(config_acc)
¶
Get the accelerator from the config file, and ensure that they are consistant.
get_checkpoint_path(config)
¶
Get the checkpoint path from a config file. If the path is a valid name or a valid path, return it. Otherwise, assume it refers to a file in the checkpointing dir.
load_architecture(config, in_dims)
¶
Loading the architecture used for training.
Parameters:
config: The config file, with key architecture
in_dims: Dictionary of the input dimensions for various
Returns:
architecture: The datamodule used to process and load the data
load_datamodule(config, accelerator_type)
¶
Load the datamodule from the specified configurations at the key
datamodule: args
.
If the accelerator is IPU, load the IPU options as well.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Union[DictConfig, Dict[str, Any]]
|
The config file, with key |
required |
accelerator_type |
str
|
The accelerator type, e.g. "cpu", "gpu", "ipu" |
required |
Returns: datamodule: The datamodule used to process and load the data
load_metrics(config)
¶
Loading the metrics to be tracked.
Parameters:
config: The config file, with key metrics
Returns:
metrics: A dictionary of all the metrics
load_mup(mup_base_path, predictor)
¶
Load the base shapes for the mup, based either on a .ckpt
or .yaml
file.
If .yaml
, it should be generated by mup.save_base_shapes
load_predictor(config, model_class, model_kwargs, metrics, task_levels, accelerator_type, featurization=None, task_norms=None, replicas=1, gradient_acc=1, global_bs=1)
¶
Defining the predictor module, which handles the training logic from lightning.LighningModule
Parameters:
model_class: The torch Module containing the main forward function
accelerator_type: The accelerator type, e.g. "cpu", "gpu", "ipu"
Returns:
predictor: The predictor module
load_trainer(config, accelerator_type, date_time_suffix='')
¶
Defining the pytorch-lightning Trainer module.
Parameters:
config: The config file, with key trainer
accelerator_type: The accelerator type, e.g. "cpu", "gpu", "ipu"
date_time_suffix: The date and time of the current run. To be used for logging.
Returns:
trainer: the trainer module
load_yaml_config(config_path, main_dir=None, unknown_args=None)
¶
Load a YAML config file and return it as a dictionary.
Also returns the anchors &
and aliases *
of the YAML file.
Then, update the config with the unknown arguments.
Finally, update the config with the config override file specified in constants.config_override
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config_path |
Union[str, PathLike]
|
The path to the YAML config file |
required |
main_dir |
Optional[Union[str, PathLike]]
|
The main directory of the project. If specified, the config override file will be loaded from this directory |
None
|
unknown_args |
The unknown arguments to update the config with, taken from |
None
|
Returns:
Name | Type | Description |
---|---|---|
config |
Dict[str, Any]
|
The config dictionary |
merge_dicts(dict_a, dict_b, previous_dict_path='', on_exist='raise')
¶
Recursively merges dict_b into dict_a. If a key is missing from dict_a,
it is added from dict_b. If a key exists in both, an error is raised.
dict_a
is modified in-place.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dict_a |
Dict[str, Any]
|
The dictionary to merge into. Modified in-place. |
required |
dict_b |
Dict[str, Any]
|
The dictionary to merge from. |
required |
previous_dict_path |
str
|
The key path of the parent dictionary, |
''
|
on_exist |
str
|
What to do if a key already exists in dict_a. Options are "raise", "overwrite", "ignore". |
'raise'
|
Raises:
Type | Description |
---|---|
ValueError
|
If a key path already exists in dict_a. |
save_params_to_wandb(logger, config, predictor, datamodule, unresolved_config=None)
¶
Save a few stuff to weights-and-biases WandB
Parameters:
logger: The object used to log the training. Usually WandbLogger
config: The config file, with key trainer
predictor: The predictor used to handle the train/val/test steps logic
datamodule: The datamodule used to load the data into training
unresolved_config: The unresolved config file