Building and training on IPU from configurations¶
This tutorial will walk you through how to use a configuration file to define all the parameters of a model and of the trainer. This tutorial focuses on training from SMILES data in a CSV format.
There are multiple examples of YAML files located in the folder graphium/expts that one can refer to when training a new model. The file docs/tutorials/model_training/config_ipu_tutorials.yaml shows an example of multi-task regression from a CSV file provided by graphium.
If you are not familiar with PyTorch or PyTorch-Lightning, we highly recommend going through their tutorial first.
The ipu config file¶
The IPU config file can be found at expts/configs/ipu.config. And is given below.
from os.path import dirname, join
import graphium
GRAPHIUM_PATH = dirname(dirname(graphium.__file__))
with open(join(GRAPHIUM_PATH, "expts/configs/ipu.config")) as f:
lines = f.readlines()
print("".join(lines))
deviceIterations(16)
replicationFactor(1)
Training.gradientAccumulation(1) # How many gradient steps to accumulate
modelName("reproduce_mixed") # TODO: Use the name from the main file, and add date/time
enableProfiling("pro_vision") # The folder where the profile will be stored
Jit.traceModel(True)
enableExecutableCaching("pop_compiler_cache")
Creating the yaml file¶
The first step is to create a YAML file containing all the required configurations, with an example given at graphium/docs/tutorials/model_training/config_ipu_tutorials.yaml. We will go through each part of the configurations.
import yaml
import omegaconf
def print_config_with_key(config, key):
new_config = {key: config[key]}
print(omegaconf.OmegaConf.to_yaml(new_config))
# First, let's read the yaml configuration file
with open("config_ipu_tutorials.yaml", "r") as file:
yaml_config = yaml.load(file, Loader=yaml.FullLoader)
print("Yaml file loaded")
Yaml file loaded
Constants¶
First, we define the constants such as the random seed and whether the model should raise or ignore an error.
The name here will be used to log the metrics into WandB and log the models.
The accelerator is used to define the device. It supports cpu, ipu or gpu. This is the only part that needs to change when working with IPUs.
print_config_with_key(yaml_config, "constants")
constants:
name: tutorial_model
seed: 42
raise_train_error: true
accelerator:
type: ipu
Datamodule¶
Here, we define all the parameters required by the datamodule to run correctly, such as the dataset path, whether to cache, the columns for the training, the molecular featurization to use, the train/val/test splits and the batch size.
The MultitaskFromSmilesDataModule allows us to define a set of tasks within different CSV or parquet files, and use them to train the model simultaneously via the path datamodule: args: task_specific_args
For more details, see class graphium.data.datamodule.MultitaskFromSmilesDataModule
Reading a CSV file and train/val/test splits¶
Here is an example of configuration regarding the task named "homo", which reads the file located in https://storage.googleapis.com/graphium-public/datasets/QM9/norm_mini_qm9.csv , selects the columns "homo" and "lumo", and splits into validation and test set with rations of 20% each
print_config_with_key(yaml_config["datamodule"]["args"]["task_specific_args"], "homo")
homo: df: null df_path: https://storage.googleapis.com/graphium-public/datasets/QM9/norm_micro_qm9.csv smiles_col: smiles label_cols: - homo - lumo split_val: 0.2 split_test: 0.2 split_seed: 42 splits_path: null sample_size: null idx_col: null weights_col: null weights_type: null
Featurizing a molecule¶
Molecules can be featurized using various properties and positional / structural encoding. A list of all features is available here:
graphium.features.featurizer.get_mol_atomic_features_onehotgraphium.features.featurizer.get_mol_atomic_features_floatgraphium.features.featurizer.get_mol_edge_featuresgraphium.features.spectral.compute_laplacian_positional_eigvecsgraphium.features.rw.compute_rwse
Example of a configuration for featurization below. Notice the list of atomic and edge properties. Notice pos_encoding_as_features that defines both the laplacian and random-walk positional encodings.
print_config_with_key(yaml_config["datamodule"]["args"], "featurization")
featurization:
atom_property_list_onehot:
- atomic-number
- valence
atom_property_list_float:
- mass
- electronegativity
- in-ring
edge_property_list:
- bond-type-onehot
- stereo
- in-ring
add_self_loop: false
explicit_H: false
use_bonds_weights: false
pos_encoding_as_features:
pos_types:
la_pos:
pos_type: laplacian_eigvec_eigval
num_pos: 3
normalization: none
disconnected_comp: true
rw_pos:
pos_type: rwse
ksteps: 16
Architecture¶
In the architecture, we define all the layers for the model, including the layers for the pre-processing MLP (input layers pre-nn), the post-processing MLP (output layers post-nn), and the main GNN (graph neural network gnn).
The parameters allow to chose the feature size, the depth, the skip connections, the pooling and the virtual node. It also support different GNN layers such as pyg:gcn, pyg:gin, pyg:gine, pyg:gated-gcn, pyg:pna-msgpass, and pyg:gps.
For more details, see the following classes:
graphium.nn.global_architecture.FullGraphMultiTaskNetwork: Main class for the architecturegraphium.nn.global_architecture.FeedForwardNN: Main class for the inputs and outputs MLPgraphium.nn.pyg_architecture.FeedForwardPyg: Main class for the GNN layers
Parameters for the node pre-processing NN¶
print_config_with_key(yaml_config["architecture"], "pre_nn")
pre_nn: out_dim: 32 hidden_dims: 32 depth: 1 activation: relu last_activation: none dropout: 0.1 normalization: none last_normalization: none residual_type: none
Parameters for the edge pre-processing NN¶
print_config_with_key(yaml_config["architecture"], "pre_nn_edges")
pre_nn_edges: out_dim: 16 hidden_dims: 16 depth: 1 activation: relu last_activation: none dropout: 0.1 normalization: none last_normalization: none residual_type: none
Parameters for the GNN¶
Here is an example of a GraphGPS layer, with it's MPNN being a GINE model
print_config_with_key(yaml_config["architecture"], "gnn")
gnn:
out_dim: 32
hidden_dims: 32
depth: 3
activation: relu
last_activation: none
dropout: 0.1
normalization: none
last_normalization: none
residual_type: simple
pooling:
- sum
- mean
- max
virtual_node: none
layer_type: pyg:gps
layer_kwargs:
mpnn_type: pyg:gine
mpnn_kwargs: null
attn_type: full-attention
attn_kwargs: null
Parameters for the node post-processing NN (after the GNN)¶
print_config_with_key(yaml_config["architecture"], "graph_output_nn")
graph_output_nn: out_dim: 32 hidden_dims: 32 depth: 1 activation: relu last_activation: none dropout: 0.1 normalization: none last_normalization: none residual_type: none
Parameters for the multi-task output heads¶
Here is the example for the task heads. Notice that "task_name" should match the tasks in the section datamodule: args: task_specific_args.
print_config_with_key(yaml_config["architecture"], "task_heads")
task_heads: - task_name: homo out_dim: 2 hidden_dims: 32 depth: 2 activation: relu last_activation: none dropout: 0.1 normalization: none last_normalization: none residual_type: none - task_name: alpha out_dim: 1 hidden_dims: 32 depth: 2 activation: relu last_activation: none dropout: 0.1 normalization: none last_normalization: none residual_type: none - task_name: cv out_dim: 1 hidden_dims: 32 depth: 2 activation: relu last_activation: none dropout: 0.1 normalization: none last_normalization: none residual_type: none
Predictor¶
In the predictor, we define the loss functions, the metrics to track on the progress bar, and all the parameters necessary for the optimizer.
Again, each of these arguments depend on the task. And the task_name should match the ones from datamodule: args: task_specific_args
print_config_with_key(yaml_config, "predictor")
predictor:
metrics_on_progress_bar:
homo:
- mae
- pearsonr
alpha:
- mae
cv:
- mae
- pearsonr
loss_fun:
homo: mse_ipu
alpha: mse_ipu
cv: mse_ipu
random_seed: 42
optim_kwargs:
lr: 0.001
torch_scheduler_kwargs: null
scheduler_kwargs: null
target_nan_mask: null
flag_kwargs:
n_steps: 0
alpha: 0.0
Metrics¶
All the metrics can be defined there. If we want to use a classification metric, we can also define a threshold.
See class graphium.trainer.metrics.MetricWrapper for more details.
See graphium.trainer.metrics.METRICS_CLASSIFICATION and graphium.trainer.metrics.METRICS_REGRESSION for a dictionnary of accepted metrics.
Again, the metrics are task-dependant and must match the names in datamodule: args: task_specific_args
print_config_with_key(yaml_config, "metrics")
metrics:
homo:
- name: mae
metric: mae
threshold_kwargs: null
- name: pearsonr
metric: pearsonr
threshold_kwargs: null
target_nan_mask: ignore-mean-label
alpha:
- name: mae
metric: mae
threshold_kwargs: null
- name: pearsonr
metric: pearsonr
threshold_kwargs: null
cv:
- name: mae
metric: mae
threshold_kwargs: null
- name: pearsonr
metric: pearsonr
threshold_kwargs: null
Trainer¶
Finally, the Trainer defines the parameters for the number of epochs to train, the checkpoints, and the patience.
print_config_with_key(yaml_config, "trainer")
trainer:
logger:
save_dir: logs/QM9
name: tutorial_model
model_checkpoint:
dirpath: models_checkpoints/QM9/
filename: tutorial_model
save_top_k: 1
every_n_epochs: 1
trainer:
precision: 32
max_epochs: 5
min_epochs: 1
# General imports
import os
from os.path import dirname, abspath
import yaml
from copy import deepcopy
from omegaconf import DictConfig
import timeit
from loguru import logger
from pytorch_lightning.utilities.model_summary import ModelSummary
# Current project imports
import graphium
from graphium.config._loader import load_datamodule, load_metrics, load_architecture, load_predictor, load_trainer
from graphium.utils.safe_run import SafeRun
# WandB
import wandb
Then, let's load the configuration file¶
# Set up the working directory
MAIN_DIR = dirname(dirname(abspath(graphium.__file__)))
CONFIG_FILE = "docs/tutorials/model_training/config_ipu_tutorials.yaml"
os.chdir(MAIN_DIR)
with open(os.path.join(MAIN_DIR, CONFIG_FILE), "r") as f:
cfg = yaml.safe_load(f)
Now let's process the data¶
# Load and initialize the dataset
datamodule = load_datamodule(cfg)
datamodule.prepare_data()
[17:16:35.918] [poptorch::python] [warning] trace_model=True is deprecated since version 3.0 and will be removed in a future release 2022-09-20 17:16:35.920 | INFO | graphium.data.datamodule:prepare_data:699 - Reading data for task 'homo' 2022-09-20 17:16:36.462 | INFO | graphium.data.datamodule:prepare_data:699 - Reading data for task 'alpha' 2022-09-20 17:16:36.585 | INFO | graphium.data.datamodule:prepare_data:699 - Reading data for task 'cv' 2022-09-20 17:16:36.707 | INFO | graphium.data.datamodule:prepare_data:721 - Done reading datasets 2022-09-20 17:16:36.707 | INFO | graphium.data.datamodule:prepare_data:733 - Prepare single-task dataset for task 'homo' with 1005 data points. 2022-09-20 17:16:36.708 | INFO | graphium.data.datamodule:prepare_data:733 - Prepare single-task dataset for task 'alpha' with 1005 data points. 2022-09-20 17:16:36.709 | INFO | graphium.data.datamodule:prepare_data:733 - Prepare single-task dataset for task 'cv' with 1005 data points.
mols to ids: 0%| | 0/3015 [00:00<?, ?it/s]
featurizing_smiles: 0%| | 0/1005 [00:00<?, ?it/s]
Let's build the architecture, metrics, and set-up the trainer¶
# Initialize the network
model_class, model_kwargs = load_architecture(
cfg,
in_dims=datamodule.in_dims,
)
metrics = load_metrics(cfg)
logger.info(metrics)
predictor = load_predictor(cfg, model_class, model_kwargs, metrics)
predictor.set_max_nodes_edges_per_graph(datamodule, stages=["train", "val"])
logger.info(predictor.model)
logger.info(ModelSummary(predictor, max_depth=4))
trainer = load_trainer(cfg, "tutorial-run")
2022-09-20 17:16:46.768 | INFO | __main__:<cell line: 8>:8 - {'homo': {'mae': mean_absolute_error, 'pearsonr': pearson_corrcoef}, 'alpha': {'mae': mean_absolute_error, 'pearsonr': pearson_corrcoef}, 'cv': {'mae': mean_absolute_error, 'pearsonr': pearson_corrcoef}}
2022-09-20 17:16:46.812 | INFO | __main__:<cell line: 12>:12 - Multitask_GNN
---------------
pre-NN(depth=1, ResidualConnectionNone)
[FCLayer[87 -> 32]
pre-NN-edges(depth=1, ResidualConnectionNone)
[FCLayer[13 -> 16]
GNN(depth=3, ResidualConnectionSimple(skip_steps=1))
GPSLayerPyg[32 -> 32 -> 32 -> 32]
-> Pooling(['sum', 'mean', 'max']) -> FCLayer(96 -> 32, activation=None)
post-NN(depth=1, ResidualConnectionNone)
[FCLayer[32 -> 32]
2022-09-20 17:16:46.813 | INFO | __main__:<cell line: 13>:13 - | Name | Type | Params
-----------------------------------------------------------------------------------
0 | model | FullGraphMultiTaskNetwork | 47.6 K
1 | model.pe_encoders | ModuleDict | 681
2 | model.pe_encoders.la_pos | LapPENodeEncoder | 137
3 | model.pe_encoders.la_pos.linear_A | Linear | 9
4 | model.pe_encoders.la_pos.pe_encoder | MLP | 128
5 | model.pe_encoders.rw_pos | MLPEncoder | 544
6 | model.pe_encoders.rw_pos.pe_encoder | MLP | 544
7 | model.pre_nn | FeedForwardNN | 2.8 K
8 | model.pre_nn.activation | ReLU | 0
9 | model.pre_nn.residual_layer | ResidualConnectionNone | 0
10 | model.pre_nn.layers | ModuleList | 2.8 K
11 | model.pre_nn.layers.0 | FCLayer | 2.8 K
12 | model.pre_nn_edges | FeedForwardNN | 224
13 | model.pre_nn_edges.activation | ReLU | 0
14 | model.pre_nn_edges.residual_layer | ResidualConnectionNone | 0
15 | model.pre_nn_edges.layers | ModuleList | 224
16 | model.pre_nn_edges.layers.0 | FCLayer | 224
17 | model.gnn | FeedForwardPyg | 39.5 K
18 | model.gnn.activation | ReLU | 0
19 | model.gnn.layers | ModuleList | 36.4 K
20 | model.gnn.layers.0 | GPSLayerPyg | 12.1 K
21 | model.gnn.layers.1 | GPSLayerPyg | 12.1 K
22 | model.gnn.layers.2 | GPSLayerPyg | 12.1 K
23 | model.gnn.virtual_node_layers | ModuleList | 0
24 | model.gnn.virtual_node_layers.0 | VirtualNodePyg | 0
25 | model.gnn.virtual_node_layers.1 | VirtualNodePyg | 0
26 | model.gnn.residual_layer | ResidualConnectionSimple | 0
27 | model.gnn.global_pool_layer | ModuleListConcat | 0
28 | model.gnn.global_pool_layer.0 | PoolingWrapperPyg | 0
29 | model.gnn.global_pool_layer.1 | PoolingWrapperPyg | 0
30 | model.gnn.global_pool_layer.2 | PoolingWrapperPyg | 0
31 | model.gnn.out_linear | FCLayer | 3.1 K
32 | model.gnn.out_linear.linear | Linear | 3.1 K
33 | model.gnn.out_linear.dropout | Dropout | 0
34 | model.graph_output_nn | FeedForwardNN | 1.1 K
35 | model.graph_output_nn.activation | ReLU | 0
36 | model.graph_output_nn.residual_layer | ResidualConnectionNone | 0
37 | model.graph_output_nn.layers | ModuleList | 1.1 K
38 | model.graph_output_nn.layers.0 | FCLayer | 1.1 K
39 | model.task_heads | TaskHeads | 3.3 K
40 | model.task_heads.task_heads | ModuleDict | 3.3 K
41 | model.task_heads.task_heads.homo | TaskHead | 1.1 K
42 | model.task_heads.task_heads.alpha | TaskHead | 1.1 K
43 | model.task_heads.task_heads.cv | TaskHead | 1.1 K
-----------------------------------------------------------------------------------
47.6 K Trainable params
0 Non-trainable params
47.6 K Total params
0.190 Total estimated model params size (MB)
[17:16:46.843] [poptorch::python] [warning] trace_model=True is deprecated since version 3.0 and will be removed in a future release
/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ipu.py:20: LightningDeprecationWarning: The `pl.plugins.training_type.ipu.IPUPlugin` is deprecated in v1.6 and will be removed in v1.8. Use `pl.strategies.ipu.IPUStrategy` instead.
rank_zero_deprecation(
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: dbeaini_mila (multitask-gnn). Use `wandb login --relogin` to force relogin
/home/dom/graphium/wandb/run-20220920_171648-1ss172i9
/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:317: LightningDeprecationWarning: Passing <graphium.ipu.ipu_wrapper.IPUPluginGraphium object at 0x7f4431dfc220> `strategy` to the `plugins` flag in Trainer has been deprecated in v1.5 and will be removed in v1.7. Use `Trainer(strategy=<graphium.ipu.ipu_wrapper.IPUPluginGraphium object at 0x7f4431dfc220>)` instead.
rank_zero_deprecation(
/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:469: UserWarning: more than one device specific flag has been set
rank_zero_warn("more than one device specific flag has been set")
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: True, using: 1 IPUs
HPU available: False, using: 0 HPUs
Finally, let's run the model¶
# Run the model training
with SafeRun(name="TRAINING", raise_error=cfg["constants"]["raise_train_error"], verbose=True):
trainer.fit(model=predictor, datamodule=datamodule)
# Exit WandB
wandb.finish()
/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:326: LightningDeprecationWarning: Base `LightningModule.on_train_batch_end` hook signature has changed in v1.5. The `dataloader_idx` argument will be removed in v1.7. rank_zero_deprecation( /home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:154: LightningDeprecationWarning: The `LightningModule.get_progress_bar_dict` method was deprecated in v1.5 and will be removed in v1.7. Please use the `ProgressBarBase.get_metrics` instead. rank_zero_deprecation(
mols to ids: 0%| | 0/1809 [00:00<?, ?it/s]
mols to ids: 0%| | 0/603 [00:00<?, ?it/s]
/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:611: UserWarning: Checkpoint directory /home/dom/graphium/models_checkpoints/QM9 exists and is not empty.
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
| Name | Type | Params
----------------------------------------------------
0 | model | FullGraphMultiTaskNetwork | 47.6 K
----------------------------------------------------
47.6 K Trainable params
0 Non-trainable params
47.6 K Total params
0.190 Total estimated model params size (MB)
------------------- MultitaskDataset about = training set num_graphs_total = 603 num_nodes_total = 5285 max_num_nodes_per_graph = 9 min_num_nodes_per_graph = 1 std_num_nodes_per_graph = 0.6989218547197285 mean_num_nodes_per_graph = 8.764510779436153 num_edges_total = 11316 max_num_edges_per_graph = 26 min_num_edges_per_graph = 0 std_num_edges_per_graph = 2.6821780989470896 mean_num_edges_per_graph = 18.766169154228855 ------------------- ------------------- MultitaskDataset about = validation set num_graphs_total = 201 num_nodes_total = 1756 max_num_nodes_per_graph = 9 min_num_nodes_per_graph = 2 std_num_nodes_per_graph = 0.7949619835770694 mean_num_nodes_per_graph = 8.7363184079602 num_edges_total = 3736 max_num_edges_per_graph = 24 min_num_edges_per_graph = 2 std_num_edges_per_graph = 2.7704251592456193 mean_num_edges_per_graph = 18.587064676616915 -------------------
Sanity Checking: 0it [00:00, ?it/s]
2022-09-20 17:16:59.088 | INFO | graphium.ipu.ipu_dataloader:create_ipu_dataloader:230 - Estimating pack max_pack_size=54 or max_pack_size_per_graph=9.0 2022-09-20 17:16:59.089 | INFO | graphium.ipu.ipu_dataloader:create_ipu_dataloader:233 - Provided `max_num_nodes=72` [17:17:02.052] [poptorch:cpp] [warning] Graph contains an unused input %this_input : Long(1, strides=[1], requires_grad=0, device=cpu) [17:17:02.054] [poptorch:cpp] [warning] %value.3 : Long(2, 144, strides=[144, 1], requires_grad=0, device=cpu) = aten::to(%tensor.1, %111, %112, %113, %114) # /home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/poptorch/_poplar_executor.py:1315:0: torch.int64 is not supported natively on IPU, loss of range/precision may occur. We will only warn on the first instance. Graph compilation: 100%|██████████| 100/100 [00:01<00:00] 2022-09-20 17:17:07.541 | INFO | graphium.ipu.ipu_dataloader:create_ipu_dataloader:230 - Estimating pack max_pack_size=54 or max_pack_size_per_graph=9.0 2022-09-20 17:17:07.542 | INFO | graphium.ipu.ipu_dataloader:create_ipu_dataloader:233 - Provided `max_num_nodes=72` /home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1933: PossibleUserWarning: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch. rank_zero_warn(
Training: 0it [00:00, ?it/s]
/home/dom/graphium/graphium/ipu/ipu_wrapper.py:274: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if all([this_input.shape[0] == 1 for this_input in inputs]):
/home/dom/graphium/graphium/ipu/ipu_wrapper.py:307: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
batch_idx = batch.pop("_batch_idx").item()
/home/dom/graphium/graphium/nn/base_layers.py:169: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
if torch.prod(torch.as_tensor(h.shape[:-1])) == 0:
/home/dom/graphium/graphium/nn/base_layers.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if torch.prod(torch.as_tensor(h.shape[:-1])) == 0:
/home/dom/graphium/graphium/nn/architectures/pyg_architectures.py:139: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert node_feats.shape[0] == g.num_nodes
/home/dom/graphium/graphium/nn/architectures/pyg_architectures.py:154: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert edge_feats.shape[0] == g.num_edges
/home/dom/graphium/graphium/nn/architectures/global_architectures.py:1190: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
if torch.prod(torch.as_tensor(e.shape[:-1])) == 0:
/home/dom/graphium/graphium/nn/architectures/global_architectures.py:1190: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if torch.prod(torch.as_tensor(e.shape[:-1])) == 0:
/home/dom/graphium/graphium/nn/base_graph_layer.py:277: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert edge_index.min() >= 0
/home/dom/graphium/graphium/nn/base_graph_layer.py:278: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert edge_index.max() < torch.iinfo(edge_index.dtype).max
/home/dom/graphium/graphium/nn/base_graph_layer.py:281: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert edge_index.size(0) == 2
/home/dom/graphium/graphium/nn/pyg_layers/gps_pyg.py:144: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
on_ipu = ("graph_is_true" in batch.keys) and (not batch.graph_is_true.all())
/home/dom/graphium/graphium/nn/pyg_layers/gps_pyg.py:146: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
max_num_nodes_per_graph = batch.dataset_max_nodes_per_graph[0].item()
/home/dom/graphium/graphium/ipu/to_dense_batch.py:57: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
batch_size = int(batch.max()) + 1
/home/dom/graphium/graphium/ipu/to_dense_batch.py:80: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert (
/home/dom/graphium/graphium/ipu/ipu_wrapper.py:20: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if targets[task].shape == preds[task].shape:
/home/dom/graphium/graphium/trainer/predictor.py:422: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
total_norm = torch.tensor(0.0)
/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/wandb/util.py:604: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
return obj.item(), True
[17:17:11.240] [poptorch:cpp] [warning] Graph contains an unused input %this_input : Long(1, strides=[1], requires_grad=0, device=cpu)
Graph compilation: 100%|██████████| 100/100 [00:02<00:00]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…
Run history:
| alpha/MSELossIPU/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| alpha/MSELossIPU/val | ▄█▃▄▁ |
| alpha/loss/val | ▄█▃▄▁ |
| alpha/mae/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| alpha/mae/val | ▅█▃▄▁ |
| alpha/mean_pred/train | █▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| alpha/mean_pred/val | ▄█▂▂▁ |
| alpha/mean_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| alpha/mean_target/val | ▁▁▁▁▁ |
| alpha/median_pred/train | ██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| alpha/median_pred/val | ▁█▃▃▁ |
| alpha/median_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| alpha/median_target/val | ▁▁▁▁▁ |
| alpha/pearsonr/train | ▁█▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆ |
| alpha/pearsonr/val | ▁▇▆▇█ |
| alpha/std_pred/train | █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| alpha/std_pred/val | █▃▂▁▃ |
| alpha/std_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| alpha/std_target/val | ▁▁▁▁▁ |
| cv/MSELossIPU/train | ▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| cv/MSELossIPU/val | ▄█▄▃▁ |
| cv/loss/val | ▄█▄▃▁ |
| cv/mae/train | █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| cv/mae/val | ▅█▄▃▁ |
| cv/mean_pred/train | █▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| cv/mean_pred/val | ▁█▅▄▂ |
| cv/mean_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| cv/mean_target/val | ▁▁▁▁▁ |
| cv/median_pred/train | █▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| cv/median_pred/val | ▁█▄▃▂ |
| cv/median_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| cv/median_target/val | ▁▁▁▁▁ |
| cv/pearsonr/train | █▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| cv/pearsonr/val | ▁▇▆██ |
| cv/std_pred/train | █▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| cv/std_pred/val | █▄▄▁▅ |
| cv/std_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| cv/std_target/val | ▁▁▁▁▁ |
| epoch | ▁▃▅▆█ |
| grad_norm | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| homo/MSELossIPU/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| homo/MSELossIPU/val | █▅▁▂▂ |
| homo/loss/val | █▅▁▂▂ |
| homo/mae/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| homo/mae/val | █▅▁▃▃ |
| homo/mean_pred/train | ▁▄███████████████████████████████ |
| homo/mean_pred/val | ▁█▄▄▄ |
| homo/mean_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| homo/mean_target/val | ▁▁▁▁▁ |
| homo/median_pred/train | █▁▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅ |
| homo/median_pred/val | ▁▅▄▇█ |
| homo/median_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| homo/median_target/val | ▁▁▁▁▁ |
| homo/pearsonr/train | █▁▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆ |
| homo/pearsonr/val | ███▅▁ |
| homo/std_pred/train | █▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| homo/std_pred/val | █▃▂▁▂ |
| homo/std_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| homo/std_target/val | ▁▁▁▁▁ |
| loss | ▁▁▁ |
| loss/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| loss/val | ▅█▃▃▁ |
| train/grad_norm | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| trainer/global_step | ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▆▇▇▇████ |
Run summary:
| alpha/MSELossIPU/train | 10.27925 |
| alpha/MSELossIPU/val | 0.90438 |
| alpha/loss/val | 0.90438 |
| alpha/mae/train | 1.92692 |
| alpha/mae/val | 0.5743 |
| alpha/mean_pred/train | -0.0 |
| alpha/mean_pred/val | 0.05032 |
| alpha/mean_target/train | -1.47591 |
| alpha/mean_target/val | -0.05393 |
| alpha/median_pred/train | -0.0 |
| alpha/median_pred/val | 0.09372 |
| alpha/median_target/train | -0.81494 |
| alpha/median_target/val | -0.04535 |
| alpha/pearsonr/train | 0.24189 |
| alpha/pearsonr/val | 0.70787 |
| alpha/std_pred/train | 0.0 |
| alpha/std_pred/val | 0.63476 |
| alpha/std_target/train | 3.11787 |
| alpha/std_target/val | 1.28429 |
| cv/MSELossIPU/train | 7.34846 |
| cv/MSELossIPU/val | 0.71549 |
| cv/loss/val | 0.71549 |
| cv/mae/train | 1.7832 |
| cv/mae/val | 0.59933 |
| cv/mean_pred/train | -0.0 |
| cv/mean_pred/val | 0.02737 |
| cv/mean_target/train | -1.33105 |
| cv/mean_target/val | -0.08344 |
| cv/median_pred/train | -0.0 |
| cv/median_pred/val | 0.08583 |
| cv/median_target/train | -0.73486 |
| cv/median_target/val | -0.0361 |
| cv/pearsonr/train | -0.84246 |
| cv/pearsonr/val | 0.61503 |
| cv/std_pred/train | 0.0 |
| cv/std_pred/val | 0.60702 |
| cv/std_target/train | 2.58691 |
| cv/std_target/val | 1.06456 |
| epoch | 4 |
| grad_norm | 0.0 |
| homo/MSELossIPU/train | 3.00724 |
| homo/MSELossIPU/val | 1.10813 |
| homo/loss/val | 1.10813 |
| homo/mae/train | 0.92035 |
| homo/mae/val | 0.81845 |
| homo/mean_pred/train | 0.0 |
| homo/mean_pred/val | 0.02063 |
| homo/mean_target/train | -0.43686 |
| homo/mean_target/val | -0.0837 |
| homo/median_pred/train | 0.0 |
| homo/median_pred/val | 0.00169 |
| homo/median_target/train | 0.04636 |
| homo/median_target/val | -0.10706 |
| homo/pearsonr/train | 0.01569 |
| homo/pearsonr/val | 0.11083 |
| homo/std_pred/train | 0.0 |
| homo/std_pred/val | 0.2037 |
| homo/std_target/train | 1.75284 |
| homo/std_target/val | 1.05716 |
| loss | 6.87832 |
| loss/train | 6.87832 |
| loss/val | 0.90933 |
| train/grad_norm | 0.0 |
| trainer/global_step | 29 |
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)
./wandb/run-20220920_171648-1ss172i9/logs
Testing the model¶
Once the model is trained, we can use the same datamodule to get the results on the test set. Here, ckpt_path refers to the checkpoint path where the model at the best validation step was saved. Thus, the results on the test set represent the early stopping.
All the metrics that were computed on the validation set are then computed on the test set, printed, and saved into the metrics.yaml file.
ckpt_path = trainer.checkpoint_callbacks[0].best_model_path
predictor.set_max_nodes_edges_per_graph(datamodule, stages=['test'])
trainer.test(model=predictor, datamodule=datamodule, ckpt_path=ckpt_path)
/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:154: LightningDeprecationWarning: The `LightningModule.get_progress_bar_dict` method was deprecated in v1.5 and will be removed in v1.7. Please use the `ProgressBarBase.get_metrics` instead. rank_zero_deprecation(
mols to ids: 0%| | 0/603 [00:00<?, ?it/s]
Restoring states from the checkpoint path at /home/dom/graphium/models_checkpoints/QM9/tutorial_model.ckpt Loaded model weights from checkpoint at /home/dom/graphium/models_checkpoints/QM9/tutorial_model.ckpt 2022-09-20 17:18:37.268 | INFO | graphium.ipu.ipu_dataloader:create_ipu_dataloader:230 - Estimating pack max_pack_size=54 or max_pack_size_per_graph=9.0 2022-09-20 17:18:37.268 | INFO | graphium.ipu.ipu_dataloader:create_ipu_dataloader:233 - Provided `max_num_nodes=72`
------------------- MultitaskDataset about = test set num_graphs_total = 201 num_nodes_total = 1751 max_num_nodes_per_graph = 9 min_num_nodes_per_graph = 1 std_num_nodes_per_graph = 0.9443829267824616 mean_num_nodes_per_graph = 8.711442786069652 num_edges_total = 3748 max_num_edges_per_graph = 26 min_num_edges_per_graph = 0 std_num_edges_per_graph = 3.0647412282392468 mean_num_edges_per_graph = 18.64676616915423 -------------------
Testing: 0it [00:00, ?it/s]
[17:18:40.730] [poptorch:cpp] [warning] Graph contains an unused input %this_input : Long(1, strides=[1], requires_grad=0, device=cpu)
Graph compilation: 100%|██████████| 100/100 [00:01<00:00]
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 113, in run
shandler(sreq)
File "/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 172, in server_record_publish
iface = self._mux.get_stream(stream_id).interface
File "/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/wandb/sdk/service/streams.py", line 186, in get_stream
stream = self._streams[stream_id]
KeyError: '1ss172i9'
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Test metric DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
alpha/MSELossIPU/test 1.0208719968795776
alpha/loss/test 1.0208719968795776
alpha/mae/test 0.5369145274162292
alpha/mean_pred/test -0.043575093150138855
alpha/mean_target/test -0.1731886863708496
alpha/median_pred/test 0.020035069435834885
alpha/median_target/test -0.05999755859375
alpha/pearsonr/test 0.6431294679641724
alpha/std_pred/test 0.6613938808441162
alpha/std_target/test 1.2930139303207397
cv/MSELossIPU/test 0.7984611392021179
cv/loss/test 0.7984611392021179
cv/mae/test 0.5717160701751709
cv/mean_pred/test -0.06430188566446304
cv/mean_target/test -0.08974086493253708
cv/median_pred/test 0.014351803809404373
cv/median_target/test -0.06439208984375
cv/pearsonr/test 0.663591742515564
cv/std_pred/test 0.63597571849823
cv/std_target/test 1.180733561515808
homo/MSELossIPU/test 1.14460027217865
homo/loss/test 1.14460027217865
homo/mae/test 0.7968456149101257
homo/mean_pred/test 0.004068718757480383
homo/mean_target/test 0.1092241033911705
homo/median_pred/test 0.0017465166747570038
homo/median_target/test 0.1390380859375
homo/pearsonr/test 0.1387491226196289
homo/std_pred/test 0.2090565264225006
homo/std_target/test 1.07249116897583
loss/test 0.9879778027534485
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[{'homo/mean_pred/test': 0.004068718757480383,
'homo/std_pred/test': 0.2090565264225006,
'homo/median_pred/test': 0.0017465166747570038,
'homo/mean_target/test': 0.1092241033911705,
'homo/std_target/test': 1.07249116897583,
'homo/median_target/test': 0.1390380859375,
'homo/mae/test': 0.7968456149101257,
'homo/pearsonr/test': 0.1387491226196289,
'homo/MSELossIPU/test': 1.14460027217865,
'homo/loss/test': 1.14460027217865,
'alpha/mean_pred/test': -0.043575093150138855,
'alpha/std_pred/test': 0.6613938808441162,
'alpha/median_pred/test': 0.020035069435834885,
'alpha/mean_target/test': -0.1731886863708496,
'alpha/std_target/test': 1.2930139303207397,
'alpha/median_target/test': -0.05999755859375,
'alpha/mae/test': 0.5369145274162292,
'alpha/pearsonr/test': 0.6431294679641724,
'alpha/MSELossIPU/test': 1.0208719968795776,
'alpha/loss/test': 1.0208719968795776,
'cv/mean_pred/test': -0.06430188566446304,
'cv/std_pred/test': 0.63597571849823,
'cv/median_pred/test': 0.014351803809404373,
'cv/mean_target/test': -0.08974086493253708,
'cv/std_target/test': 1.180733561515808,
'cv/median_target/test': -0.06439208984375,
'cv/mae/test': 0.5717160701751709,
'cv/pearsonr/test': 0.663591742515564,
'cv/MSELossIPU/test': 0.7984611392021179,
'cv/loss/test': 0.7984611392021179,
'loss/test': 0.9879778027534485}]