## Scaling inputs and outputs

In [1]:
import torch
from chemprop.models import MPNN
from chemprop.nn import BondMessagePassing, NormAggregation, RegressionFFN
from chemprop.nn.transforms import ScaleTransform, UnscaleTransform, GraphTransform

This is an example [dataset](./data/datasets.ipynb) with extra atom and bond features, extra atom descriptors, and extra [datapoint](./data/datapoints.ipynb) descriptors.

In [2]:
import numpy as np
from chemprop.data import MoleculeDatapoint, MoleculeDataset

smis = ["CC", "CN", "CO", "CF", "CP", "CS", "CI"]
ys = np.random.rand(len(smis), 1) * 100

n_datapoints = len(smis)
n_atoms = 2
n_bonds = 1
n_extra_atom_features = 3
n_extra_bond_features = 4
n_extra_atom_descriptors = 5
n_extra_datapoint_descriptors = 6

extra_atom_features = np.random.rand(n_datapoints, n_atoms, n_extra_atom_features)
extra_bond_features = np.random.rand(n_datapoints, n_bonds, n_extra_bond_features)
extra_atom_descriptors = np.random.rand(n_datapoints, n_atoms, n_extra_atom_descriptors)
extra_datapoint_descriptors = np.random.rand(n_datapoints, n_extra_datapoint_descriptors)

datapoints = [
    MoleculeDatapoint.from_smi(smi, y, x_d=x_d, V_f=V_f, E_f=E_f, V_d=V_d)
    for smi, y, x_d, V_f, E_f, V_d in zip(
        smis,
        ys,
        extra_datapoint_descriptors,
        extra_atom_features,
        extra_bond_features,
        extra_atom_descriptors,
    )
]
train_dset = MoleculeDataset(datapoints[:3])
val_dset = MoleculeDataset(datapoints[3:5])
test_dset = MoleculeDataset(datapoints[5:])

### Scaling targets - FFN

Scaling the target values before training can improve model performance and make training faster. The scaler for the targets should be fit to the training dataset and then applied to the validation dataset. This scaler is *not* applied to the test dataset. Instead the scaler is used to make an `UnscaleTransform` which is given to the predictor (FFN) layer and used automatically during inference. 

Note that currently the output_transform is saved both in the model's state_dict and and in the model's hyperparameters. This may be changed in the future to align with `lightning`'s recommendations. You can ignore any messages about this.

In [3]:
output_scaler = train_dset.normalize_targets()
val_dset.normalize_targets(output_scaler)
# test_dset targets not scaled

output_transform = UnscaleTransform.from_standard_scaler(output_scaler)

ffn = RegressionFFN(output_transform=output_transform)

### Scaling extra atom and bond features - Message Passing

The atom and bond features generated by Chemprop [featurizers](./featurizers/molgraph_molecule_featurizer.ipynb) are either multi-hot or on the order of 1. We recommend scaling extra atom and bond features to also be on the order of 1. Like the target scaler, these scalers are fit to the training data, applied to the validation data, and then saved to the model (in this case the message passing layer) so that they are applied automatically to the test dataset during inference.

In [4]:
V_f_scaler = train_dset.normalize_inputs("V_f")
E_f_scaler = train_dset.normalize_inputs("E_f")

val_dset.normalize_inputs("V_f", V_f_scaler)
val_dset.normalize_inputs("E_f", E_f_scaler)

The scalers are used to make `ScaleTransform`s. These are combined into a `GraphTransform` which is given to the message passing module. Note that `ScaleTransform` acts on the whole feature vector, not just the extra features. The `ScaleTransform`'s mean and scale arrays are padded with enough zeros and ones so that only the extra features are actually scaled. The amount of padding required is the length of the default features of the featurizer.

In [5]:
from chemprop.featurizers import SimpleMoleculeMolGraphFeaturizer

featurizer = SimpleMoleculeMolGraphFeaturizer(
    extra_atom_fdim=n_extra_atom_features, extra_bond_fdim=n_extra_bond_features
)
n_V_features = featurizer.atom_fdim - featurizer.extra_atom_fdim
n_E_features = featurizer.bond_fdim - featurizer.extra_bond_fdim

V_f_transform = ScaleTransform.from_standard_scaler(V_f_scaler, pad=n_V_features)
E_f_transform = ScaleTransform.from_standard_scaler(E_f_scaler, pad=n_E_features)

graph_transform = GraphTransform(V_f_transform, E_f_transform)

mp = BondMessagePassing(graph_transform=graph_transform)

If you only have one of extra atom features or extra bond features, you can set the transform for the unused option to `torch.nn.Identity`.

In [6]:
graph_transform = GraphTransform(V_transform=torch.nn.Identity(), E_transform=E_f_transform)

### Scaling extra atom descriptors - Message Passing

The atom descriptors from message passing (before aggregation) are also likely to be on the order of 1 so extra atom descriptors should also be scaled. No padding is needed (unlike above) as this scaling is only applied to the extra atom descriptors. The `ScaleTransform` is given to the message passing module for use during inference.

In [7]:
V_d_scaler = train_dset.normalize_inputs("V_d")
val_dset.normalize_inputs("V_d", V_d_scaler)

V_d_transform = ScaleTransform.from_standard_scaler(V_d_scaler)

mp = BondMessagePassing(V_d_transform=V_d_transform)

A `GraphTransform` and `ScaleTransform` can both be given to the message passing.

In [8]:
mp = BondMessagePassing(graph_transform=graph_transform, V_d_transform=V_d_transform)

### Scaling extra datapoint descriptors - MPNN

The molecule/reaction descriptors from message passing (after aggregation) are batch normalized by default to be on the order of 1 (can be turned off, see the [model notebook](./models/basic_mpnn_model.ipynb)). Therefore we also recommended scaling the extra datapoint level descriptors. The `ScaleTransform` for this is given to the `MPNN` or `MulticomponentMPNN` module.

In [9]:
X_d_scaler = train_dset.normalize_inputs("X_d")
val_dset.normalize_inputs("X_d", X_d_scaler)

X_d_transform = ScaleTransform.from_standard_scaler(X_d_scaler)

chemprop_model = MPNN(
    BondMessagePassing(), NormAggregation(), RegressionFFN(), X_d_transform=X_d_transform
)