{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Scaling inputs and outputs" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from chemprop.models import MPNN\n", "from chemprop.nn import BondMessagePassing, NormAggregation, RegressionFFN\n", "from chemprop.nn.transforms import ScaleTransform, UnscaleTransform, GraphTransform" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is an example [dataset](./data/datasets.ipynb) with extra atom and bond features, extra atom descriptors, and extra [datapoint](./data/datapoints.ipynb) descriptors." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from chemprop.data import MoleculeDatapoint, MoleculeDataset\n", "\n", "smis = [\"CC\", \"CN\", \"CO\", \"CF\", \"CP\", \"CS\", \"CI\"]\n", "ys = np.random.rand(len(smis), 1) * 100\n", "\n", "n_datapoints = len(smis)\n", "n_atoms = 2\n", "n_bonds = 1\n", "n_extra_atom_features = 3\n", "n_extra_bond_features = 4\n", "n_extra_atom_descriptors = 5\n", "n_extra_datapoint_descriptors = 6\n", "\n", "extra_atom_features = np.random.rand(n_datapoints, n_atoms, n_extra_atom_features)\n", "extra_bond_features = np.random.rand(n_datapoints, n_bonds, n_extra_bond_features)\n", "extra_atom_descriptors = np.random.rand(n_datapoints, n_atoms, n_extra_atom_descriptors)\n", "extra_datapoint_descriptors = np.random.rand(n_datapoints, n_extra_datapoint_descriptors)\n", "\n", "datapoints = [\n", " MoleculeDatapoint.from_smi(smi, y, x_d=x_d, V_f=V_f, E_f=E_f, V_d=V_d)\n", " for smi, y, x_d, V_f, E_f, V_d in zip(\n", " smis,\n", " ys,\n", " extra_datapoint_descriptors,\n", " extra_atom_features,\n", " extra_bond_features,\n", " extra_atom_descriptors,\n", " )\n", "]\n", "train_dset = MoleculeDataset(datapoints[:3])\n", "val_dset = MoleculeDataset(datapoints[3:5])\n", "test_dset = MoleculeDataset(datapoints[5:])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Scaling targets - FFN" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Scaling the target values before training can improve model performance and make training faster. The scaler for the targets should be fit to the training dataset and then applied to the validation dataset. This scaler is *not* applied to the test dataset. Instead the scaler is used to make an `UnscaleTransform` which is given to the predictor (FFN) layer and used automatically during inference. \n", "\n", "Note that currently the output_transform is saved both in the model's state_dict and and in the model's hyperparameters. This may be changed in the future to align with `lightning`'s recommendations. You can ignore any messages about this." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "output_scaler = train_dset.normalize_targets()\n", "val_dset.normalize_targets(output_scaler)\n", "# test_dset targets not scaled\n", "\n", "output_transform = UnscaleTransform.from_standard_scaler(output_scaler)\n", "\n", "ffn = RegressionFFN(output_transform=output_transform)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Scaling extra atom and bond features - Message Passing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The atom and bond features generated by Chemprop [featurizers](./featurizers/molgraph_molecule_featurizer.ipynb) are either multi-hot or on the order of 1. We recommend scaling extra atom and bond features to also be on the order of 1. Like the target scaler, these scalers are fit to the training data, applied to the validation data, and then saved to the model (in this case the message passing layer) so that they are applied automatically to the test dataset during inference." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
StandardScaler()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "StandardScaler()" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "V_f_scaler = train_dset.normalize_inputs(\"V_f\")\n", "E_f_scaler = train_dset.normalize_inputs(\"E_f\")\n", "\n", "val_dset.normalize_inputs(\"V_f\", V_f_scaler)\n", "val_dset.normalize_inputs(\"E_f\", E_f_scaler)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The scalers are used to make `ScaleTransform`s. These are combined into a `GraphTransform` which is given to the message passing module. Note that `ScaleTransform` acts on the whole feature vector, not just the extra features. The `ScaleTransform`'s mean and scale arrays are padded with enough zeros and ones so that only the extra features are actually scaled. The amount of padding required is the length of the default features of the featurizer." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from chemprop.featurizers import SimpleMoleculeMolGraphFeaturizer\n", "\n", "featurizer = SimpleMoleculeMolGraphFeaturizer(\n", " extra_atom_fdim=n_extra_atom_features, extra_bond_fdim=n_extra_bond_features\n", ")\n", "n_V_features = featurizer.atom_fdim - featurizer.extra_atom_fdim\n", "n_E_features = featurizer.bond_fdim - featurizer.extra_bond_fdim\n", "\n", "V_f_transform = ScaleTransform.from_standard_scaler(V_f_scaler, pad=n_V_features)\n", "E_f_transform = ScaleTransform.from_standard_scaler(E_f_scaler, pad=n_E_features)\n", "\n", "graph_transform = GraphTransform(V_f_transform, E_f_transform)\n", "\n", "mp = BondMessagePassing(graph_transform=graph_transform)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you only have one of extra atom features or extra bond features, you can set the transform for the unused option to `torch.nn.Identity`." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "graph_transform = GraphTransform(V_transform=torch.nn.Identity(), E_transform=E_f_transform)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Scaling extra atom descriptors - Message Passing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The atom descriptors from message passing (before aggregation) are also likely to be on the order of 1 so extra atom descriptors should also be scaled. No padding is needed (unlike above) as this scaling is only applied to the extra atom descriptors. The `ScaleTransform` is given to the message passing module for use during inference." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "V_d_scaler = train_dset.normalize_inputs(\"V_d\")\n", "val_dset.normalize_inputs(\"V_d\", V_d_scaler)\n", "\n", "V_d_transform = ScaleTransform.from_standard_scaler(V_d_scaler)\n", "\n", "mp = BondMessagePassing(V_d_transform=V_d_transform)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A `GraphTransform` and `ScaleTransform` can both be given to the message passing." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "mp = BondMessagePassing(graph_transform=graph_transform, V_d_transform=V_d_transform)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Scaling extra datapoint descriptors - MPNN" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The molecule/reaction descriptors from message passing (after aggregation) are batch normalized by default to be on the order of 1 (can be turned off, see the [model notebook](./models/basic_mpnn_model.ipynb)). Therefore we also recommended scaling the extra datapoint level descriptors. The `ScaleTransform` for this is given to the `MPNN` or `MulticomponentMPNN` module." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "X_d_scaler = train_dset.normalize_inputs(\"X_d\")\n", "val_dset.normalize_inputs(\"X_d\", X_d_scaler)\n", "\n", "X_d_transform = ScaleTransform.from_standard_scaler(X_d_scaler)\n", "\n", "chemprop_model = MPNN(\n", " BondMessagePassing(), NormAggregation(), RegressionFFN(), X_d_transform=X_d_transform\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "chemprop", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 2 }