Copyright (c) 2023 Graphcore Ltd. All rights reserved.
Multitask Molecular Modelling with Graphium on the IPU¶
Graphium is a high-performance deep learning library specializing in graph representation learning for diverse chemical tasks. Graphium integrates state-of-the-art Graph Neural Network (GNN) architectures and a user-friendly API, enabling the easy construction and training of custom GNN models.
Graphium's distinctive feature is its robust featurization capabilities, enabling the extraction of comprehensive features from molecular structures for a range of applications, including property prediction, virtual screening, and drug discovery.
Applicable to a wide array of chemical tasks like property prediction, toxicity evaluation, molecular generation, and graph regression, Graphium provides efficient tools and models, regardless of whether you're working with simple molecules or intricate chemical graphs.
In this notebook, we'll illustrate Graphium's capabilities using the ToyMix dataset, an amalgamation of the QM9, Tox21, and ZINC12k datasets. We'll delve into the multitask scenario and display how Graphium trains models to concurrently predict multiple properties with chemical awareness.
Summary table¶
Domain | Tasks | Model | Datasets | Workflow | Number of IPUs | Execution time |
---|---|---|---|---|---|---|
Molecules | Multitask | GCN/GIN/GINE | QM9, Zinc, Tox21 | Training, Validation, Inference | recommended: 4x (min: 1x, max: 16x) | 20 mins |
Learning outcomes¶
In this notebook you will learn how to:
- Run multitask training for a variety of GNN models using the Graphium library on IPUs.
- Compare the results to the single task performance
- Learn how to modify the Graphium config files
Links to other resources¶
For this notebook, it is best if you are already familiarity with GNNs and PyTorch Geometric. If you need a quick introduction, we suggest you refer to the Graphcore tutorials for using PyTorch Geometric on the IPU, which cover the basics of GNNs and PyTorch Geometric while introducing you to running models on the IPU.
Dependencies¶
We install the following dependencies for this notebook:
(This may take 5 - 10 minutes the first time.)
!pip install git+https://github.com/datamol-io/graphium.git
!pip install -r requirements.txt
from examples_utils import notebook_logging
%load_ext gc_logger
Running on Paperspace¶
The Paperspace environment lets you run this notebook with no set up. If a problem persists or you want to give us feedback on the content of this notebook, please reach out through our community of developers using Slack or raise a GitHub issue.
Requirements:
- Python packages installed with
pip install -r ./requirements.txt
# General imports
import argparse
import os
from os.path import dirname, abspath
from loguru import logger
from datetime import datetime
from lightning.pytorch.utilities.model_summary import ModelSummary
# Current project imports
import graphium
from graphium.config._loader import (
load_datamodule,
load_metrics,
load_architecture,
load_predictor,
load_trainer,
load_accelerator,
load_yaml_config,
)
from graphium.utils.safe_run import SafeRun
from graphium.utils.command_line_utils import update_config, get_anchors_and_aliases
import poptorch
from tqdm.auto import tqdm
# WandB
import wandb
# Set up the working directory
MAIN_DIR = dirname(dirname(abspath(graphium.__file__)))
poptorch.setLogLevel("ERR")
Dataset: ToyMix Benchmark¶
The ToyMix dataset is an amalgamation of three widely recognized datasets in the field of machine learning for chemistry:
- QM9
- Tox21
- ZINC12K.
Each of these datasets contributes different aspects, ensuring a comprehensive benchmarking tool.
QM9, known in the realm of 3D GNNs, provides 19 graph-level quantum properties tied to the energy-minimized 3D conformation of molecules. Given that these molecules have at most nine heavy atoms, it's simpler and enables fast model iterations, much like the larger proposed quantum datasets, but on a smaller scale.
The Tox21 dataset is notable among researchers in machine learning for drug discovery. It is a multi-label classification task with 12 labels, exhibiting a high degree of missing labels and imbalance towards the negative class. This mirrors the characteristics of larger datasets, specifically in terms of sparsity and imbalance, making it invaluable for learning to handle real-world, skewed datasets.
Finally, the ZINC12K dataset is recognized in the study of GNN expressivity, an essential feature for large-scale data performance. It's included in ToyMix with the expectation that successful performance on this task will correlate well with success on larger datasets.
Multitask learning¶
In multi-task learning, related tasks share information with each other during training, leading to a shared representation. This can often improve the model's performance on individual tasks, as information learned from one task can assist with others. Here the ToyMix dataset provides a sandbox to develop these models for molecular tasks.
Download the dataset and splits¶
The ToyMix dataset requires downloading the three datasets. Here we will create the expected data directory and download both the raw dataset CSV files, and the split files to specify training, validation and test splits.
# And for compatibility with the Paperspace environment variables we will do the following:
dataset_directory = os.getenv("DATASETS_DIR")
if not os.path.exists(dataset_directory + "/data/neurips2023/small-dataset"):
!wget -P $DATASETS_DIR/data/neurips2023/small-dataset/ https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9.csv.gz
!wget -P $DATASETS_DIR/data/neurips2023/small-dataset/ https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz
!wget -P $DATASETS_DIR/data/neurips2023/small-dataset/ https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz
!wget -P $DATASETS_DIR/data/neurips2023/small-dataset/ https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9_random_splits.pt
!wget -P $DATASETS_DIR/data/neurips2023/small-dataset/ https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21_random_splits.pt
!wget -P $DATASETS_DIR/data/neurips2023/small-dataset/ https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt
print("Datasets have been successfully downloaded.")
else:
print("Datasets are already downloaded.")
from IPython.core.display import Image, display
from IPython.display import IFrame, display
display(Image(filename="ToyMix.png", width=600))
The representation of the datasets above shows that QM9 and ZINC are densely populated datasets with regression tasks and each label provided for each molecule. We also see that the Tox21 dataset is sparsely populated with binary labels. The relative sizes of these datasets is important. QM9 is approximately 10x larger than the other datasets, and so we would expect it to dominate the training.
It is interesting to note the differences in the datasets, where the molecules are from fairly different domains in terms of their sizes and composition. Additionally, the quantum properties of molecules calculated using DFT is a very different type of property compared to results obtained from a wet lab, particularly in the content of complex systems.
To get a sense of the dataset overlap, a Uniform Manifold Approximation and Projection (UMAP) is shown below. Here we can see the 2D projection of the molecule similarity, with the purple QM9 results being largely separate from the Tox21 and ZINC results.
The question now is: how much can the large dataset of quantum properties be used to improve the prediction power of smaller datasets?
This approach of using a larger dataset to improve predictions for smaller datasets has yielded impressive results in the world of large language models, but it is unknown whether the same can be done with molecules.
display(Image(filename="UMAP.png", width=600))
Graphium model configuration¶
The Graphium model is made up of the following components:
- Processing input features of the molecules
- A set of encoders used for positional and structural features
- Featurization to produce a PyG Graph dataset
- A stack of GNN layers - the backbone of the model - made up of MPNN and transformer layers
- The output node, graph and edge embeddings. These are processed with MLPs to yield multitask predictions.
This is shown in the flowchart below.
display(Image(filename="flowchart.png", width=900))
The available configs are shown in the cell below, you can swap between them easily by uncommenting the desired line. The first set show the full multitask training on the ToyMix dataset for the three GNN architectures.
The final two configs show the same GCN model trained on only the QM9 and Zinc datasets.
## Multi-Task Training - ToyMix Dataset
CONFIG_FILE = {
"small_gcn": "config_small_gcn.yaml",
"small_gcn_qm9": "config_small_gcn_qm9.yaml",
"small_gcn_zinc": "config_small_gcn_zinc.yaml",
"small_gin": "config_small_gin.yaml",
"small_gine": "config_small_gine.yaml",
}
cfg = load_yaml_config(CONFIG_FILE["small_gcn"], None)
Updating the config¶
The Graphium library is designed to be used with a config file. This tutorial contains three examples (GCN, GIN, and GINE models) for training on the multitask ToyMix dataset.
The configs are split into groups, with the key options in accelerator
, datamodule
, archtecture
, predictor
, metrics
and trainer
. It is recommended to have a look at the files to see how the model is structured.
For some parameters, it makes more sense to set them in the notebook or via the command line arguments. You can also set these values by hand in the notebook, as shown below.
(--predictor.torch_scheduler_kwargs.max_num_epochs=200
will set the number of epochs).
In the cell below, we can set the number of replicas to speed up training. The default value is 4, but if you are using a POD 16, we can increase the replica count for a significantly faster training.
We can also set the number of epochs and learning rate. For the number of epochs, the yaml
config file has been loaded into a dict
, so we can't rely on the anchor and reference system in the YAML config file to keep parameters in sync, so this is done manually here.
cfg["accelerator"]["ipu_config"] = [
"deviceIterations(5)",
"replicationFactor(4)", # <------ 4x IPUs for a Pod-4, increase to 16 if using a Pod-16
"TensorLocations.numIOTiles(128)",
'_Popart.set("defaultBufferingDepth", 128)',
"Precision.enableStochasticRounding(True)",
]
# For Paperspace we need to update the paths to the downloaded data / splits files
cfg["datamodule"]["args"]["task_specific_args"]["qm9"]["df_path"] = (
dataset_directory + "/"
+ cfg["datamodule"]["args"]["task_specific_args"]["qm9"]["df_path"]
)
cfg["datamodule"]["args"]["task_specific_args"]["qm9"]["splits_path"] = (
dataset_directory + "/"
+ cfg["datamodule"]["args"]["task_specific_args"]["qm9"]["splits_path"]
)
cfg["datamodule"]["args"]["task_specific_args"]["tox21"]["df_path"] = (
dataset_directory + "/"
+ cfg["datamodule"]["args"]["task_specific_args"]["tox21"]["df_path"]
)
cfg["datamodule"]["args"]["task_specific_args"]["tox21"]["df_path"] = (
dataset_directory + "/"
+ cfg["datamodule"]["args"]["task_specific_args"]["tox21"]["df_path"]
)
cfg["datamodule"]["args"]["task_specific_args"]["zinc"]["splits_path"] = (
dataset_directory + "/"
+ cfg["datamodule"]["args"]["task_specific_args"]["zinc"]["splits_path"]
)
cfg["datamodule"]["args"]["task_specific_args"]["zinc"]["splits_path"] = (
dataset_directory + "/"
+ cfg["datamodule"]["args"]["task_specific_args"]["zinc"]["splits_path"]
)
Key Hyper-Parameters:¶
Shown here for simplicity - for testing you can change the values here / add other parameters When building a full training script it is recommended to use either the config file or the command line to set values.
cfg["trainer"]["trainer"]["max_epochs"] = 5
cfg["predictor"]["torch_scheduler_kwargs"]["max_num_epochs"] = cfg["trainer"][
"trainer"
]["max_epochs"]
cfg["predictor"]["optim_kwargs"]["lr"] = 4e-5
# Initialize the accelerator
cfg, accelerator_type = load_accelerator(cfg)
wandb.init(mode="disabled")
# Load and initialize the dataset
datamodule = load_datamodule(cfg, accelerator_type)
# Initialize the network
model_class, model_kwargs = load_architecture(
cfg,
in_dims=datamodule.in_dims,
)
First build the predictor - this is the model, look at the printed summary to see the layer types and number of parameters.
datamodule.prepare_data()
metrics = load_metrics(cfg)
logger.info(metrics)
predictor = load_predictor(
cfg,
model_class,
model_kwargs,
metrics,
datamodule.get_task_levels(),
accelerator_type,
datamodule.featurization,
datamodule.task_norms
)
logger.info(predictor.model)
logger.info(ModelSummary(predictor, max_depth=4))
Then build the trainer - which controls the actual training loop.
date_time_suffix = datetime.now().strftime("%d.%m.%Y_%H.%M.%S")
run_name = "main"
trainer = load_trainer(cfg, run_name, accelerator_type, date_time_suffix)
Train the model¶
Now we can start training the model. The training is handled with PyTorch Lightning making it simple to set up the training loop.
This will set up the training loop after calculating the maximum number of nodes and edges in the batch. We use fixed size batches on the IPU so we need to add some padding to account for variable molecule sizes before we compile the model.
The training is displayed in a progress bar.
Exercise for the reader: We currently only run training for 10 epochs to illustrate training the model and so you can see the validation and test output format. Try increasingthe number of epochs for better results. Using around 30 epochs should give results ~95% of the performance of the model trained for a full 300 epochs.
# Determine the max num nodes and edges in training and validation
logger.info("About to set the max nodes etc.")
predictor.set_max_nodes_edges_per_graph(datamodule, stages=["train", "val"])
# Run the model training
with SafeRun(
name="TRAINING", raise_error=cfg["constants"]["raise_train_error"], verbose=True
):
trainer.fit(model=predictor, datamodule=datamodule)
Validation and test dataset splits¶
Finally we can look at the validation and test dataset splits, where the trained model is loaded from the checkpoint and evaluated against the hold out sets.
# Determine the max num nodes and edges in training and validation
predictor.set_max_nodes_edges_per_graph(datamodule, stages=["val"])
# Run the model validation
with SafeRun(
name="VALIDATING", raise_error=cfg["constants"]["raise_train_error"], verbose=True
):
trainer.validate(model=predictor, datamodule=datamodule)
# Run the model testing
with SafeRun(
name="TESTING", raise_error=cfg["constants"]["raise_train_error"], verbose=True
):
trainer.test(
model=predictor,
datamodule=datamodule,
)
Conclusions and next steps¶
We have demonstrated how you can use Graphium for multitask molecular modelling on the IPU. We used the ToyMix dataset, which combines the QM9, Tox21, and ZINC12k datasets to show how Graphium trains models to concurrently predict multiple properties with chemical awareness.
Looking at the validation and test results we can see the impact of the larger QM9 dataset in improving the single task performance of the smaller datasets. The accuracy and MAE scores are much better than we would expect just from training on the small datasets, confirming the training benefit from a multitask setting even on a very short training run. To further verify this we can look at the same model trained for the full number of epochs on the multitask dataset and the individual datasets and compare the final validation results.
You can try out the following:
- Increase the number of epochs for training to around 30. This should yield reasonably good results. If you train for 300 epochs, your results should match the benchmarking results.
- Change the config file to use a different model. GIN will achieve higher accuracy than GCN, and GINE will have a higher accuracy than GIN.
- Adapt the multitask config to only use a single dataset. You can use this to compare the results of the multitask performance with a single task performance. Rremember to remove the relevant sections from the dataset, the losses and the metrics. Example YAML config files showing how to do this for GCN with the QM9 (
config_small_gcn_qm9.yaml
) and ZINC12K (config_small_gcn_zinc.yaml
) datasets are given. - Explore more molecular featurisers in the
molfeat
notebookpytorch_geometric_molfeat.ipynb