{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Ensembling" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from lightning import pytorch as pl\n", "import numpy as np\n", "import torch\n", "from chemprop import data, models, nn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is an example [dataloader](./data/dataloaders.ipynb)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "smis = [\"C\" * i for i in range(1, 4)]\n", "ys = np.random.rand(len(smis), 1)\n", "dset = data.MoleculeDataset([data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])\n", "dataloader = data.build_dataloader(dset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model ensembling" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A single model will sometimes give erroneous predictions for some molecules. These erroneous predictions can be mitigated by averaging the predictions of several models trained on the same data. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "ensemble = []\n", "n_models = 3\n", "for _ in range(n_models):\n", " ensemble.append(models.MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), nn.RegressionFFN()))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: False, used: False\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", "Loading `train_dataloader` to estimate number of stepping batches.\n", "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n", "\n", " | Name | Type | Params | Mode \n", "---------------------------------------------------------------\n", "0 | message_passing | BondMessagePassing | 227 K | train\n", "1 | agg | MeanAggregation | 0 | train\n", "2 | bn | Identity | 0 | train\n", "3 | predictor | RegressionFFN | 90.6 K | train\n", "4 | X_d_transform | Identity | 0 | train\n", "5 | metrics | ModuleList | 0 | train\n", "---------------------------------------------------------------\n", "318 K Trainable params\n", "0 Non-trainable params\n", "318 K Total params\n", "1.273 Total estimated model params size (MB)\n", "24 Modules in train mode\n", "0 Modules in eval mode\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: 100%|██████████| 1/1 [00:00<00:00, 14.38it/s, train_loss_step=0.234, train_loss_epoch=0.234]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=1` reached.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: 100%|██████████| 1/1 [00:00<00:00, 13.86it/s, train_loss_step=0.234, train_loss_epoch=0.234]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "GPU available: False, used: False\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading `train_dataloader` to estimate number of stepping batches.\n", "\n", " | Name | Type | Params | Mode \n", "---------------------------------------------------------------\n", "0 | message_passing | BondMessagePassing | 227 K | train\n", "1 | agg | MeanAggregation | 0 | train\n", "2 | bn | Identity | 0 | train\n", "3 | predictor | RegressionFFN | 90.6 K | train\n", "4 | X_d_transform | Identity | 0 | train\n", "5 | metrics | ModuleList | 0 | train\n", "---------------------------------------------------------------\n", "318 K Trainable params\n", "0 Non-trainable params\n", "318 K Total params\n", "1.273 Total estimated model params size (MB)\n", "24 Modules in train mode\n", "0 Modules in eval mode\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: 100%|██████████| 1/1 [00:00<00:00, 46.40it/s, train_loss_step=0.215, train_loss_epoch=0.215]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=1` reached.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: 100%|██████████| 1/1 [00:00<00:00, 23.79it/s, train_loss_step=0.215, train_loss_epoch=0.215]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "GPU available: False, used: False\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading `train_dataloader` to estimate number of stepping batches.\n", "\n", " | Name | Type | Params | Mode \n", "---------------------------------------------------------------\n", "0 | message_passing | BondMessagePassing | 227 K | train\n", "1 | agg | MeanAggregation | 0 | train\n", "2 | bn | Identity | 0 | train\n", "3 | predictor | RegressionFFN | 90.6 K | train\n", "4 | X_d_transform | Identity | 0 | train\n", "5 | metrics | ModuleList | 0 | train\n", "---------------------------------------------------------------\n", "318 K Trainable params\n", "0 Non-trainable params\n", "318 K Total params\n", "1.273 Total estimated model params size (MB)\n", "24 Modules in train mode\n", "0 Modules in eval mode\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: 100%|██████████| 1/1 [00:00<00:00, 42.51it/s, train_loss_step=0.239, train_loss_epoch=0.239]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=1` reached.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: 100%|██████████| 1/1 [00:00<00:00, 36.88it/s, train_loss_step=0.239, train_loss_epoch=0.239]\n" ] } ], "source": [ "for model in ensemble:\n", " trainer = pl.Trainer(logger=False, enable_checkpointing=False, max_epochs=1)\n", " trainer.fit(model, dataloader)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 83.86it/s] \n", "Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 82.63it/s]\n", "Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 68.94it/s] \n" ] } ], "source": [ "prediction_dataloader = data.build_dataloader(dset, shuffle=False)\n", "predictions = []\n", "for model in ensemble:\n", " predictions.append(torch.concat(trainer.predict(model, prediction_dataloader)))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([[0.0096],\n", " [0.0008],\n", " [0.0082]]),\n", " tensor([[0.0318],\n", " [0.0260],\n", " [0.0254]]),\n", " tensor([[-0.0054],\n", " [ 0.0032],\n", " [-0.0035]])]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.0120],\n", " [0.0100],\n", " [0.0100]])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.concat(predictions, axis=1).mean(axis=1, keepdim=True)" ] } ], "metadata": { "kernelspec": { "display_name": "chemprop", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 4 }