File size: 3,020 Bytes
ed339ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
# Training Models
## Dataset Specification
From the repository top level directory, run the following command to get the datasets reported in the paper:
```bash
cd src
python get_studies_datasets.py
```
For training on custom datasets, please refer to the class `PROTAC_Dataset` in the file [`protac_dataset.py`](../protac_degradation_predictor/protac_dataset.py). The class expects a Pandas dataframe, so plase assemble a file to be parsed into a Pandas DataFrame with the following columns:
| Column Name | Type | Description |
| --- | --- | --- |
| Smiles | str | The SMILES representation of the PROTAC molecule. |
| Uniprot | str | The Uniprot ID of the target protein. |
| E3 Ligase Uniprot | str | The Uniprot ID of the E3 ligase. |
| Cell Line Identifier | str | The cell line identifier as one reported in Cellosaurus. |
| `<active_label>` | bool | The activity label of the PROTAC molecule to be predicted by the model. |
The column `<active_label>` is set _"Active"_ as default in the `PROTAC_Dataset` class and in the `hyperparameter_tuning_and_training` function (see below for how to use it).
## Training on Custom Data
For training on custom datasets, please refer to the function `hyperparameter_tuning_and_training` in [`optuna_utils.py`](../protac_degradation_predictor/optuna_utils.py) and the file [`run_experiments.py`](../src/run_experiments.py) for inspiration on how to use the function.
An example of skeleton implementation is as follows:
```python
import protac_degradation_predictor as pdp
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
# Load train/val and test dataframes
train_val_df = pd.read_csv('path/to/custom_dataset.csv')
test_df = pd.read_csv('path/to/test_dataset.csv') # Load one of our test datasets
# NOTE: Make sure to avoid data leakage by removing leaking data in the train/val
# dataframe. Do NOT do remove/alter the test set, as it would impair comparison
# with our work. Data leakage can occur if the test set contains any combination
# of SMILES, Uniprot, E3 Ligase Uniprot, or Cell Line Identifier that is present
# in the train/val set too.
# Precompute Morgan fingerprints
unique_smiles = pd.concat([train_val_df, test_df])['Smiles'].unique().tolist()
smiles2fp = {s: np.array(pdp.get_fingerprint(s)) for s in unique_smiles}
# Load embedding dictionaries
protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
# Setup Cross-Validation object
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
pdp.hyperparameter_tuning_and_training(
protein2embedding=protein2embedding,
cell2embedding=cell2embedding,
smiles2fp=smiles2fp,
train_val_df=train_val_df,
test_df=test_df,
kf=kf,
n_models_for_test=3,
n_trials=100,
max_epochs=20,
logger_save_dir='../logs',
logger_name=f'logs_{experiment_name}',
study_filename=f'../reports/study_{experiment_name}.pkl',
)
``` |