Skip to content

graphium.config

helper functions to load and convert config files

graphium.config._loader

get_accelerator(config_acc)

Get the accelerator from the config file, and ensure that they are consistant.

load_architecture(config, in_dims)

Loading the architecture used for training.

Parameters:

Name Type Description Default
config Union[omegaconf.DictConfig, Dict[str, Any]]

The config file, with key architecture

required
in_dims Dict[str, int]

Dictionary of the input dimensions for various

required

Returns:

Name Type Description
architecture Union[FullGraphMultiTaskNetwork, torch.nn.Module]

The datamodule used to process and load the data

load_datamodule(config, accelerator_type)

Load the datamodule from the specified configurations at the key datamodule: args. If the accelerator is IPU, load the IPU options as well.

Parameters:

Name Type Description Default
config Union[omegaconf.DictConfig, Dict[str, Any]]

The config file, with key datamodule: args

required
accelerator_type str

The accelerator type, e.g. "cpu", "gpu", "ipu"

required

Returns:

Name Type Description
datamodule BaseDataModule

The datamodule used to process and load the data

load_metrics(config)

Loading the metrics to be tracked.

Parameters:

Name Type Description Default
config Union[omegaconf.DictConfig, Dict[str, Any]]

The config file, with key metrics

required

Returns:

Name Type Description
metrics Dict[str, MetricWrapper]

A dictionary of all the metrics

load_mup(mup_base_path, predictor)

Load the base shapes for the mup, based either on a .ckpt or .yaml file. If .yaml, it should be generated by mup.save_base_shapes

load_predictor(config, model_class, model_kwargs, metrics, accelerator_type, task_norms=None)

Defining the predictor module, which handles the training logic from lightning.LighningModule

Parameters:

Name Type Description Default
model_class Type[torch.nn.Module]

The torch Module containing the main forward function

required
accelerator_type str

The accelerator type, e.g. "cpu", "gpu", "ipu"

required

Returns:

Name Type Description
predictor PredictorModule

The predictor module

load_trainer(config, run_name, accelerator_type, date_time_suffix='')

Defining the pytorch-lightning Trainer module.

Parameters:

Name Type Description Default
config Union[omegaconf.DictConfig, Dict[str, Any]]

The config file, with key trainer

required
run_name str

The name of the current run. To be used for logging.

required
accelerator_type str

The accelerator type, e.g. "cpu", "gpu", "ipu"

required
date_time_suffix str

The date and time of the current run. To be used for logging.

''

Returns:

Name Type Description
trainer Trainer

the trainer module

load_yaml_config(config_path, main_dir=None, unknown_args=None)

Load a YAML config file and return it as a dictionary. Also returns the anchors & and aliases * of the YAML file. Then, update the config with the unknown arguments. Finally, update the config with the config override file specified in constants.config_override.

Parameters:

Name Type Description Default
config_path Union[str, os.PathLike]

The path to the YAML config file

required
main_dir Optional[Union[str, os.PathLike]]

The main directory of the project. If specified, the config override file will be loaded from this directory

None
unknown_args

The unknown arguments to update the config with, taken from argparse.parse_known_args

None

Returns:

Name Type Description
config Dict[str, Any]

The config dictionary

merge_dicts(dict_a, dict_b, previous_dict_path='', on_exist='raise')

Recursively merges dict_b into dict_a. If a key is missing from dict_a, it is added from dict_b. If a key exists in both, an error is raised. dict_a is modified in-place.

Parameters:

Name Type Description Default
dict_a Dict[str, Any]

The dictionary to merge into. Modified in-place.

required
dict_b Dict[str, Any]

The dictionary to merge from.

required
previous_dict_path str

The key path of the parent dictionary,

''
on_exist str

What to do if a key already exists in dict_a. Options are "raise", "overwrite", "ignore".

'raise'

Raises:

Type Description
ValueError

If a key path already exists in dict_a.

save_params_to_wandb(logger, config, predictor, datamodule)

Save a few stuff to weights-and-biases WandB

Parameters:

Name Type Description Default
logger Logger

The object used to log the training. Usually WandbLogger

required
config Union[omegaconf.DictConfig, Dict[str, Any]]

The config file, with key trainer

required
predictor PredictorModule

The predictor used to handle the train/val/test steps logic

required
datamodule MultitaskFromSmilesDataModule

The datamodule used to load the data into training

required

graphium.config.config_convert

recursive_config_reformating(configs)

For a given configuration file, convert all DictConfig to dict, all ListConfig to list, and all byte to str.

This helps avoid errors when dumping a yaml file.

graphium.config._load

load_config(name)

Load a default config file by its name.

Parameters:

Name Type Description Default
name str

name of the config to load.

required