Building and training on IPU from configurations¶
This tutorial will walk you through how to use a configuration file to define all the parameters of a model and of the trainer. This tutorial focuses on training from SMILES data in a CSV format.
There are multiple examples of YAML files located in the folder graphium/expts
that one can refer to when training a new model. The file docs/tutorials/model_training/config_ipu_tutorials.yaml
shows an example of multi-task regression from a CSV file provided by graphium.
If you are not familiar with PyTorch or PyTorch-Lightning, we highly recommend going through their tutorial first.
The ipu config file¶
The IPU config file can be found at expts/configs/ipu.config
. And is given below.
from os.path import dirname, join
import graphium
GRAPHIUM_PATH = dirname(dirname(graphium.__file__))
with open(join(GRAPHIUM_PATH, "expts/configs/ipu.config")) as f:
lines = f.readlines()
print("".join(lines))
deviceIterations(16) replicationFactor(1) Training.gradientAccumulation(1) # How many gradient steps to accumulate modelName("reproduce_mixed") # TODO: Use the name from the main file, and add date/time enableProfiling("pro_vision") # The folder where the profile will be stored Jit.traceModel(True) enableExecutableCaching("pop_compiler_cache")
Creating the yaml file¶
The first step is to create a YAML file containing all the required configurations, with an example given at graphium/docs/tutorials/model_training/config_ipu_tutorials.yaml
. We will go through each part of the configurations.
import yaml
import omegaconf
def print_config_with_key(config, key):
new_config = {key: config[key]}
print(omegaconf.OmegaConf.to_yaml(new_config))
# First, let's read the yaml configuration file
with open("config_ipu_tutorials.yaml", "r") as file:
yaml_config = yaml.load(file, Loader=yaml.FullLoader)
print("Yaml file loaded")
Yaml file loaded
Constants¶
First, we define the constants such as the random seed and whether the model should raise or ignore an error.
The name
here will be used to log the metrics into WandB and log the models.
The accelerator
is used to define the device. It supports cpu
, ipu
or gpu
. This is the only part that needs to change when working with IPUs.
print_config_with_key(yaml_config, "constants")
constants: name: tutorial_model seed: 42 raise_train_error: true accelerator: type: ipu
Datamodule¶
Here, we define all the parameters required by the datamodule to run correctly, such as the dataset path, whether to cache, the columns for the training, the molecular featurization to use, the train/val/test splits and the batch size.
The MultitaskFromSmilesDataModule
allows us to define a set of tasks within different CSV or parquet files, and use them to train the model simultaneously via the path datamodule: args: task_specific_args
For more details, see class graphium.data.datamodule.MultitaskFromSmilesDataModule
Reading a CSV file and train/val/test splits¶
Here is an example of configuration regarding the task named "homo", which reads the file located in https://storage.googleapis.com/graphium-public/datasets/QM9/norm_mini_qm9.csv
, selects the columns "homo" and "lumo", and splits into validation and test set with rations of 20% each
print_config_with_key(yaml_config["datamodule"]["args"]["task_specific_args"], "homo")
homo: df: null df_path: https://storage.googleapis.com/graphium-public/datasets/QM9/norm_micro_qm9.csv smiles_col: smiles label_cols: - homo - lumo split_val: 0.2 split_test: 0.2 split_seed: 42 splits_path: null sample_size: null idx_col: null weights_col: null weights_type: null
Featurizing a molecule¶
Molecules can be featurized using various properties and positional / structural encoding. A list of all features is available here:
graphium.features.featurizer.get_mol_atomic_features_onehot
graphium.features.featurizer.get_mol_atomic_features_float
graphium.features.featurizer.get_mol_edge_features
graphium.features.spectral.compute_laplacian_positional_eigvecs
graphium.features.rw.compute_rwse
Example of a configuration for featurization below. Notice the list of atomic and edge properties. Notice pos_encoding_as_features
that defines both the laplacian and random-walk positional encodings.
print_config_with_key(yaml_config["datamodule"]["args"], "featurization")
featurization: atom_property_list_onehot: - atomic-number - valence atom_property_list_float: - mass - electronegativity - in-ring edge_property_list: - bond-type-onehot - stereo - in-ring add_self_loop: false explicit_H: false use_bonds_weights: false pos_encoding_as_features: pos_types: la_pos: pos_type: laplacian_eigvec_eigval num_pos: 3 normalization: none disconnected_comp: true rw_pos: pos_type: rwse ksteps: 16
Architecture¶
In the architecture, we define all the layers for the model, including the layers for the pre-processing MLP (input layers pre-nn
), the post-processing MLP (output layers post-nn
), and the main GNN (graph neural network gnn
).
The parameters allow to chose the feature size, the depth, the skip connections, the pooling and the virtual node. It also support different GNN layers such as pyg:gcn
, pyg:gin
, pyg:gine
, pyg:gated-gcn
, pyg:pna-msgpass
, and pyg:gps
.
For more details, see the following classes:
graphium.nn.global_architecture.FullGraphMultiTaskNetwork
: Main class for the architecturegraphium.nn.global_architecture.FeedForwardNN
: Main class for the inputs and outputs MLPgraphium.nn.pyg_architecture.FeedForwardPyg
: Main class for the GNN layers
Parameters for the node pre-processing NN¶
print_config_with_key(yaml_config["architecture"], "pre_nn")
pre_nn: out_dim: 32 hidden_dims: 32 depth: 1 activation: relu last_activation: none dropout: 0.1 normalization: none last_normalization: none residual_type: none
Parameters for the edge pre-processing NN¶
print_config_with_key(yaml_config["architecture"], "pre_nn_edges")
pre_nn_edges: out_dim: 16 hidden_dims: 16 depth: 1 activation: relu last_activation: none dropout: 0.1 normalization: none last_normalization: none residual_type: none
Parameters for the GNN¶
Here is an example of a GraphGPS layer, with it's MPNN being a GINE model
print_config_with_key(yaml_config["architecture"], "gnn")
gnn: out_dim: 32 hidden_dims: 32 depth: 3 activation: relu last_activation: none dropout: 0.1 normalization: none last_normalization: none residual_type: simple pooling: - sum - mean - max virtual_node: none layer_type: pyg:gps layer_kwargs: mpnn_type: pyg:gine mpnn_kwargs: null attn_type: full-attention attn_kwargs: null
Parameters for the node post-processing NN (after the GNN)¶
print_config_with_key(yaml_config["architecture"], "graph_output_nn")
graph_output_nn: out_dim: 32 hidden_dims: 32 depth: 1 activation: relu last_activation: none dropout: 0.1 normalization: none last_normalization: none residual_type: none
Parameters for the multi-task output heads¶
Here is the example for the task heads. Notice that "task_name"
should match the tasks in the section datamodule: args: task_specific_args
.
print_config_with_key(yaml_config["architecture"], "task_heads")
task_heads: - task_name: homo out_dim: 2 hidden_dims: 32 depth: 2 activation: relu last_activation: none dropout: 0.1 normalization: none last_normalization: none residual_type: none - task_name: alpha out_dim: 1 hidden_dims: 32 depth: 2 activation: relu last_activation: none dropout: 0.1 normalization: none last_normalization: none residual_type: none - task_name: cv out_dim: 1 hidden_dims: 32 depth: 2 activation: relu last_activation: none dropout: 0.1 normalization: none last_normalization: none residual_type: none
Predictor¶
In the predictor, we define the loss functions, the metrics to track on the progress bar, and all the parameters necessary for the optimizer.
Again, each of these arguments depend on the task. And the task_name
should match the ones from datamodule: args: task_specific_args
print_config_with_key(yaml_config, "predictor")
predictor: metrics_on_progress_bar: homo: - mae - pearsonr alpha: - mae cv: - mae - pearsonr loss_fun: homo: mse_ipu alpha: mse_ipu cv: mse_ipu random_seed: 42 optim_kwargs: lr: 0.001 torch_scheduler_kwargs: null scheduler_kwargs: null target_nan_mask: null flag_kwargs: n_steps: 0 alpha: 0.0
Metrics¶
All the metrics can be defined there. If we want to use a classification metric, we can also define a threshold.
See class graphium.trainer.metrics.MetricWrapper
for more details.
See graphium.trainer.metrics.METRICS_CLASSIFICATION
and graphium.trainer.metrics.METRICS_REGRESSION
for a dictionnary of accepted metrics.
Again, the metrics are task-dependant and must match the names in datamodule: args: task_specific_args
print_config_with_key(yaml_config, "metrics")
metrics: homo: - name: mae metric: mae threshold_kwargs: null - name: pearsonr metric: pearsonr threshold_kwargs: null target_nan_mask: ignore-mean-label alpha: - name: mae metric: mae threshold_kwargs: null - name: pearsonr metric: pearsonr threshold_kwargs: null cv: - name: mae metric: mae threshold_kwargs: null - name: pearsonr metric: pearsonr threshold_kwargs: null
Trainer¶
Finally, the Trainer defines the parameters for the number of epochs to train, the checkpoints, and the patience.
print_config_with_key(yaml_config, "trainer")
trainer: logger: save_dir: logs/QM9 name: tutorial_model model_checkpoint: dirpath: models_checkpoints/QM9/ filename: tutorial_model save_top_k: 1 every_n_epochs: 1 trainer: precision: 32 max_epochs: 5 min_epochs: 1
# General imports
import os
from os.path import dirname, abspath
import yaml
from copy import deepcopy
from omegaconf import DictConfig
import timeit
from loguru import logger
from pytorch_lightning.utilities.model_summary import ModelSummary
# Current project imports
import graphium
from graphium.config._loader import load_datamodule, load_metrics, load_architecture, load_predictor, load_trainer
from graphium.utils.safe_run import SafeRun
# WandB
import wandb
Then, let's load the configuration file¶
# Set up the working directory
MAIN_DIR = dirname(dirname(abspath(graphium.__file__)))
CONFIG_FILE = "docs/tutorials/model_training/config_ipu_tutorials.yaml"
os.chdir(MAIN_DIR)
with open(os.path.join(MAIN_DIR, CONFIG_FILE), "r") as f:
cfg = yaml.safe_load(f)
Now let's process the data¶
# Load and initialize the dataset
datamodule = load_datamodule(cfg)
datamodule.prepare_data()
[17:16:35.918] [poptorch::python] [warning] trace_model=True is deprecated since version 3.0 and will be removed in a future release 2022-09-20 17:16:35.920 | INFO | graphium.data.datamodule:prepare_data:699 - Reading data for task 'homo' 2022-09-20 17:16:36.462 | INFO | graphium.data.datamodule:prepare_data:699 - Reading data for task 'alpha' 2022-09-20 17:16:36.585 | INFO | graphium.data.datamodule:prepare_data:699 - Reading data for task 'cv' 2022-09-20 17:16:36.707 | INFO | graphium.data.datamodule:prepare_data:721 - Done reading datasets 2022-09-20 17:16:36.707 | INFO | graphium.data.datamodule:prepare_data:733 - Prepare single-task dataset for task 'homo' with 1005 data points. 2022-09-20 17:16:36.708 | INFO | graphium.data.datamodule:prepare_data:733 - Prepare single-task dataset for task 'alpha' with 1005 data points. 2022-09-20 17:16:36.709 | INFO | graphium.data.datamodule:prepare_data:733 - Prepare single-task dataset for task 'cv' with 1005 data points.
mols to ids: 0%| | 0/3015 [00:00<?, ?it/s]
featurizing_smiles: 0%| | 0/1005 [00:00<?, ?it/s]
Let's build the architecture, metrics, and set-up the trainer¶
# Initialize the network
model_class, model_kwargs = load_architecture(
cfg,
in_dims=datamodule.in_dims,
)
metrics = load_metrics(cfg)
logger.info(metrics)
predictor = load_predictor(cfg, model_class, model_kwargs, metrics)
predictor.set_max_nodes_edges_per_graph(datamodule, stages=["train", "val"])
logger.info(predictor.model)
logger.info(ModelSummary(predictor, max_depth=4))
trainer = load_trainer(cfg, "tutorial-run")
2022-09-20 17:16:46.768 | INFO | __main__:<cell line: 8>:8 - {'homo': {'mae': mean_absolute_error, 'pearsonr': pearson_corrcoef}, 'alpha': {'mae': mean_absolute_error, 'pearsonr': pearson_corrcoef}, 'cv': {'mae': mean_absolute_error, 'pearsonr': pearson_corrcoef}} 2022-09-20 17:16:46.812 | INFO | __main__:<cell line: 12>:12 - Multitask_GNN --------------- pre-NN(depth=1, ResidualConnectionNone) [FCLayer[87 -> 32] pre-NN-edges(depth=1, ResidualConnectionNone) [FCLayer[13 -> 16] GNN(depth=3, ResidualConnectionSimple(skip_steps=1)) GPSLayerPyg[32 -> 32 -> 32 -> 32] -> Pooling(['sum', 'mean', 'max']) -> FCLayer(96 -> 32, activation=None) post-NN(depth=1, ResidualConnectionNone) [FCLayer[32 -> 32] 2022-09-20 17:16:46.813 | INFO | __main__:<cell line: 13>:13 - | Name | Type | Params ----------------------------------------------------------------------------------- 0 | model | FullGraphMultiTaskNetwork | 47.6 K 1 | model.pe_encoders | ModuleDict | 681 2 | model.pe_encoders.la_pos | LapPENodeEncoder | 137 3 | model.pe_encoders.la_pos.linear_A | Linear | 9 4 | model.pe_encoders.la_pos.pe_encoder | MLP | 128 5 | model.pe_encoders.rw_pos | MLPEncoder | 544 6 | model.pe_encoders.rw_pos.pe_encoder | MLP | 544 7 | model.pre_nn | FeedForwardNN | 2.8 K 8 | model.pre_nn.activation | ReLU | 0 9 | model.pre_nn.residual_layer | ResidualConnectionNone | 0 10 | model.pre_nn.layers | ModuleList | 2.8 K 11 | model.pre_nn.layers.0 | FCLayer | 2.8 K 12 | model.pre_nn_edges | FeedForwardNN | 224 13 | model.pre_nn_edges.activation | ReLU | 0 14 | model.pre_nn_edges.residual_layer | ResidualConnectionNone | 0 15 | model.pre_nn_edges.layers | ModuleList | 224 16 | model.pre_nn_edges.layers.0 | FCLayer | 224 17 | model.gnn | FeedForwardPyg | 39.5 K 18 | model.gnn.activation | ReLU | 0 19 | model.gnn.layers | ModuleList | 36.4 K 20 | model.gnn.layers.0 | GPSLayerPyg | 12.1 K 21 | model.gnn.layers.1 | GPSLayerPyg | 12.1 K 22 | model.gnn.layers.2 | GPSLayerPyg | 12.1 K 23 | model.gnn.virtual_node_layers | ModuleList | 0 24 | model.gnn.virtual_node_layers.0 | VirtualNodePyg | 0 25 | model.gnn.virtual_node_layers.1 | VirtualNodePyg | 0 26 | model.gnn.residual_layer | ResidualConnectionSimple | 0 27 | model.gnn.global_pool_layer | ModuleListConcat | 0 28 | model.gnn.global_pool_layer.0 | PoolingWrapperPyg | 0 29 | model.gnn.global_pool_layer.1 | PoolingWrapperPyg | 0 30 | model.gnn.global_pool_layer.2 | PoolingWrapperPyg | 0 31 | model.gnn.out_linear | FCLayer | 3.1 K 32 | model.gnn.out_linear.linear | Linear | 3.1 K 33 | model.gnn.out_linear.dropout | Dropout | 0 34 | model.graph_output_nn | FeedForwardNN | 1.1 K 35 | model.graph_output_nn.activation | ReLU | 0 36 | model.graph_output_nn.residual_layer | ResidualConnectionNone | 0 37 | model.graph_output_nn.layers | ModuleList | 1.1 K 38 | model.graph_output_nn.layers.0 | FCLayer | 1.1 K 39 | model.task_heads | TaskHeads | 3.3 K 40 | model.task_heads.task_heads | ModuleDict | 3.3 K 41 | model.task_heads.task_heads.homo | TaskHead | 1.1 K 42 | model.task_heads.task_heads.alpha | TaskHead | 1.1 K 43 | model.task_heads.task_heads.cv | TaskHead | 1.1 K ----------------------------------------------------------------------------------- 47.6 K Trainable params 0 Non-trainable params 47.6 K Total params 0.190 Total estimated model params size (MB) [17:16:46.843] [poptorch::python] [warning] trace_model=True is deprecated since version 3.0 and will be removed in a future release /home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ipu.py:20: LightningDeprecationWarning: The `pl.plugins.training_type.ipu.IPUPlugin` is deprecated in v1.6 and will be removed in v1.8. Use `pl.strategies.ipu.IPUStrategy` instead. rank_zero_deprecation( Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving. wandb: Currently logged in as: dbeaini_mila (multitask-gnn). Use `wandb login --relogin` to force relogin
/home/dom/graphium/wandb/run-20220920_171648-1ss172i9
/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:317: LightningDeprecationWarning: Passing <graphium.ipu.ipu_wrapper.IPUPluginGraphium object at 0x7f4431dfc220> `strategy` to the `plugins` flag in Trainer has been deprecated in v1.5 and will be removed in v1.7. Use `Trainer(strategy=<graphium.ipu.ipu_wrapper.IPUPluginGraphium object at 0x7f4431dfc220>)` instead. rank_zero_deprecation( /home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:469: UserWarning: more than one device specific flag has been set rank_zero_warn("more than one device specific flag has been set") GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: True, using: 1 IPUs HPU available: False, using: 0 HPUs
Finally, let's run the model¶
# Run the model training
with SafeRun(name="TRAINING", raise_error=cfg["constants"]["raise_train_error"], verbose=True):
trainer.fit(model=predictor, datamodule=datamodule)
# Exit WandB
wandb.finish()
/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:326: LightningDeprecationWarning: Base `LightningModule.on_train_batch_end` hook signature has changed in v1.5. The `dataloader_idx` argument will be removed in v1.7. rank_zero_deprecation( /home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:154: LightningDeprecationWarning: The `LightningModule.get_progress_bar_dict` method was deprecated in v1.5 and will be removed in v1.7. Please use the `ProgressBarBase.get_metrics` instead. rank_zero_deprecation(
mols to ids: 0%| | 0/1809 [00:00<?, ?it/s]
mols to ids: 0%| | 0/603 [00:00<?, ?it/s]
/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:611: UserWarning: Checkpoint directory /home/dom/graphium/models_checkpoints/QM9 exists and is not empty. rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") | Name | Type | Params ---------------------------------------------------- 0 | model | FullGraphMultiTaskNetwork | 47.6 K ---------------------------------------------------- 47.6 K Trainable params 0 Non-trainable params 47.6 K Total params 0.190 Total estimated model params size (MB)
------------------- MultitaskDataset about = training set num_graphs_total = 603 num_nodes_total = 5285 max_num_nodes_per_graph = 9 min_num_nodes_per_graph = 1 std_num_nodes_per_graph = 0.6989218547197285 mean_num_nodes_per_graph = 8.764510779436153 num_edges_total = 11316 max_num_edges_per_graph = 26 min_num_edges_per_graph = 0 std_num_edges_per_graph = 2.6821780989470896 mean_num_edges_per_graph = 18.766169154228855 ------------------- ------------------- MultitaskDataset about = validation set num_graphs_total = 201 num_nodes_total = 1756 max_num_nodes_per_graph = 9 min_num_nodes_per_graph = 2 std_num_nodes_per_graph = 0.7949619835770694 mean_num_nodes_per_graph = 8.7363184079602 num_edges_total = 3736 max_num_edges_per_graph = 24 min_num_edges_per_graph = 2 std_num_edges_per_graph = 2.7704251592456193 mean_num_edges_per_graph = 18.587064676616915 -------------------
Sanity Checking: 0it [00:00, ?it/s]
2022-09-20 17:16:59.088 | INFO | graphium.ipu.ipu_dataloader:create_ipu_dataloader:230 - Estimating pack max_pack_size=54 or max_pack_size_per_graph=9.0 2022-09-20 17:16:59.089 | INFO | graphium.ipu.ipu_dataloader:create_ipu_dataloader:233 - Provided `max_num_nodes=72` [17:17:02.052] [poptorch:cpp] [warning] Graph contains an unused input %this_input : Long(1, strides=[1], requires_grad=0, device=cpu) [17:17:02.054] [poptorch:cpp] [warning] %value.3 : Long(2, 144, strides=[144, 1], requires_grad=0, device=cpu) = aten::to(%tensor.1, %111, %112, %113, %114) # /home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/poptorch/_poplar_executor.py:1315:0: torch.int64 is not supported natively on IPU, loss of range/precision may occur. We will only warn on the first instance. Graph compilation: 100%|██████████| 100/100 [00:01<00:00] 2022-09-20 17:17:07.541 | INFO | graphium.ipu.ipu_dataloader:create_ipu_dataloader:230 - Estimating pack max_pack_size=54 or max_pack_size_per_graph=9.0 2022-09-20 17:17:07.542 | INFO | graphium.ipu.ipu_dataloader:create_ipu_dataloader:233 - Provided `max_num_nodes=72` /home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1933: PossibleUserWarning: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch. rank_zero_warn(
Training: 0it [00:00, ?it/s]
/home/dom/graphium/graphium/ipu/ipu_wrapper.py:274: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if all([this_input.shape[0] == 1 for this_input in inputs]): /home/dom/graphium/graphium/ipu/ipu_wrapper.py:307: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! batch_idx = batch.pop("_batch_idx").item() /home/dom/graphium/graphium/nn/base_layers.py:169: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. if torch.prod(torch.as_tensor(h.shape[:-1])) == 0: /home/dom/graphium/graphium/nn/base_layers.py:169: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if torch.prod(torch.as_tensor(h.shape[:-1])) == 0: /home/dom/graphium/graphium/nn/architectures/pyg_architectures.py:139: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! assert node_feats.shape[0] == g.num_nodes /home/dom/graphium/graphium/nn/architectures/pyg_architectures.py:154: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! assert edge_feats.shape[0] == g.num_edges /home/dom/graphium/graphium/nn/architectures/global_architectures.py:1190: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. if torch.prod(torch.as_tensor(e.shape[:-1])) == 0: /home/dom/graphium/graphium/nn/architectures/global_architectures.py:1190: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if torch.prod(torch.as_tensor(e.shape[:-1])) == 0: /home/dom/graphium/graphium/nn/base_graph_layer.py:277: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! assert edge_index.min() >= 0 /home/dom/graphium/graphium/nn/base_graph_layer.py:278: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! assert edge_index.max() < torch.iinfo(edge_index.dtype).max /home/dom/graphium/graphium/nn/base_graph_layer.py:281: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! assert edge_index.size(0) == 2 /home/dom/graphium/graphium/nn/pyg_layers/gps_pyg.py:144: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! on_ipu = ("graph_is_true" in batch.keys) and (not batch.graph_is_true.all()) /home/dom/graphium/graphium/nn/pyg_layers/gps_pyg.py:146: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! max_num_nodes_per_graph = batch.dataset_max_nodes_per_graph[0].item() /home/dom/graphium/graphium/ipu/to_dense_batch.py:57: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! batch_size = int(batch.max()) + 1 /home/dom/graphium/graphium/ipu/to_dense_batch.py:80: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! assert ( /home/dom/graphium/graphium/ipu/ipu_wrapper.py:20: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if targets[task].shape == preds[task].shape: /home/dom/graphium/graphium/trainer/predictor.py:422: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. total_norm = torch.tensor(0.0) /home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/wandb/util.py:604: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! return obj.item(), True [17:17:11.240] [poptorch:cpp] [warning] Graph contains an unused input %this_input : Long(1, strides=[1], requires_grad=0, device=cpu) Graph compilation: 100%|██████████| 100/100 [00:02<00:00]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…
Run history:
alpha/MSELossIPU/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
alpha/MSELossIPU/val | ▄█▃▄▁ |
alpha/loss/val | ▄█▃▄▁ |
alpha/mae/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
alpha/mae/val | ▅█▃▄▁ |
alpha/mean_pred/train | █▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
alpha/mean_pred/val | ▄█▂▂▁ |
alpha/mean_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
alpha/mean_target/val | ▁▁▁▁▁ |
alpha/median_pred/train | ██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
alpha/median_pred/val | ▁█▃▃▁ |
alpha/median_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
alpha/median_target/val | ▁▁▁▁▁ |
alpha/pearsonr/train | ▁█▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆ |
alpha/pearsonr/val | ▁▇▆▇█ |
alpha/std_pred/train | █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
alpha/std_pred/val | █▃▂▁▃ |
alpha/std_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
alpha/std_target/val | ▁▁▁▁▁ |
cv/MSELossIPU/train | ▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
cv/MSELossIPU/val | ▄█▄▃▁ |
cv/loss/val | ▄█▄▃▁ |
cv/mae/train | █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
cv/mae/val | ▅█▄▃▁ |
cv/mean_pred/train | █▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
cv/mean_pred/val | ▁█▅▄▂ |
cv/mean_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
cv/mean_target/val | ▁▁▁▁▁ |
cv/median_pred/train | █▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
cv/median_pred/val | ▁█▄▃▂ |
cv/median_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
cv/median_target/val | ▁▁▁▁▁ |
cv/pearsonr/train | █▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
cv/pearsonr/val | ▁▇▆██ |
cv/std_pred/train | █▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
cv/std_pred/val | █▄▄▁▅ |
cv/std_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
cv/std_target/val | ▁▁▁▁▁ |
epoch | ▁▃▅▆█ |
grad_norm | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
homo/MSELossIPU/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
homo/MSELossIPU/val | █▅▁▂▂ |
homo/loss/val | █▅▁▂▂ |
homo/mae/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
homo/mae/val | █▅▁▃▃ |
homo/mean_pred/train | ▁▄███████████████████████████████ |
homo/mean_pred/val | ▁█▄▄▄ |
homo/mean_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
homo/mean_target/val | ▁▁▁▁▁ |
homo/median_pred/train | █▁▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅ |
homo/median_pred/val | ▁▅▄▇█ |
homo/median_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
homo/median_target/val | ▁▁▁▁▁ |
homo/pearsonr/train | █▁▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆ |
homo/pearsonr/val | ███▅▁ |
homo/std_pred/train | █▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
homo/std_pred/val | █▃▂▁▂ |
homo/std_target/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
homo/std_target/val | ▁▁▁▁▁ |
loss | ▁▁▁ |
loss/train | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
loss/val | ▅█▃▃▁ |
train/grad_norm | ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
trainer/global_step | ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▆▇▇▇████ |
Run summary:
alpha/MSELossIPU/train | 10.27925 |
alpha/MSELossIPU/val | 0.90438 |
alpha/loss/val | 0.90438 |
alpha/mae/train | 1.92692 |
alpha/mae/val | 0.5743 |
alpha/mean_pred/train | -0.0 |
alpha/mean_pred/val | 0.05032 |
alpha/mean_target/train | -1.47591 |
alpha/mean_target/val | -0.05393 |
alpha/median_pred/train | -0.0 |
alpha/median_pred/val | 0.09372 |
alpha/median_target/train | -0.81494 |
alpha/median_target/val | -0.04535 |
alpha/pearsonr/train | 0.24189 |
alpha/pearsonr/val | 0.70787 |
alpha/std_pred/train | 0.0 |
alpha/std_pred/val | 0.63476 |
alpha/std_target/train | 3.11787 |
alpha/std_target/val | 1.28429 |
cv/MSELossIPU/train | 7.34846 |
cv/MSELossIPU/val | 0.71549 |
cv/loss/val | 0.71549 |
cv/mae/train | 1.7832 |
cv/mae/val | 0.59933 |
cv/mean_pred/train | -0.0 |
cv/mean_pred/val | 0.02737 |
cv/mean_target/train | -1.33105 |
cv/mean_target/val | -0.08344 |
cv/median_pred/train | -0.0 |
cv/median_pred/val | 0.08583 |
cv/median_target/train | -0.73486 |
cv/median_target/val | -0.0361 |
cv/pearsonr/train | -0.84246 |
cv/pearsonr/val | 0.61503 |
cv/std_pred/train | 0.0 |
cv/std_pred/val | 0.60702 |
cv/std_target/train | 2.58691 |
cv/std_target/val | 1.06456 |
epoch | 4 |
grad_norm | 0.0 |
homo/MSELossIPU/train | 3.00724 |
homo/MSELossIPU/val | 1.10813 |
homo/loss/val | 1.10813 |
homo/mae/train | 0.92035 |
homo/mae/val | 0.81845 |
homo/mean_pred/train | 0.0 |
homo/mean_pred/val | 0.02063 |
homo/mean_target/train | -0.43686 |
homo/mean_target/val | -0.0837 |
homo/median_pred/train | 0.0 |
homo/median_pred/val | 0.00169 |
homo/median_target/train | 0.04636 |
homo/median_target/val | -0.10706 |
homo/pearsonr/train | 0.01569 |
homo/pearsonr/val | 0.11083 |
homo/std_pred/train | 0.0 |
homo/std_pred/val | 0.2037 |
homo/std_target/train | 1.75284 |
homo/std_target/val | 1.05716 |
loss | 6.87832 |
loss/train | 6.87832 |
loss/val | 0.90933 |
train/grad_norm | 0.0 |
trainer/global_step | 29 |
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)
./wandb/run-20220920_171648-1ss172i9/logs
Testing the model¶
Once the model is trained, we can use the same datamodule to get the results on the test set. Here, ckpt_path
refers to the checkpoint path where the model at the best validation step was saved. Thus, the results on the test set represent the early stopping.
All the metrics that were computed on the validation set are then computed on the test set, printed, and saved into the metrics.yaml
file.
ckpt_path = trainer.checkpoint_callbacks[0].best_model_path
predictor.set_max_nodes_edges_per_graph(datamodule, stages=['test'])
trainer.test(model=predictor, datamodule=datamodule, ckpt_path=ckpt_path)
/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:154: LightningDeprecationWarning: The `LightningModule.get_progress_bar_dict` method was deprecated in v1.5 and will be removed in v1.7. Please use the `ProgressBarBase.get_metrics` instead. rank_zero_deprecation(
mols to ids: 0%| | 0/603 [00:00<?, ?it/s]
Restoring states from the checkpoint path at /home/dom/graphium/models_checkpoints/QM9/tutorial_model.ckpt Loaded model weights from checkpoint at /home/dom/graphium/models_checkpoints/QM9/tutorial_model.ckpt 2022-09-20 17:18:37.268 | INFO | graphium.ipu.ipu_dataloader:create_ipu_dataloader:230 - Estimating pack max_pack_size=54 or max_pack_size_per_graph=9.0 2022-09-20 17:18:37.268 | INFO | graphium.ipu.ipu_dataloader:create_ipu_dataloader:233 - Provided `max_num_nodes=72`
------------------- MultitaskDataset about = test set num_graphs_total = 201 num_nodes_total = 1751 max_num_nodes_per_graph = 9 min_num_nodes_per_graph = 1 std_num_nodes_per_graph = 0.9443829267824616 mean_num_nodes_per_graph = 8.711442786069652 num_edges_total = 3748 max_num_edges_per_graph = 26 min_num_edges_per_graph = 0 std_num_edges_per_graph = 3.0647412282392468 mean_num_edges_per_graph = 18.64676616915423 -------------------
Testing: 0it [00:00, ?it/s]
[17:18:40.730] [poptorch:cpp] [warning] Graph contains an unused input %this_input : Long(1, strides=[1], requires_grad=0, device=cpu) Graph compilation: 100%|██████████| 100/100 [00:01<00:00] Exception in thread SockSrvRdThr: Traceback (most recent call last): File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner self.run() File "/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 113, in run shandler(sreq) File "/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 172, in server_record_publish iface = self._mux.get_stream(stream_id).interface File "/home/dom/.venv/graphium_ipu/lib/python3.8/site-packages/wandb/sdk/service/streams.py", line 186, in get_stream stream = self._streams[stream_id] KeyError: '1ss172i9'
──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── alpha/MSELossIPU/test 1.0208719968795776 alpha/loss/test 1.0208719968795776 alpha/mae/test 0.5369145274162292 alpha/mean_pred/test -0.043575093150138855 alpha/mean_target/test -0.1731886863708496 alpha/median_pred/test 0.020035069435834885 alpha/median_target/test -0.05999755859375 alpha/pearsonr/test 0.6431294679641724 alpha/std_pred/test 0.6613938808441162 alpha/std_target/test 1.2930139303207397 cv/MSELossIPU/test 0.7984611392021179 cv/loss/test 0.7984611392021179 cv/mae/test 0.5717160701751709 cv/mean_pred/test -0.06430188566446304 cv/mean_target/test -0.08974086493253708 cv/median_pred/test 0.014351803809404373 cv/median_target/test -0.06439208984375 cv/pearsonr/test 0.663591742515564 cv/std_pred/test 0.63597571849823 cv/std_target/test 1.180733561515808 homo/MSELossIPU/test 1.14460027217865 homo/loss/test 1.14460027217865 homo/mae/test 0.7968456149101257 homo/mean_pred/test 0.004068718757480383 homo/mean_target/test 0.1092241033911705 homo/median_pred/test 0.0017465166747570038 homo/median_target/test 0.1390380859375 homo/pearsonr/test 0.1387491226196289 homo/std_pred/test 0.2090565264225006 homo/std_target/test 1.07249116897583 loss/test 0.9879778027534485 ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[{'homo/mean_pred/test': 0.004068718757480383, 'homo/std_pred/test': 0.2090565264225006, 'homo/median_pred/test': 0.0017465166747570038, 'homo/mean_target/test': 0.1092241033911705, 'homo/std_target/test': 1.07249116897583, 'homo/median_target/test': 0.1390380859375, 'homo/mae/test': 0.7968456149101257, 'homo/pearsonr/test': 0.1387491226196289, 'homo/MSELossIPU/test': 1.14460027217865, 'homo/loss/test': 1.14460027217865, 'alpha/mean_pred/test': -0.043575093150138855, 'alpha/std_pred/test': 0.6613938808441162, 'alpha/median_pred/test': 0.020035069435834885, 'alpha/mean_target/test': -0.1731886863708496, 'alpha/std_target/test': 1.2930139303207397, 'alpha/median_target/test': -0.05999755859375, 'alpha/mae/test': 0.5369145274162292, 'alpha/pearsonr/test': 0.6431294679641724, 'alpha/MSELossIPU/test': 1.0208719968795776, 'alpha/loss/test': 1.0208719968795776, 'cv/mean_pred/test': -0.06430188566446304, 'cv/std_pred/test': 0.63597571849823, 'cv/median_pred/test': 0.014351803809404373, 'cv/mean_target/test': -0.08974086493253708, 'cv/std_target/test': 1.180733561515808, 'cv/median_target/test': -0.06439208984375, 'cv/mae/test': 0.5717160701751709, 'cv/pearsonr/test': 0.663591742515564, 'cv/MSELossIPU/test': 0.7984611392021179, 'cv/loss/test': 0.7984611392021179, 'loss/test': 0.9879778027534485}]