graphium.config¶
helper functions to load and convert config files
          graphium.config._loader
¶
  
          get_accelerator(config_acc)
¶
  Get the accelerator from the config file, and ensure that they are consistant.
          get_checkpoint_path(config)
¶
  Get the checkpoint path from a config file. If the path is a valid name or a valid path, return it. Otherwise, assume it refers to a file in the checkpointing dir.
          load_architecture(config, in_dims)
¶
  Loading the architecture used for training.
Parameters:
    config: The config file, with key architecture
    in_dims: Dictionary of the input dimensions for various
Returns:
    architecture: The datamodule used to process and load the data
          load_datamodule(config, accelerator_type)
¶
  Load the datamodule from the specified configurations at the key
datamodule: args.
If the accelerator is IPU, load the IPU options as well.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
config | 
          
                Union[DictConfig, Dict[str, Any]]
           | 
          
             The config file, with key   | 
          required | 
accelerator_type | 
          
                str
           | 
          
             The accelerator type, e.g. "cpu", "gpu", "ipu"  | 
          required | 
Returns: datamodule: The datamodule used to process and load the data
          load_metrics(config)
¶
  Loading the metrics to be tracked.
Parameters:
    config: The config file, with key metrics
Returns:
    metrics: A dictionary of all the metrics
          load_mup(mup_base_path, predictor)
¶
  Load the base shapes for the mup, based either on a .ckpt or .yaml file.
If .yaml, it should be generated by mup.save_base_shapes
          load_predictor(config, model_class, model_kwargs, metrics, task_levels, accelerator_type, featurization=None, task_norms=None)
¶
  Defining the predictor module, which handles the training logic from lightning.LighningModule
Parameters:
    model_class: The torch Module containing the main forward function
    accelerator_type: The accelerator type, e.g. "cpu", "gpu", "ipu"
Returns:
    predictor: The predictor module
          load_trainer(config, accelerator_type, date_time_suffix='')
¶
  Defining the pytorch-lightning Trainer module.
Parameters:
    config: The config file, with key trainer
    accelerator_type: The accelerator type, e.g. "cpu", "gpu", "ipu"
    date_time_suffix: The date and time of the current run. To be used for logging.
Returns:
    trainer: the trainer module
          load_yaml_config(config_path, main_dir=None, unknown_args=None)
¶
  Load a YAML config file and return it as a dictionary.
Also returns the anchors & and aliases * of the YAML file.
Then, update the config with the unknown arguments.
Finally, update the config with the config override file specified in constants.config_override.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
config_path | 
          
                Union[str, PathLike]
           | 
          
             The path to the YAML config file  | 
          required | 
main_dir | 
          
                Optional[Union[str, PathLike]]
           | 
          
             The main directory of the project. If specified, the config override file will be loaded from this directory  | 
          
                None
           | 
        
unknown_args | 
          
             The unknown arguments to update the config with, taken from   | 
          
                None
           | 
        
Returns:
| Name | Type | Description | 
|---|---|---|
config |           
                Dict[str, Any]
           | 
          
             The config dictionary  | 
        
          merge_dicts(dict_a, dict_b, previous_dict_path='', on_exist='raise')
¶
  Recursively merges dict_b into dict_a. If a key is missing from dict_a,
it is added from dict_b. If a key exists in both, an error is raised.
dict_a is modified in-place.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
dict_a | 
          
                Dict[str, Any]
           | 
          
             The dictionary to merge into. Modified in-place.  | 
          required | 
dict_b | 
          
                Dict[str, Any]
           | 
          
             The dictionary to merge from.  | 
          required | 
previous_dict_path | 
          
                str
           | 
          
             The key path of the parent dictionary,  | 
          
                ''
           | 
        
on_exist | 
          
                str
           | 
          
             What to do if a key already exists in dict_a. Options are "raise", "overwrite", "ignore".  | 
          
                'raise'
           | 
        
Raises:
| Type | Description | 
|---|---|
                ValueError
           | 
          
             If a key path already exists in dict_a.  | 
        
          save_params_to_wandb(logger, config, predictor, datamodule)
¶
  Save a few stuff to weights-and-biases WandB
Parameters:
    logger: The object used to log the training. Usually WandbLogger
    config: The config file, with key trainer
    predictor: The predictor used to handle the train/val/test steps logic
    datamodule: The datamodule used to load the data into training