Building and training a simple model from configurations¶
This tutorial will walk you through how to use a configuration file to define all the parameters of a model and of the trainer. This tutorial focuses on training from SMILES data in a CSV format.
The work flow of testing your code on the entire pipeline is as follows:
- select a corresponding yaml file in the expts/main_run_multitask.py i.e. by
CONFIG_FILE = "expts/configs/config_gps_10M_pcqm4m_mod.yaml" - modify the yaml config file
python expts/main_run_multitask.py
There are multiple examples of YAML files located in the folder graphium/expts/configs that one can refer to when training a new model. The file config_gps_10M_pcqm4m_mod.yaml shows an example of running the GPS model on the pcqm4m dataset.
Creating the yaml file¶
The first step is to create a YAML file containing all the required configurations, with an example given at graphium/expts/config_gps_10M_pcqm4m_mod.yaml. We will go through each part of the configurations.
import yaml
import omegaconf
def print_config_with_key(config, key):
new_config = {key: config[key]}
print(omegaconf.OmegaConf.to_yaml(new_config))
# First, let's read the yaml configuration file
with open("../../../expts/configs/config_gps_10M_pcqm4m_mod.yaml", "r") as file:
yaml_config = yaml.load(file, Loader=yaml.FullLoader)
print("Yaml file loaded")
Yaml file loaded
Constants¶
First, we define the constants such as the random seed and whether the model should raise or ignore an error.
print_config_with_key(yaml_config, "constants")
constants:
name: pcqm4mv2_mpnn_4layer
seed: 42
raise_train_error: true
accelerator:
type: gpu
Datamodule¶
Here, we define all the parameters required by the datamodule to run correctly, such as the dataset path, whether to cache, the columns for the training, the molecular featurization to use, the train/val/test splits and the batch size.
For more details, see class MultitaskFromSmilesDataModule
print_config_with_key(yaml_config, "datamodule")
datamodule:
module_type: MultitaskFromSmilesDataModule
args:
task_specific_args:
homolumo:
df: null
task_level: graph
df_path: ~/scratch/data/graphium/data/PCQM4M/pcqm4mv2-20k.csv
smiles_col: cxsmiles
label_cols:
- homo_lumo_gap
split_val: 0.1
split_test: 0.1
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: true
featurization_backend: loky
featurization:
mask_nan: 0
atom_property_list_onehot:
- atomic-number
- group
- period
- total-valence
atom_property_list_float:
- degree
- formal-charge
- radical-electron
- aromatic
- in-ring
edge_property_list:
- bond-type-onehot
- stereo
- in-ring
conformer_property_list:
- positions_3d
add_self_loop: false
explicit_H: false
use_bonds_weights: false
pos_encoding_as_features:
pos_types:
node_laplacian_eigvec:
pos_type: laplacian_eigvec
pos_level: node
num_pos: 8
normalization: none
disconnected_comp: true
node_laplacian_eigval:
pos_type: laplacian_eigval
pos_level: node
num_pos: 8
normalization: none
disconnected_comp: true
rw_return_probs:
pos_type: rw_return_probs
pos_level: node
ksteps:
- 4
- 8
nodepair_rw_transition_probs:
pos_type: rw_transition_probs
pos_level: edge
ksteps:
- 2
- 4
nodepair_rw_return_probs:
pos_type: rw_return_probs
pos_level: nodepair
ksteps:
- 4
electrostatic:
pos_type: electrostatic
pos_level: node
edge_commute:
pos_type: commute
pos_level: edge
nodepair_graphormer:
pos_type: graphormer
pos_level: nodepair
batch_size_training: 64
batch_size_inference: 16
num_workers: 0
persistent_workers: false
Architecture¶
The architecture is based on FullGraphMultiTaskNetwork.
Here, we define all the layers for the model, including the layers for the pre-processing MLP (input layers pre-nn and pre_nn_edges), the positional encoder (pe_encoders), the post-processing MLP (output layers post-nn), and the main GNN (graph neural network gnn).
You can find details in the following:
- info about the positional encoder in
graphium.nn.encoders - info about the gnn layers in
graphium.nn.pyg_layers - info about the architecture
FullGraphMultiTaskNetwork - Main class for the GNN layers in
BaseGraphStructure
The parameters allow to chose the feature size, the depth, the skip connections, the pooling and the virtual node. It also support different GNN layers such as GatedGCNPyg, GINConvPyg, GINEConvPyg, GPSLayerPyg, MPNNPlusPyg.
print_config_with_key(yaml_config, "architecture")
architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
pre_nn:
out_dim: 32
hidden_dims: 64
depth: 2
activation: relu
last_activation: none
dropout: 0.1
normalization: layer_norm
last_normalization: layer_norm
residual_type: none
pre_nn_edges:
out_dim: 16
hidden_dims: 32
depth: 2
activation: relu
last_activation: none
dropout: 0.1
normalization: layer_norm
last_normalization: layer_norm
residual_type: none
pe_encoders:
out_dim: 32
edge_out_dim: 16
pool: sum
last_norm: None
encoders:
emb_la_pos:
encoder_type: laplacian_pe
input_keys:
- laplacian_eigvec
- laplacian_eigval
output_keys:
- feat
hidden_dim: 32
model_type: DeepSet
num_layers: 2
num_layers_post: 1
dropout: 0.1
first_normalization: none
emb_rwse:
encoder_type: mlp
input_keys:
- rw_return_probs
output_keys:
- feat
hidden_dim: 32
num_layers: 2
dropout: 0.1
normalization: layer_norm
first_normalization: layer_norm
emb_electrostatic:
encoder_type: mlp
input_keys:
- electrostatic
output_keys:
- feat
hidden_dim: 32
num_layers: 1
dropout: 0.1
normalization: layer_norm
first_normalization: layer_norm
emb_edge_rwse:
encoder_type: mlp
input_keys:
- edge_rw_transition_probs
output_keys:
- edge_feat
hidden_dim: 32
num_layers: 1
dropout: 0.1
normalization: layer_norm
emb_edge_pes:
encoder_type: cat_mlp
input_keys:
- edge_rw_transition_probs
- edge_commute
output_keys:
- edge_feat
hidden_dim: 32
num_layers: 1
dropout: 0.1
normalization: layer_norm
gaussian_pos:
encoder_type: gaussian_kernel
input_keys:
- positions_3d
output_keys:
- feat
- nodepair_gaussian_bias_3d
num_heads: 2
num_layers: 2
embed_dim: 32
use_input_keys_prefix: false
gnn:
out_dim: 32
hidden_dims: 32
depth: 4
activation: gelu
last_activation: none
dropout: 0.0
normalization: layer_norm
last_normalization: layer_norm
residual_type: simple
pooling:
- sum
virtual_node: none
layer_type: pyg:gps
layer_kwargs:
node_residual: false
mpnn_type: pyg:mpnnplus
mpnn_kwargs:
in_dim: 32
out_dim: 32
in_dim_edges: 16
out_dim_edges: 16
attn_type: full-attention
attn_kwargs:
num_heads: 2
biased_attention_key: nodepair_gaussian_bias_3d
post_nn: null
task_heads:
homolumo:
out_dim: 1
hidden_dims: 256
depth: 2
activation: relu
last_activation: none
dropout: 0.1
normalization: layer_norm
last_normalization: none
residual_type: none
Predictor¶
In the predictor, we define the loss functions, the metrics to track on the progress bar, and all the parameters necessary for the optimizer.
print_config_with_key(yaml_config, "predictor")
predictor:
metrics_on_progress_bar:
homolumo:
- mae
- pearsonr
loss_fun:
homolumo: mse_ipu
random_seed: 42
optim_kwargs:
lr: 0.0004
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: 5
warmup_epochs: 10
verbose: false
scheduler_kwargs: null
target_nan_mask: null
flag_kwargs:
n_steps: 0
alpha: 0.0
Metrics¶
All the metrics can be defined there. If we want to use a classification metric, we can also define a threshold.
See class graphium.trainer.metrics.MetricWrapper for more details.
print_config_with_key(yaml_config, "metrics")
metrics:
homolumo:
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
Trainer¶
Finally, the Trainer defines the parameters for the number of epochs to train, the checkpoints, and the patience.
print_config_with_key(yaml_config, "trainer")
trainer:
logger:
save_dir: logs/PCQMv2
name: pcqm4mv2_mpnn_4layer
project: PCQMv2_mpnn
model_checkpoint:
dirpath: models_checkpoints/PCMQv2/
filename: pcqm4mv2_mpnn_4layer
save_top_k: 1
every_n_epochs: 100
trainer:
precision: 32
max_epochs: 5
min_epochs: 1
accumulate_grad_batches: 2
check_val_every_n_epoch: 20
Training the model¶
Now that we defined all the configuration files, we want to train the model. The steps are fairly easy using the config loaders, and are given below.
First make sure the dataset file is downloaded.
Using config_gps_10M_pcqm4m.yaml as an example, if the file at df_path in the config is downloaded.
In this case, we need to download pcqm4mv2-20k.csv into the specified directory graphium/data/PCQM4M/pcqm4mv2-20k.csv
$python expts/main_run_multitask.py