Building and training a simple model from configurations¶
This tutorial will walk you through how to use a configuration file to define all the parameters of a model and of the trainer. This tutorial focuses on training from SMILES data in a CSV format.
The work flow of testing your code on the entire pipeline is as follows:
- select a corresponding yaml file in the expts/main_run_multitask.py i.e. by
CONFIG_FILE = "expts/configs/config_gps_10M_pcqm4m_mod.yaml"
- modify the yaml config file
python expts/main_run_multitask.py
There are multiple examples of YAML files located in the folder graphium/expts/configs
that one can refer to when training a new model. The file config_gps_10M_pcqm4m_mod.yaml
shows an example of running the GPS model on the pcqm4m dataset.
Creating the yaml file¶
The first step is to create a YAML file containing all the required configurations, with an example given at graphium/expts/config_gps_10M_pcqm4m_mod.yaml
. We will go through each part of the configurations.
import yaml
import omegaconf
def print_config_with_key(config, key):
new_config = {key: config[key]}
print(omegaconf.OmegaConf.to_yaml(new_config))
# First, let's read the yaml configuration file
with open("../../../expts/configs/config_gps_10M_pcqm4m_mod.yaml", "r") as file:
yaml_config = yaml.load(file, Loader=yaml.FullLoader)
print("Yaml file loaded")
Yaml file loaded
Constants¶
First, we define the constants such as the random seed and whether the model should raise or ignore an error.
print_config_with_key(yaml_config, "constants")
constants: name: pcqm4mv2_mpnn_4layer seed: 42 raise_train_error: true accelerator: type: gpu
Datamodule¶
Here, we define all the parameters required by the datamodule to run correctly, such as the dataset path, whether to cache, the columns for the training, the molecular featurization to use, the train/val/test splits and the batch size.
For more details, see class MultitaskFromSmilesDataModule
print_config_with_key(yaml_config, "datamodule")
datamodule: module_type: MultitaskFromSmilesDataModule args: task_specific_args: homolumo: df: null task_level: graph df_path: ~/scratch/data/graphium/data/PCQM4M/pcqm4mv2-20k.csv smiles_col: cxsmiles label_cols: - homo_lumo_gap split_val: 0.1 split_test: 0.1 prepare_dict_or_graph: pyg:graph featurization_n_jobs: 30 featurization_progress: true featurization_backend: loky featurization: mask_nan: 0 atom_property_list_onehot: - atomic-number - group - period - total-valence atom_property_list_float: - degree - formal-charge - radical-electron - aromatic - in-ring edge_property_list: - bond-type-onehot - stereo - in-ring conformer_property_list: - positions_3d add_self_loop: false explicit_H: false use_bonds_weights: false pos_encoding_as_features: pos_types: node_laplacian_eigvec: pos_type: laplacian_eigvec pos_level: node num_pos: 8 normalization: none disconnected_comp: true node_laplacian_eigval: pos_type: laplacian_eigval pos_level: node num_pos: 8 normalization: none disconnected_comp: true rw_return_probs: pos_type: rw_return_probs pos_level: node ksteps: - 4 - 8 nodepair_rw_transition_probs: pos_type: rw_transition_probs pos_level: edge ksteps: - 2 - 4 nodepair_rw_return_probs: pos_type: rw_return_probs pos_level: nodepair ksteps: - 4 electrostatic: pos_type: electrostatic pos_level: node edge_commute: pos_type: commute pos_level: edge nodepair_graphormer: pos_type: graphormer pos_level: nodepair batch_size_training: 64 batch_size_inference: 16 num_workers: 0 persistent_workers: false
Architecture¶
The architecture is based on FullGraphMultiTaskNetwork
.
Here, we define all the layers for the model, including the layers for the pre-processing MLP (input layers pre-nn
and pre_nn_edges
), the positional encoder (pe_encoders
), the post-processing MLP (output layers post-nn
), and the main GNN (graph neural network gnn
).
You can find details in the following:
- info about the positional encoder in
graphium.nn.encoders
- info about the gnn layers in
graphium.nn.pyg_layers
- info about the architecture
FullGraphMultiTaskNetwork
- Main class for the GNN layers in
BaseGraphStructure
The parameters allow to chose the feature size, the depth, the skip connections, the pooling and the virtual node. It also support different GNN layers such as GatedGCNPyg
, GINConvPyg
, GINEConvPyg
, GPSLayerPyg
, MPNNPlusPyg
.
print_config_with_key(yaml_config, "architecture")
architecture: model_type: FullGraphMultiTaskNetwork mup_base_path: null pre_nn: out_dim: 32 hidden_dims: 64 depth: 2 activation: relu last_activation: none dropout: 0.1 normalization: layer_norm last_normalization: layer_norm residual_type: none pre_nn_edges: out_dim: 16 hidden_dims: 32 depth: 2 activation: relu last_activation: none dropout: 0.1 normalization: layer_norm last_normalization: layer_norm residual_type: none pe_encoders: out_dim: 32 edge_out_dim: 16 pool: sum last_norm: None encoders: emb_la_pos: encoder_type: laplacian_pe input_keys: - laplacian_eigvec - laplacian_eigval output_keys: - feat hidden_dim: 32 model_type: DeepSet num_layers: 2 num_layers_post: 1 dropout: 0.1 first_normalization: none emb_rwse: encoder_type: mlp input_keys: - rw_return_probs output_keys: - feat hidden_dim: 32 num_layers: 2 dropout: 0.1 normalization: layer_norm first_normalization: layer_norm emb_electrostatic: encoder_type: mlp input_keys: - electrostatic output_keys: - feat hidden_dim: 32 num_layers: 1 dropout: 0.1 normalization: layer_norm first_normalization: layer_norm emb_edge_rwse: encoder_type: mlp input_keys: - edge_rw_transition_probs output_keys: - edge_feat hidden_dim: 32 num_layers: 1 dropout: 0.1 normalization: layer_norm emb_edge_pes: encoder_type: cat_mlp input_keys: - edge_rw_transition_probs - edge_commute output_keys: - edge_feat hidden_dim: 32 num_layers: 1 dropout: 0.1 normalization: layer_norm gaussian_pos: encoder_type: gaussian_kernel input_keys: - positions_3d output_keys: - feat - nodepair_gaussian_bias_3d num_heads: 2 num_layers: 2 embed_dim: 32 use_input_keys_prefix: false gnn: out_dim: 32 hidden_dims: 32 depth: 4 activation: gelu last_activation: none dropout: 0.0 normalization: layer_norm last_normalization: layer_norm residual_type: simple pooling: - sum virtual_node: none layer_type: pyg:gps layer_kwargs: node_residual: false mpnn_type: pyg:mpnnplus mpnn_kwargs: in_dim: 32 out_dim: 32 in_dim_edges: 16 out_dim_edges: 16 attn_type: full-attention attn_kwargs: num_heads: 2 biased_attention_key: nodepair_gaussian_bias_3d post_nn: null task_heads: homolumo: out_dim: 1 hidden_dims: 256 depth: 2 activation: relu last_activation: none dropout: 0.1 normalization: layer_norm last_normalization: none residual_type: none
Predictor¶
In the predictor, we define the loss functions, the metrics to track on the progress bar, and all the parameters necessary for the optimizer.
print_config_with_key(yaml_config, "predictor")
predictor: metrics_on_progress_bar: homolumo: - mae - pearsonr loss_fun: homolumo: mse_ipu random_seed: 42 optim_kwargs: lr: 0.0004 torch_scheduler_kwargs: module_type: WarmUpLinearLR max_num_epochs: 5 warmup_epochs: 10 verbose: false scheduler_kwargs: null target_nan_mask: null flag_kwargs: n_steps: 0 alpha: 0.0
Metrics¶
All the metrics can be defined there. If we want to use a classification metric, we can also define a threshold.
See class graphium.trainer.metrics.MetricWrapper
for more details.
print_config_with_key(yaml_config, "metrics")
metrics: homolumo: - name: mae metric: mae_ipu target_nan_mask: null multitask_handling: flatten threshold_kwargs: null - name: pearsonr metric: pearsonr_ipu threshold_kwargs: null target_nan_mask: null multitask_handling: mean-per-label
Trainer¶
Finally, the Trainer defines the parameters for the number of epochs to train, the checkpoints, and the patience.
print_config_with_key(yaml_config, "trainer")
trainer: logger: save_dir: logs/PCQMv2 name: pcqm4mv2_mpnn_4layer project: PCQMv2_mpnn model_checkpoint: dirpath: models_checkpoints/PCMQv2/ filename: pcqm4mv2_mpnn_4layer save_top_k: 1 every_n_epochs: 100 trainer: precision: 32 max_epochs: 5 min_epochs: 1 accumulate_grad_batches: 2 check_val_every_n_epoch: 20
Training the model¶
Now that we defined all the configuration files, we want to train the model. The steps are fairly easy using the config loaders, and are given below.
First make sure the dataset file is downloaded.
Using config_gps_10M_pcqm4m.yaml
as an example, if the file at df_path
in the config is downloaded.
In this case, we need to download pcqm4mv2-20k.csv
into the specified directory graphium/data/PCQM4M/pcqm4mv2-20k.csv
$python expts/main_run_multitask.py