Building and training a simple model from configurations¶
This tutorial will walk you through how to use a configuration file to define all the parameters of a model and of the trainer. This tutorial focuses on training from SMILES data in a CSV format.
The work flow of testing your code on the entire pipeline is as follows:
- Select a subset of the available configs as a starting point.
- Create additional configs or modify the existing configs to suit your needs.
- Train or fine-tune a model with the
graphium-trainCLI.
Creating the yaml file¶
The first step is to create a YAML file containing all the required configurations, with an example given at graphium/expts/hydra-configs/main.yaml. We will go through each part of the configurations. See also the README here.
import yaml
import omegaconf
from hydra import compose, initialize
def print_config_with_key(config, key):
new_config = {key: config[key]}
print(omegaconf.OmegaConf.to_yaml(new_config))
# First, let's read the yaml configuration file
with initialize(version_base=None, config_path="../../../expts/hydra-configs"):
yaml_config = compose(config_name="main")
print("Yaml file loaded")
Yaml file loaded
Constants¶
First, we define the constants such as the random seed and whether the model should raise or ignore an error.
print_config_with_key(yaml_config, "constants")
constants: name: neurips2023_small_data_gcn seed: 42 max_epochs: 100 data_dir: expts/data/neurips2023/small-dataset raise_train_error: true
Datamodule¶
Here, we define all the parameters required by the datamodule to run correctly, such as the dataset path, whether to cache, the columns for the training, the molecular featurization to use, the train/val/test splits and the batch size.
For more details, see class MultitaskFromSmilesDataModule
print_config_with_key(yaml_config, "datamodule")
datamodule:
module_type: MultitaskFromSmilesDataModule
args:
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 4
featurization_progress: true
featurization_backend: loky
processed_graph_data_path: ../datacache/neurips2023-small/
num_workers: 4
persistent_workers: false
featurization:
atom_property_list_onehot:
- atomic-number
- group
- period
- total-valence
atom_property_list_float:
- degree
- formal-charge
- radical-electron
- aromatic
- in-ring
edge_property_list:
- bond-type-onehot
- stereo
- in-ring
add_self_loop: false
explicit_H: false
use_bonds_weights: false
pos_encoding_as_features:
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: none
disconnected_comp: true
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: none
disconnected_comp: true
rw_pos:
pos_level: node
pos_type: rw_return_probs
ksteps: 16
task_specific_args:
qm9:
df: null
df_path: ${constants.data_dir}/qm9.csv.gz
smiles_col: smiles
label_cols:
- A
- B
- C
- mu
- alpha
- homo
- lumo
- gap
- r2
- zpve
- u0
- u298
- h298
- g298
- cv
- u0_atom
- u298_atom
- h298_atom
- g298_atom
splits_path: ${constants.data_dir}/qm9_random_splits.pt
seed: ${constants.seed}
task_level: graph
label_normalization:
normalize_val_test: true
method: normal
tox21:
df: null
df_path: ${constants.data_dir}/Tox21-7k-12-labels.csv.gz
smiles_col: smiles
label_cols:
- NR-AR
- NR-AR-LBD
- NR-AhR
- NR-Aromatase
- NR-ER
- NR-ER-LBD
- NR-PPAR-gamma
- SR-ARE
- SR-ATAD5
- SR-HSE
- SR-MMP
- SR-p53
splits_path: ${constants.data_dir}/Tox21_random_splits.pt
seed: ${constants.seed}
task_level: graph
zinc:
df: null
df_path: ${constants.data_dir}/ZINC12k.csv.gz
smiles_col: smiles
label_cols:
- SA
- logp
- score
splits_path: ${constants.data_dir}/ZINC12k_random_splits.pt
seed: ${constants.seed}
task_level: graph
label_normalization:
normalize_val_test: true
method: normal
batch_size_training: 200
batch_size_inference: 200
Architecture¶
The architecture is based on FullGraphMultiTaskNetwork.
Here, we define all the layers for the model, including the layers for the pre-processing MLP (input layers pre-nn and pre_nn_edges), the positional encoder (pe_encoders), the post-processing MLP (output layers post-nn), and the main GNN (graph neural network gnn).
You can find details in the following:
- info about the positional encoder in
graphium.nn.encoders - info about the gnn layers in
graphium.nn.pyg_layers - info about the architecture
FullGraphMultiTaskNetwork - Main class for the GNN layers in
BaseGraphStructure
The parameters allow to chose the feature size, the depth, the skip connections, the pooling and the virtual node. It also support different GNN layers such as GatedGCNPyg, GINConvPyg, GINEConvPyg, GPSLayerPyg, MPNNPlusPyg.
print_config_with_key(yaml_config, "architecture")
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn:
out_dim: 64
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: 0.18
normalization: layer_norm
last_normalization: ${architecture.pre_nn.normalization}
residual_type: none
pre_nn_edges: null
pe_encoders:
out_dim: 32
pool: sum
last_norm: None
encoders:
la_pos:
encoder_type: laplacian_pe
input_keys:
- laplacian_eigvec
- laplacian_eigval
output_keys:
- feat
hidden_dim: 64
out_dim: 32
model_type: DeepSet
num_layers: 2
num_layers_post: 1
dropout: 0.1
first_normalization: none
rw_pos:
encoder_type: mlp
input_keys:
- rw_return_probs
output_keys:
- feat
hidden_dim: 64
out_dim: 32
num_layers: 2
dropout: 0.1
normalization: layer_norm
first_normalization: layer_norm
gnn:
in_dim: 64
out_dim: 96
hidden_dims: 96
depth: 4
activation: gelu
last_activation: none
dropout: 0.1
normalization: layer_norm
last_normalization: ${architecture.pre_nn.normalization}
residual_type: simple
virtual_node: none
layer_type: pyg:gcn
layer_kwargs: null
graph_output_nn:
graph:
pooling:
- sum
out_dim: 96
hidden_dims: 96
depth: 1
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: none
residual_type: none
task_heads:
qm9:
task_level: graph
out_dim: 19
hidden_dims: 128
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: none
residual_type: none
tox21:
task_level: graph
out_dim: 12
hidden_dims: 64
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: none
residual_type: none
zinc:
task_level: graph
out_dim: 3
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: ${architecture.pre_nn.dropout}
normalization: ${architecture.pre_nn.normalization}
last_normalization: none
residual_type: none
Predictor¶
In the predictor, we define the loss functions, the metrics to track on the progress bar, and all the parameters necessary for the optimizer.
print_config_with_key(yaml_config, "predictor")
predictor:
metrics_on_progress_bar:
qm9:
- mae
tox21:
- auroc
zinc:
- mae
loss_fun:
qm9: mae_ipu
tox21: bce_logits_ipu
zinc: mae_ipu
random_seed: ${constants.seed}
optim_kwargs:
lr: 4.0e-05
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: ${constants.max_epochs}
warmup_epochs: 10
verbose: false
scheduler_kwargs: null
target_nan_mask: null
multitask_handling: flatten
metrics_every_n_train_steps: 300
Metrics¶
All the metrics can be defined there. If we want to use a classification metric, we can also define a threshold.
See class graphium.trainer.metrics.MetricWrapper for more details.
print_config_with_key(yaml_config, "metrics")
metrics:
qm9:
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2_score
metric: r2_score_ipu
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
tox21:
- name: auroc
metric: auroc_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
task: binary
multitask_handling: mean-per-label
threshold_kwargs: null
- name: f1 > 0.5
metric: f1
multitask_handling: mean-per-label
target_to_int: true
num_classes: 2
average: micro
threshold_kwargs:
operator: greater
threshold: 0.5
th_on_preds: true
th_on_target: true
- name: precision > 0.5
metric: precision
multitask_handling: mean-per-label
average: micro
threshold_kwargs:
operator: greater
threshold: 0.5
th_on_preds: true
th_on_target: true
zinc:
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2_score
metric: r2_score_ipu
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
Trainer¶
Finally, the Trainer defines the parameters for the number of epochs to train, the checkpoints, and the patience.
print_config_with_key(yaml_config, "trainer")
trainer:
seed: ${constants.seed}
model_checkpoint:
filename: ${constants.name}
save_last: true
dirpath: models_checkpoints/neurips2023-small-gcn/
trainer:
precision: 32
max_epochs: ${constants.max_epochs}
min_epochs: 1
check_val_every_n_epoch: 20
accumulate_grad_batches: 1
Training the model¶
Now that we defined all the configuration files, we want to train the model. The steps are fairly easy using the config loaders, and are given below.
First make sure the dataset file is downloaded. Using config_gps_10M_pcqm4m.yaml as an example, make sure the file specified by df_path in the config is available.
In this case, we need to download pcqm4mv2-20k.csv into the specified directory graphium/data/PCQM4M/pcqm4mv2-20k.csv.
After that, we can simply run a training through the CLI:
graphium-train