Building and training a simple model from configurations¶
This tutorial will walk you through how to use a configuration file to define all the parameters of a model and of the trainer. This tutorial focuses on training from SMILES data in a CSV format.
The work flow of testing your code on the entire pipeline is as follows:
- Select a subset of the available configs as a starting point.
- Create additional configs or modify the existing configs to suit your needs.
- Train or fine-tune a model with the
graphium-train
CLI.
Creating the yaml file¶
The first step is to create a YAML file containing all the required configurations, with an example given at graphium/expts/hydra-configs/main.yaml
. We will go through each part of the configurations. See also the README here.
import yaml
import omegaconf
from hydra import compose, initialize
def print_config_with_key(config, key):
new_config = {key: config[key]}
print(omegaconf.OmegaConf.to_yaml(new_config))
# First, let's read the yaml configuration file
with initialize(version_base=None, config_path="../../../expts/hydra-configs"):
yaml_config = compose(config_name="main")
print("Yaml file loaded")
Yaml file loaded
Constants¶
First, we define the constants such as the random seed and whether the model should raise or ignore an error.
print_config_with_key(yaml_config, "constants")
constants: name: neurips2023_small_data_gcn seed: 42 max_epochs: 100 data_dir: expts/data/neurips2023/small-dataset raise_train_error: true
Datamodule¶
Here, we define all the parameters required by the datamodule to run correctly, such as the dataset path, whether to cache, the columns for the training, the molecular featurization to use, the train/val/test splits and the batch size.
For more details, see class MultitaskFromSmilesDataModule
print_config_with_key(yaml_config, "datamodule")
datamodule: module_type: MultitaskFromSmilesDataModule args: prepare_dict_or_graph: pyg:graph featurization_n_jobs: 4 featurization_progress: true featurization_backend: loky processed_graph_data_path: ../datacache/neurips2023-small/ num_workers: 4 persistent_workers: false featurization: atom_property_list_onehot: - atomic-number - group - period - total-valence atom_property_list_float: - degree - formal-charge - radical-electron - aromatic - in-ring edge_property_list: - bond-type-onehot - stereo - in-ring add_self_loop: false explicit_H: false use_bonds_weights: false pos_encoding_as_features: pos_types: lap_eigvec: pos_level: node pos_type: laplacian_eigvec num_pos: 8 normalization: none disconnected_comp: true lap_eigval: pos_level: node pos_type: laplacian_eigval num_pos: 8 normalization: none disconnected_comp: true rw_pos: pos_level: node pos_type: rw_return_probs ksteps: 16 task_specific_args: qm9: df: null df_path: ${constants.data_dir}/qm9.csv.gz smiles_col: smiles label_cols: - A - B - C - mu - alpha - homo - lumo - gap - r2 - zpve - u0 - u298 - h298 - g298 - cv - u0_atom - u298_atom - h298_atom - g298_atom splits_path: ${constants.data_dir}/qm9_random_splits.pt seed: ${constants.seed} task_level: graph label_normalization: normalize_val_test: true method: normal tox21: df: null df_path: ${constants.data_dir}/Tox21-7k-12-labels.csv.gz smiles_col: smiles label_cols: - NR-AR - NR-AR-LBD - NR-AhR - NR-Aromatase - NR-ER - NR-ER-LBD - NR-PPAR-gamma - SR-ARE - SR-ATAD5 - SR-HSE - SR-MMP - SR-p53 splits_path: ${constants.data_dir}/Tox21_random_splits.pt seed: ${constants.seed} task_level: graph zinc: df: null df_path: ${constants.data_dir}/ZINC12k.csv.gz smiles_col: smiles label_cols: - SA - logp - score splits_path: ${constants.data_dir}/ZINC12k_random_splits.pt seed: ${constants.seed} task_level: graph label_normalization: normalize_val_test: true method: normal batch_size_training: 200 batch_size_inference: 200
Architecture¶
The architecture is based on FullGraphMultiTaskNetwork
.
Here, we define all the layers for the model, including the layers for the pre-processing MLP (input layers pre-nn
and pre_nn_edges
), the positional encoder (pe_encoders
), the post-processing MLP (output layers post-nn
), and the main GNN (graph neural network gnn
).
You can find details in the following:
- info about the positional encoder in
graphium.nn.encoders
- info about the gnn layers in
graphium.nn.pyg_layers
- info about the architecture
FullGraphMultiTaskNetwork
- Main class for the GNN layers in
BaseGraphStructure
The parameters allow to chose the feature size, the depth, the skip connections, the pooling and the virtual node. It also support different GNN layers such as GatedGCNPyg
, GINConvPyg
, GINEConvPyg
, GPSLayerPyg
, MPNNPlusPyg
.
print_config_with_key(yaml_config, "architecture")
architecture: model_type: FullGraphMultiTaskNetwork mup_base_path: null pre_nn: out_dim: 64 hidden_dims: 256 depth: 2 activation: relu last_activation: none dropout: 0.18 normalization: layer_norm last_normalization: ${architecture.pre_nn.normalization} residual_type: none pre_nn_edges: null pe_encoders: out_dim: 32 pool: sum last_norm: None encoders: la_pos: encoder_type: laplacian_pe input_keys: - laplacian_eigvec - laplacian_eigval output_keys: - feat hidden_dim: 64 out_dim: 32 model_type: DeepSet num_layers: 2 num_layers_post: 1 dropout: 0.1 first_normalization: none rw_pos: encoder_type: mlp input_keys: - rw_return_probs output_keys: - feat hidden_dim: 64 out_dim: 32 num_layers: 2 dropout: 0.1 normalization: layer_norm first_normalization: layer_norm gnn: in_dim: 64 out_dim: 96 hidden_dims: 96 depth: 4 activation: gelu last_activation: none dropout: 0.1 normalization: layer_norm last_normalization: ${architecture.pre_nn.normalization} residual_type: simple virtual_node: none layer_type: pyg:gcn layer_kwargs: null graph_output_nn: graph: pooling: - sum out_dim: 96 hidden_dims: 96 depth: 1 activation: relu last_activation: none dropout: ${architecture.pre_nn.dropout} normalization: ${architecture.pre_nn.normalization} last_normalization: none residual_type: none task_heads: qm9: task_level: graph out_dim: 19 hidden_dims: 128 depth: 2 activation: relu last_activation: none dropout: ${architecture.pre_nn.dropout} normalization: ${architecture.pre_nn.normalization} last_normalization: none residual_type: none tox21: task_level: graph out_dim: 12 hidden_dims: 64 depth: 2 activation: relu last_activation: none dropout: ${architecture.pre_nn.dropout} normalization: ${architecture.pre_nn.normalization} last_normalization: none residual_type: none zinc: task_level: graph out_dim: 3 hidden_dims: 32 depth: 2 activation: relu last_activation: none dropout: ${architecture.pre_nn.dropout} normalization: ${architecture.pre_nn.normalization} last_normalization: none residual_type: none
Predictor¶
In the predictor, we define the loss functions, the metrics to track on the progress bar, and all the parameters necessary for the optimizer.
print_config_with_key(yaml_config, "predictor")
predictor: metrics_on_progress_bar: qm9: - mae tox21: - auroc zinc: - mae loss_fun: qm9: mae_ipu tox21: bce_logits_ipu zinc: mae_ipu random_seed: ${constants.seed} optim_kwargs: lr: 4.0e-05 torch_scheduler_kwargs: module_type: WarmUpLinearLR max_num_epochs: ${constants.max_epochs} warmup_epochs: 10 verbose: false scheduler_kwargs: null target_nan_mask: null multitask_handling: flatten metrics_every_n_train_steps: 300
Metrics¶
All the metrics can be defined there. If we want to use a classification metric, we can also define a threshold.
See class graphium.trainer.metrics.MetricWrapper
for more details.
print_config_with_key(yaml_config, "metrics")
metrics: qm9: - name: mae metric: mae_ipu target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr metric: pearsonr_ipu threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2_score metric: r2_score_ipu target_nan_mask: null multitask_handling: mean-per-label threshold_kwargs: null tox21: - name: auroc metric: auroc_ipu task: binary multitask_handling: mean-per-label threshold_kwargs: null - name: avpr metric: average_precision_ipu task: binary multitask_handling: mean-per-label threshold_kwargs: null - name: f1 > 0.5 metric: f1 multitask_handling: mean-per-label target_to_int: true num_classes: 2 average: micro threshold_kwargs: operator: greater threshold: 0.5 th_on_preds: true th_on_target: true - name: precision > 0.5 metric: precision multitask_handling: mean-per-label average: micro threshold_kwargs: operator: greater threshold: 0.5 th_on_preds: true th_on_target: true zinc: - name: mae metric: mae_ipu target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr metric: pearsonr_ipu threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label - name: r2_score metric: r2_score_ipu target_nan_mask: null multitask_handling: mean-per-label threshold_kwargs: null
Trainer¶
Finally, the Trainer defines the parameters for the number of epochs to train, the checkpoints, and the patience.
print_config_with_key(yaml_config, "trainer")
trainer: seed: ${constants.seed} model_checkpoint: filename: ${constants.name} save_last: true dirpath: models_checkpoints/neurips2023-small-gcn/ trainer: precision: 32 max_epochs: ${constants.max_epochs} min_epochs: 1 check_val_every_n_epoch: 20 accumulate_grad_batches: 1
Training the model¶
Now that we defined all the configuration files, we want to train the model. The steps are fairly easy using the config loaders, and are given below.
First make sure the dataset file is downloaded. Using config_gps_10M_pcqm4m.yaml
as an example, make sure the file specified by df_path
in the config is available.
In this case, we need to download pcqm4mv2-20k.csv
into the specified directory graphium/data/PCQM4M/pcqm4mv2-20k.csv
.
After that, we can simply run a training through the CLI:
graphium-train