Building and training a simple model from configurations¶
This tutorial will walk you through how to use a configuration file to define all the parameters of a model and of the trainer. This tutorial focuses on training from SMILES data in a CSV format.
The work flow of testing your code on the entire pipeline is as follows:
- Select a subset of the available configs as a starting point.
 - Create additional configs or modify the existing configs to suit your needs.
 - Train or fine-tune a model with the 
graphium-trainCLI. 
Creating the yaml file¶
The first step is to create a YAML file containing all the required configurations, with an example given at graphium/expts/hydra-configs/main.yaml. We will go through each part of the configurations. See also the README here.
import yaml
import omegaconf
from hydra import compose, initialize
def print_config_with_key(config, key):
    new_config = {key: config[key]}
    print(omegaconf.OmegaConf.to_yaml(new_config))
# First, let's read the yaml configuration file
with initialize(version_base=None, config_path="../../../expts/hydra-configs"):
    yaml_config = compose(config_name="main")
print("Yaml file loaded")
Yaml file loaded
Constants¶
First, we define the constants such as the random seed and whether the model should raise or ignore an error.
print_config_with_key(yaml_config, "constants")
constants: name: neurips2023_small_data_gcn seed: 42 max_epochs: 100 data_dir: expts/data/neurips2023/small-dataset raise_train_error: true
Datamodule¶
Here, we define all the parameters required by the datamodule to run correctly, such as the dataset path, whether to cache, the columns for the training, the molecular featurization to use, the train/val/test splits and the batch size.
For more details, see class MultitaskFromSmilesDataModule
print_config_with_key(yaml_config, "datamodule")
datamodule:
  module_type: MultitaskFromSmilesDataModule
  args:
    prepare_dict_or_graph: pyg:graph
    featurization_n_jobs: 4
    featurization_progress: true
    featurization_backend: loky
    processed_graph_data_path: ../datacache/neurips2023-small/
    num_workers: 4
    persistent_workers: false
    featurization:
      atom_property_list_onehot:
      - atomic-number
      - group
      - period
      - total-valence
      atom_property_list_float:
      - degree
      - formal-charge
      - radical-electron
      - aromatic
      - in-ring
      edge_property_list:
      - bond-type-onehot
      - stereo
      - in-ring
      add_self_loop: false
      explicit_H: false
      use_bonds_weights: false
      pos_encoding_as_features:
        pos_types:
          lap_eigvec:
            pos_level: node
            pos_type: laplacian_eigvec
            num_pos: 8
            normalization: none
            disconnected_comp: true
          lap_eigval:
            pos_level: node
            pos_type: laplacian_eigval
            num_pos: 8
            normalization: none
            disconnected_comp: true
          rw_pos:
            pos_level: node
            pos_type: rw_return_probs
            ksteps: 16
    task_specific_args:
      qm9:
        df: null
        df_path: ${constants.data_dir}/qm9.csv.gz
        smiles_col: smiles
        label_cols:
        - A
        - B
        - C
        - mu
        - alpha
        - homo
        - lumo
        - gap
        - r2
        - zpve
        - u0
        - u298
        - h298
        - g298
        - cv
        - u0_atom
        - u298_atom
        - h298_atom
        - g298_atom
        splits_path: ${constants.data_dir}/qm9_random_splits.pt
        seed: ${constants.seed}
        task_level: graph
        label_normalization:
          normalize_val_test: true
          method: normal
      tox21:
        df: null
        df_path: ${constants.data_dir}/Tox21-7k-12-labels.csv.gz
        smiles_col: smiles
        label_cols:
        - NR-AR
        - NR-AR-LBD
        - NR-AhR
        - NR-Aromatase
        - NR-ER
        - NR-ER-LBD
        - NR-PPAR-gamma
        - SR-ARE
        - SR-ATAD5
        - SR-HSE
        - SR-MMP
        - SR-p53
        splits_path: ${constants.data_dir}/Tox21_random_splits.pt
        seed: ${constants.seed}
        task_level: graph
      zinc:
        df: null
        df_path: ${constants.data_dir}/ZINC12k.csv.gz
        smiles_col: smiles
        label_cols:
        - SA
        - logp
        - score
        splits_path: ${constants.data_dir}/ZINC12k_random_splits.pt
        seed: ${constants.seed}
        task_level: graph
        label_normalization:
          normalize_val_test: true
          method: normal
    batch_size_training: 200
    batch_size_inference: 200
Architecture¶
The architecture is based on FullGraphMultiTaskNetwork.
Here, we define all the layers for the model, including the layers for the pre-processing MLP (input layers pre-nn and pre_nn_edges), the positional encoder (pe_encoders), the post-processing MLP (output layers post-nn), and the main GNN (graph neural network gnn).
You can find details in the following:
- info about the positional encoder in 
graphium.nn.encoders - info about the gnn layers in 
graphium.nn.pyg_layers - info about the architecture 
FullGraphMultiTaskNetwork - Main class for the GNN layers in 
BaseGraphStructure 
The parameters allow to chose the feature size, the depth, the skip connections, the pooling and the virtual node. It also support different GNN layers such as GatedGCNPyg, GINConvPyg, GINEConvPyg, GPSLayerPyg, MPNNPlusPyg.
print_config_with_key(yaml_config, "architecture")
architecture:
  model_type: FullGraphMultiTaskNetwork
  mup_base_path: null
  pre_nn:
    out_dim: 64
    hidden_dims: 256
    depth: 2
    activation: relu
    last_activation: none
    dropout: 0.18
    normalization: layer_norm
    last_normalization: ${architecture.pre_nn.normalization}
    residual_type: none
  pre_nn_edges: null
  pe_encoders:
    out_dim: 32
    pool: sum
    last_norm: None
    encoders:
      la_pos:
        encoder_type: laplacian_pe
        input_keys:
        - laplacian_eigvec
        - laplacian_eigval
        output_keys:
        - feat
        hidden_dim: 64
        out_dim: 32
        model_type: DeepSet
        num_layers: 2
        num_layers_post: 1
        dropout: 0.1
        first_normalization: none
      rw_pos:
        encoder_type: mlp
        input_keys:
        - rw_return_probs
        output_keys:
        - feat
        hidden_dim: 64
        out_dim: 32
        num_layers: 2
        dropout: 0.1
        normalization: layer_norm
        first_normalization: layer_norm
  gnn:
    in_dim: 64
    out_dim: 96
    hidden_dims: 96
    depth: 4
    activation: gelu
    last_activation: none
    dropout: 0.1
    normalization: layer_norm
    last_normalization: ${architecture.pre_nn.normalization}
    residual_type: simple
    virtual_node: none
    layer_type: pyg:gcn
    layer_kwargs: null
  graph_output_nn:
    graph:
      pooling:
      - sum
      out_dim: 96
      hidden_dims: 96
      depth: 1
      activation: relu
      last_activation: none
      dropout: ${architecture.pre_nn.dropout}
      normalization: ${architecture.pre_nn.normalization}
      last_normalization: none
      residual_type: none
  task_heads:
    qm9:
      task_level: graph
      out_dim: 19
      hidden_dims: 128
      depth: 2
      activation: relu
      last_activation: none
      dropout: ${architecture.pre_nn.dropout}
      normalization: ${architecture.pre_nn.normalization}
      last_normalization: none
      residual_type: none
    tox21:
      task_level: graph
      out_dim: 12
      hidden_dims: 64
      depth: 2
      activation: relu
      last_activation: none
      dropout: ${architecture.pre_nn.dropout}
      normalization: ${architecture.pre_nn.normalization}
      last_normalization: none
      residual_type: none
    zinc:
      task_level: graph
      out_dim: 3
      hidden_dims: 32
      depth: 2
      activation: relu
      last_activation: none
      dropout: ${architecture.pre_nn.dropout}
      normalization: ${architecture.pre_nn.normalization}
      last_normalization: none
      residual_type: none
Predictor¶
In the predictor, we define the loss functions, the metrics to track on the progress bar, and all the parameters necessary for the optimizer.
print_config_with_key(yaml_config, "predictor")
predictor:
  metrics_on_progress_bar:
    qm9:
    - mae
    tox21:
    - auroc
    zinc:
    - mae
  loss_fun:
    qm9: mae_ipu
    tox21: bce_logits_ipu
    zinc: mae_ipu
  random_seed: ${constants.seed}
  optim_kwargs:
    lr: 4.0e-05
  torch_scheduler_kwargs:
    module_type: WarmUpLinearLR
    max_num_epochs: ${constants.max_epochs}
    warmup_epochs: 10
    verbose: false
  scheduler_kwargs: null
  target_nan_mask: null
  multitask_handling: flatten
  metrics_every_n_train_steps: 300
Metrics¶
All the metrics can be defined there. If we want to use a classification metric, we can also define a threshold.
See class graphium.trainer.metrics.MetricWrapper for more details.
print_config_with_key(yaml_config, "metrics")
metrics:
  qm9:
  - name: mae
    metric: mae_ipu
    target_nan_mask: null
    multitask_handling: flatten
    threshold_kwargs: null
  - name: pearsonr
    metric: pearsonr_ipu
    threshold_kwargs: null
    target_nan_mask: null
    multitask_handling: mean-per-label
  - name: r2_score
    metric: r2_score_ipu
    target_nan_mask: null
    multitask_handling: mean-per-label
    threshold_kwargs: null
  tox21:
  - name: auroc
    metric: auroc_ipu
    task: binary
    multitask_handling: mean-per-label
    threshold_kwargs: null
  - name: avpr
    metric: average_precision_ipu
    task: binary
    multitask_handling: mean-per-label
    threshold_kwargs: null
  - name: f1 > 0.5
    metric: f1
    multitask_handling: mean-per-label
    target_to_int: true
    num_classes: 2
    average: micro
    threshold_kwargs:
      operator: greater
      threshold: 0.5
      th_on_preds: true
      th_on_target: true
  - name: precision > 0.5
    metric: precision
    multitask_handling: mean-per-label
    average: micro
    threshold_kwargs:
      operator: greater
      threshold: 0.5
      th_on_preds: true
      th_on_target: true
  zinc:
  - name: mae
    metric: mae_ipu
    target_nan_mask: null
    multitask_handling: flatten
    threshold_kwargs: null
  - name: pearsonr
    metric: pearsonr_ipu
    threshold_kwargs: null
    target_nan_mask: null
    multitask_handling: mean-per-label
  - name: r2_score
    metric: r2_score_ipu
    target_nan_mask: null
    multitask_handling: mean-per-label
    threshold_kwargs: null
Trainer¶
Finally, the Trainer defines the parameters for the number of epochs to train, the checkpoints, and the patience.
print_config_with_key(yaml_config, "trainer")
trainer:
  seed: ${constants.seed}
  model_checkpoint:
    filename: ${constants.name}
    save_last: true
    dirpath: models_checkpoints/neurips2023-small-gcn/
  trainer:
    precision: 32
    max_epochs: ${constants.max_epochs}
    min_epochs: 1
    check_val_every_n_epoch: 20
    accumulate_grad_batches: 1
Training the model¶
Now that we defined all the configuration files, we want to train the model. The steps are fairly easy using the config loaders, and are given below.
First make sure the dataset file is downloaded. Using config_gps_10M_pcqm4m.yaml as an example, make sure the file specified by df_path in the config is available.
In this case, we need to download pcqm4mv2-20k.csv into the specified directory graphium/data/PCQM4M/pcqm4mv2-20k.csv.
After that, we can simply run a training through the CLI:
graphium-train