ribesstefano
commited on
Commit
•
5e01175
1
Parent(s):
6101de8
Started working on packaging the repository
Browse files- notebooks/plotting_dragradation_activity_performance.ipynb +0 -0
- notebooks/protac_degradation_predictor.ipynb +10 -2
- notebooks/protac_degradation_predictor.py +3 -2
- protac_degradation_predictor/__init__.py +7 -0
- protac_degradation_predictor/config.py +37 -0
- protac_degradation_predictor/data/PROTAC-DB.csv +0 -0
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_0_test_split_0.2.pkl → protac_degradation_predictor/data/cell2embedding.pkl +2 -2
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_1_test_split_0.2.pkl → protac_degradation_predictor/data/uniprot2embedding.h5 +2 -2
- protac_degradation_predictor/data_utils.py +46 -0
- protac_degradation_predictor/optuna_utils.py +318 -0
- protac_degradation_predictor/protac_dataset.py +193 -0
- protac_degradation_predictor/protac_degradation_predictor.py +88 -0
- protac_degradation_predictor/pytorch_models.py +471 -0
- protac_degradation_predictor/sklearn_models.py +243 -0
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_0_test_split_0.1.pkl +1 -1
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_1_test_split_0.1.pkl +1 -1
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_2_test_split_0.1.pkl +1 -1
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_2_test_split_0.2.pkl +0 -3
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_3_test_split_0.1.pkl +1 -1
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_3_test_split_0.2.pkl +0 -3
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_4_test_split_0.1.pkl +1 -1
- reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_4_test_split_0.2.pkl +0 -3
- setup.py +21 -0
notebooks/plotting_dragradation_activity_performance.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/protac_degradation_predictor.ipynb
CHANGED
@@ -1719,8 +1719,16 @@
|
|
1719 |
}
|
1720 |
],
|
1721 |
"source": [
|
1722 |
-
"
|
1723 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1724 |
"\n",
|
1725 |
"# Generic function to fit and evaluate a classifier model (given as argument),\n",
|
1726 |
"# on train and val sets (and optionally a test set) given as dataframes\n",
|
|
|
1719 |
}
|
1720 |
],
|
1721 |
"source": [
|
1722 |
+
"import torch\n",
|
1723 |
+
"import torch.nn as nn\n",
|
1724 |
+
"from torchmetrics import (\n",
|
1725 |
+
" Accuracy,\n",
|
1726 |
+
" AUROC,\n",
|
1727 |
+
" Precision,\n",
|
1728 |
+
" Recall,\n",
|
1729 |
+
" F1Score,\n",
|
1730 |
+
" MetricCollection,\n",
|
1731 |
+
")\n",
|
1732 |
"\n",
|
1733 |
"# Generic function to fit and evaluate a classifier model (given as argument),\n",
|
1734 |
"# on train and val sets (and optionally a test set) given as dataframes\n",
|
notebooks/protac_degradation_predictor.py
CHANGED
@@ -680,7 +680,7 @@ def train_model(
|
|
680 |
hidden_dim (int): The hidden dimension of the model.
|
681 |
batch_size (int): The batch size.
|
682 |
learning_rate (float): The learning rate.
|
683 |
-
max_epochs (int):
|
684 |
smiles_emb_dim (int): The dimension of the SMILES embeddings.
|
685 |
smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
|
686 |
fast_dev_run (bool): Whether to run a fast development run.
|
@@ -985,6 +985,8 @@ def main(
|
|
985 |
encoder = OrdinalEncoder()
|
986 |
protac_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
|
987 |
active_df = protac_df[protac_df[active_col].notna()].copy()
|
|
|
|
|
988 |
tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
|
989 |
|
990 |
test_df = []
|
@@ -992,7 +994,6 @@ def main(
|
|
992 |
# entries to the test_df if: 1) the test_df lenght + the group entries is less
|
993 |
# 20% of the active_df lenght, and 2) the percentage of True and False entries
|
994 |
# in the active_col in test_df is roughly 50%.
|
995 |
-
# Start the loop from the groups containing the smallest number of entries.
|
996 |
for group in tanimoto_groups:
|
997 |
group_df = active_df[active_df['Tanimoto Group'] == group]
|
998 |
if test_df == []:
|
|
|
680 |
hidden_dim (int): The hidden dimension of the model.
|
681 |
batch_size (int): The batch size.
|
682 |
learning_rate (float): The learning rate.
|
683 |
+
max_epochs (int): Th e maximum number of epochs.
|
684 |
smiles_emb_dim (int): The dimension of the SMILES embeddings.
|
685 |
smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
|
686 |
fast_dev_run (bool): Whether to run a fast development run.
|
|
|
985 |
encoder = OrdinalEncoder()
|
986 |
protac_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
|
987 |
active_df = protac_df[protac_df[active_col].notna()].copy()
|
988 |
+
# Sort the groups so that samples with the highest tanimoto similarity,
|
989 |
+
# i.e., the "less similar" ones, are placed in the test set first
|
990 |
tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
|
991 |
|
992 |
test_df = []
|
|
|
994 |
# entries to the test_df if: 1) the test_df lenght + the group entries is less
|
995 |
# 20% of the active_df lenght, and 2) the percentage of True and False entries
|
996 |
# in the active_col in test_df is roughly 50%.
|
|
|
997 |
for group in tanimoto_groups:
|
998 |
group_df = active_df[active_df['Tanimoto Group'] == group]
|
999 |
if test_df == []:
|
protac_degradation_predictor/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .protac_degradation_predictor import (
|
2 |
+
PROTAC_Model,
|
3 |
+
train_model,
|
4 |
+
)
|
5 |
+
|
6 |
+
__version__ = "0.0.1"
|
7 |
+
__author__ = "Stefano Ribes"
|
protac_degradation_predictor/config.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
@dataclass(frozen=True)
|
4 |
+
class Config:
|
5 |
+
# Embeddings information
|
6 |
+
morgan_radius: int = 15
|
7 |
+
fingerprint_size: int = 224
|
8 |
+
protein_embedding_size: int = 1024
|
9 |
+
cell_embedding_size: int = 768
|
10 |
+
|
11 |
+
# Data information
|
12 |
+
dmax_threshold: float = 0.6
|
13 |
+
pdc50_threshold: float = 6.0
|
14 |
+
e3_ligase2uniprot: dict = {
|
15 |
+
'VHL': 'P40337',
|
16 |
+
'CRBN': 'Q96SW2',
|
17 |
+
'DCAF11': 'Q8TEB1',
|
18 |
+
'DCAF15': 'Q66K64',
|
19 |
+
'DCAF16': 'Q9NXF7',
|
20 |
+
'MDM2': 'Q00987',
|
21 |
+
'Mdm2': 'Q00987',
|
22 |
+
'XIAP': 'P98170',
|
23 |
+
'cIAP1': 'Q7Z460',
|
24 |
+
'IAP': 'P98170', # I couldn't find the Uniprot ID for IAP, so it's XIAP instead
|
25 |
+
'Iap': 'P98170', # I couldn't find the Uniprot ID for IAP, so it's XIAP instead
|
26 |
+
'AhR': 'P35869',
|
27 |
+
'RNF4': 'P78317',
|
28 |
+
'RNF114': 'Q9Y508',
|
29 |
+
'FEM1B': 'Q9UK73',
|
30 |
+
'Ubr1': 'Q8IWV7',
|
31 |
+
}
|
32 |
+
|
33 |
+
def __post_init__(self):
|
34 |
+
self.active_label: str = f'Active (Dmax {self.dmax_threshold}, pDC50 {self.pdc50_threshold})'
|
35 |
+
|
36 |
+
|
37 |
+
config = Config()
|
protac_degradation_predictor/data/PROTAC-DB.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_0_test_split_0.2.pkl → protac_degradation_predictor/data/cell2embedding.pkl
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:627e8ce3842afeb6bb7d5caa5ec1ba034c36dc77fab70734e15dca340a7fd718
|
3 |
+
size 3550864
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_1_test_split_0.2.pkl → protac_degradation_predictor/data/uniprot2embedding.h5
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:19f4b8c73652392db7840962d1a7817c7e899716e2bb758e4947c8c2bb265336
|
3 |
+
size 51089512
|
protac_degradation_predictor/data_utils.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pkg_resources
|
3 |
+
import pickle
|
4 |
+
from typing import Dict
|
5 |
+
|
6 |
+
from config import config
|
7 |
+
|
8 |
+
import h5py
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
from rdkit import Chem
|
12 |
+
from rdkit.Chem import AllChem
|
13 |
+
from joblib import Memory
|
14 |
+
|
15 |
+
|
16 |
+
home_dir = os.path.expanduser('~')
|
17 |
+
cachedir = os.path.join(home_dir, '.cache', 'protac_degradation_predictor')
|
18 |
+
memory = Memory(cachedir, verbose=0)
|
19 |
+
|
20 |
+
|
21 |
+
@memory.cache
|
22 |
+
def load_protein2embedding() -> Dict[str, np.ndarray]:
|
23 |
+
embeddings_path = pkg_resources.resource_stream(__name__, 'data/uniprot2embedding.h5')
|
24 |
+
protein2embedding = {}
|
25 |
+
with h5py.File(embeddings_path, "r") as file:
|
26 |
+
for sequence_id in file.keys():
|
27 |
+
embedding = file[sequence_id][:]
|
28 |
+
protein2embedding[sequence_id] = np.array(embedding)
|
29 |
+
return protein2embedding
|
30 |
+
|
31 |
+
|
32 |
+
@memory.cache
|
33 |
+
def load_cell2embedding() -> Dict[str, np.ndarray]:
|
34 |
+
embeddings_path = pkg_resources.resource_stream(__name__, 'data/cell2embedding.pkl')
|
35 |
+
with open(embeddings_path, 'rb') as f:
|
36 |
+
cell2embedding = pickle.load(f)
|
37 |
+
return cell2embedding
|
38 |
+
|
39 |
+
|
40 |
+
def get_fingerprint(smiles: str) -> np.ndarray:
|
41 |
+
morgan_fpgen = AllChem.GetMorganGenerator(
|
42 |
+
radius=config.morgan_radius,
|
43 |
+
fpSize=config.fingerprint_size,
|
44 |
+
includeChirality=True,
|
45 |
+
)
|
46 |
+
return morgan_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
|
protac_degradation_predictor/optuna_utils.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Literal, List, Tuple, Optional, Dict
|
3 |
+
|
4 |
+
from pytorch_models import train_model
|
5 |
+
from sklearn_models import (
|
6 |
+
train_sklearn_model,
|
7 |
+
suggest_random_forest,
|
8 |
+
suggest_logistic_regression,
|
9 |
+
suggest_svc,
|
10 |
+
suggest_gradient_boosting,
|
11 |
+
)
|
12 |
+
|
13 |
+
import optuna
|
14 |
+
from optuna.samplers import TPESampler
|
15 |
+
import joblib
|
16 |
+
import pandas as pd
|
17 |
+
from sklearn.ensemble import (
|
18 |
+
RandomForestClassifier,
|
19 |
+
GradientBoostingClassifier,
|
20 |
+
)
|
21 |
+
from sklearn.linear_model import LogisticRegression
|
22 |
+
from sklearn.svm import SVC
|
23 |
+
|
24 |
+
|
25 |
+
def pytorch_model_objective(
|
26 |
+
trial: optuna.Trial,
|
27 |
+
protein2embedding: Dict,
|
28 |
+
cell2embedding: Dict,
|
29 |
+
smiles2fp: Dict,
|
30 |
+
train_df: pd.DataFrame,
|
31 |
+
val_df: pd.DataFrame,
|
32 |
+
hidden_dim_options: List[int] = [256, 512, 768],
|
33 |
+
batch_size_options: List[int] = [8, 16, 32],
|
34 |
+
learning_rate_options: Tuple[float, float] = (1e-5, 1e-3),
|
35 |
+
smote_k_neighbors_options: List[int] = list(range(3, 16)),
|
36 |
+
dropout_options: Tuple[float, float] = (0.1, 0.5),
|
37 |
+
fast_dev_run: bool = False,
|
38 |
+
active_label: str = 'Active',
|
39 |
+
disabled_embeddings: List[str] = [],
|
40 |
+
max_epochs: int = 100,
|
41 |
+
) -> float:
|
42 |
+
""" Objective function for hyperparameter optimization.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
trial (optuna.Trial): The Optuna trial object.
|
46 |
+
train_df (pd.DataFrame): The training set.
|
47 |
+
val_df (pd.DataFrame): The validation set.
|
48 |
+
hidden_dim_options (List[int]): The hidden dimension options.
|
49 |
+
batch_size_options (List[int]): The batch size options.
|
50 |
+
learning_rate_options (Tuple[float, float]): The learning rate options.
|
51 |
+
smote_k_neighbors_options (List[int]): The SMOTE k neighbors options.
|
52 |
+
dropout_options (Tuple[float, float]): The dropout options.
|
53 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
54 |
+
active_label (str): The active label column.
|
55 |
+
disabled_embeddings (List[str]): The list of disabled embeddings.
|
56 |
+
"""
|
57 |
+
# Generate the hyperparameters
|
58 |
+
hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
|
59 |
+
batch_size = trial.suggest_categorical('batch_size', batch_size_options)
|
60 |
+
learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
|
61 |
+
join_embeddings = trial.suggest_categorical('join_embeddings', ['beginning', 'concat', 'sum'])
|
62 |
+
smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
|
63 |
+
use_smote = trial.suggest_categorical('use_smote', [True, False])
|
64 |
+
apply_scaling = trial.suggest_categorical('apply_scaling', [True, False])
|
65 |
+
dropout = trial.suggest_float('dropout', *dropout_options)
|
66 |
+
|
67 |
+
# Train the model with the current set of hyperparameters
|
68 |
+
_, _, metrics = train_model(
|
69 |
+
protein2embedding,
|
70 |
+
cell2embedding,
|
71 |
+
smiles2fp,
|
72 |
+
train_df,
|
73 |
+
val_df,
|
74 |
+
hidden_dim=hidden_dim,
|
75 |
+
batch_size=batch_size,
|
76 |
+
join_embeddings=join_embeddings,
|
77 |
+
learning_rate=learning_rate,
|
78 |
+
dropout=dropout,
|
79 |
+
max_epochs=max_epochs,
|
80 |
+
smote_k_neighbors=smote_k_neighbors,
|
81 |
+
apply_scaling=apply_scaling,
|
82 |
+
use_smote=use_smote,
|
83 |
+
use_logger=False,
|
84 |
+
fast_dev_run=fast_dev_run,
|
85 |
+
active_label=active_label,
|
86 |
+
disabled_embeddings=disabled_embeddings,
|
87 |
+
)
|
88 |
+
|
89 |
+
# Metrics is a dictionary containing at least the validation loss
|
90 |
+
val_loss = metrics['val_loss']
|
91 |
+
val_acc = metrics['val_acc']
|
92 |
+
val_roc_auc = metrics['val_roc_auc']
|
93 |
+
|
94 |
+
# Optuna aims to minimize the pytorch_model_objective
|
95 |
+
return val_loss - val_acc - val_roc_auc
|
96 |
+
|
97 |
+
|
98 |
+
def hyperparameter_tuning_and_training(
|
99 |
+
protein2embedding: Dict,
|
100 |
+
cell2embedding: Dict,
|
101 |
+
smiles2fp: Dict,
|
102 |
+
train_df: pd.DataFrame,
|
103 |
+
val_df: pd.DataFrame,
|
104 |
+
test_df: Optional[pd.DataFrame] = None,
|
105 |
+
fast_dev_run: bool = False,
|
106 |
+
n_trials: int = 50,
|
107 |
+
logger_name: str = 'protac_hparam_search',
|
108 |
+
active_label: str = 'Active',
|
109 |
+
disabled_embeddings: List[str] = [],
|
110 |
+
study_filename: Optional[str] = None,
|
111 |
+
) -> tuple:
|
112 |
+
""" Hyperparameter tuning and training of a PROTAC model.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
train_df (pd.DataFrame): The training set.
|
116 |
+
val_df (pd.DataFrame): The validation set.
|
117 |
+
test_df (pd.DataFrame): The test set.
|
118 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
119 |
+
n_trials (int): The number of hyperparameter optimization trials.
|
120 |
+
logger_name (str): The name of the logger.
|
121 |
+
active_label (str): The active label column.
|
122 |
+
disabled_embeddings (List[str]): The list of disabled embeddings.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
tuple: The trained model, the trainer, and the best metrics.
|
126 |
+
"""
|
127 |
+
# Define the search space
|
128 |
+
hidden_dim_options = [256, 512, 768]
|
129 |
+
batch_size_options = [8, 16, 32]
|
130 |
+
learning_rate_options = (1e-5, 1e-3) # min and max values for loguniform distribution
|
131 |
+
smote_k_neighbors_options = list(range(3, 16))
|
132 |
+
|
133 |
+
# Set the verbosity of Optuna
|
134 |
+
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
135 |
+
# Create an Optuna study object
|
136 |
+
sampler = TPESampler(seed=42, multivariate=True)
|
137 |
+
study = optuna.create_study(direction='minimize', sampler=sampler)
|
138 |
+
|
139 |
+
study_loaded = False
|
140 |
+
if study_filename:
|
141 |
+
if os.path.exists(study_filename):
|
142 |
+
study = joblib.load(study_filename)
|
143 |
+
study_loaded = True
|
144 |
+
print(f'Loaded study from {study_filename}')
|
145 |
+
|
146 |
+
if not study_loaded:
|
147 |
+
study.optimize(
|
148 |
+
lambda trial: pytorch_model_objective(
|
149 |
+
trial=trial,
|
150 |
+
protein2embedding=protein2embedding,
|
151 |
+
cell2embedding=cell2embedding,
|
152 |
+
smiles2fp=smiles2fp,
|
153 |
+
train_df=train_df,
|
154 |
+
val_df=val_df,
|
155 |
+
hidden_dim_options=hidden_dim_options,
|
156 |
+
batch_size_options=batch_size_options,
|
157 |
+
learning_rate_options=learning_rate_options,
|
158 |
+
smote_k_neighbors_options=smote_k_neighbors_options,
|
159 |
+
fast_dev_run=fast_dev_run,
|
160 |
+
active_label=active_label,
|
161 |
+
disabled_embeddings=disabled_embeddings,
|
162 |
+
),
|
163 |
+
n_trials=n_trials,
|
164 |
+
)
|
165 |
+
if study_filename:
|
166 |
+
joblib.dump(study, study_filename)
|
167 |
+
|
168 |
+
# Retrain the model with the best hyperparameters
|
169 |
+
model, trainer, metrics = train_model(
|
170 |
+
protein2embedding=protein2embedding,
|
171 |
+
cell2embedding=cell2embedding,
|
172 |
+
smiles2fp=smiles2fp,
|
173 |
+
train_df=train_df,
|
174 |
+
val_df=val_df,
|
175 |
+
test_df=test_df,
|
176 |
+
use_logger=True,
|
177 |
+
logger_name=logger_name,
|
178 |
+
fast_dev_run=fast_dev_run,
|
179 |
+
active_label=active_label,
|
180 |
+
disabled_embeddings=disabled_embeddings,
|
181 |
+
**study.best_params,
|
182 |
+
)
|
183 |
+
|
184 |
+
# Report the best hyperparameters found
|
185 |
+
metrics.update({f'hparam_{k}': v for k, v in study.best_params.items()})
|
186 |
+
|
187 |
+
# Return the best metrics
|
188 |
+
return model, trainer, metrics
|
189 |
+
|
190 |
+
|
191 |
+
def sklearn_model_objective(
|
192 |
+
trial: optuna.Trial,
|
193 |
+
protein2embedding: Dict,
|
194 |
+
cell2embedding: Dict,
|
195 |
+
smiles2fp: Dict,
|
196 |
+
train_df: pd.DataFrame,
|
197 |
+
val_df: pd.DataFrame,
|
198 |
+
model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
|
199 |
+
active_label: str = 'Active',
|
200 |
+
) -> float:
|
201 |
+
""" Objective function for hyperparameter optimization.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
trial (optuna.Trial): The Optuna trial object.
|
205 |
+
train_df (pd.DataFrame): The training set.
|
206 |
+
val_df (pd.DataFrame): The validation set.
|
207 |
+
model_type (str): The model type.
|
208 |
+
hyperparameters (Dict): The hyperparameters for the model.
|
209 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
210 |
+
active_label (str): The active label column.
|
211 |
+
"""
|
212 |
+
|
213 |
+
# Generate the hyperparameters
|
214 |
+
use_single_scaler = trial.suggest_categorical('use_single_scaler', [True, False])
|
215 |
+
if model_type == 'RandomForest':
|
216 |
+
clf = suggest_random_forest(trial)
|
217 |
+
elif model_type == 'SVC':
|
218 |
+
clf = suggest_svc(trial)
|
219 |
+
elif model_type == 'LogisticRegression':
|
220 |
+
clf = suggest_logistic_regression(trial)
|
221 |
+
elif model_type == 'GradientBoosting':
|
222 |
+
clf = suggest_gradient_boosting(trial)
|
223 |
+
else:
|
224 |
+
raise ValueError(f'Invalid model type: {model_type}. Available: RandomForest, SVC, LogisticRegression, GradientBoosting.')
|
225 |
+
|
226 |
+
# Train the model with the current set of hyperparameters
|
227 |
+
_, metrics = train_sklearn_model(
|
228 |
+
clf=clf,
|
229 |
+
protein2embedding=protein2embedding,
|
230 |
+
cell2embedding=cell2embedding,
|
231 |
+
smiles2fp=smiles2fp,
|
232 |
+
train_df=train_df,
|
233 |
+
val_df=val_df,
|
234 |
+
active_label=active_label,
|
235 |
+
use_single_scaler=use_single_scaler,
|
236 |
+
)
|
237 |
+
|
238 |
+
# Metrics is a dictionary containing at least the validation loss
|
239 |
+
val_acc = metrics['val_acc']
|
240 |
+
val_roc_auc = metrics['val_roc_auc']
|
241 |
+
|
242 |
+
# Optuna aims to minimize the sklearn_model_objective
|
243 |
+
return - val_acc - val_roc_auc
|
244 |
+
|
245 |
+
|
246 |
+
def hyperparameter_tuning_and_training_sklearn(
|
247 |
+
protein2embedding: Dict,
|
248 |
+
cell2embedding: Dict,
|
249 |
+
smiles2fp: Dict,
|
250 |
+
train_df: pd.DataFrame,
|
251 |
+
val_df: pd.DataFrame,
|
252 |
+
test_df: Optional[pd.DataFrame] = None,
|
253 |
+
model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
|
254 |
+
active_label: str = 'Active',
|
255 |
+
n_trials: int = 50,
|
256 |
+
logger_name: str = 'protac_hparam_search',
|
257 |
+
study_filename: Optional[str] = None,
|
258 |
+
) -> Tuple:
|
259 |
+
# Set the verbosity of Optuna
|
260 |
+
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
261 |
+
# Create an Optuna study object
|
262 |
+
sampler = TPESampler(seed=42, multivariate=True)
|
263 |
+
study = optuna.create_study(direction='minimize', sampler=sampler)
|
264 |
+
|
265 |
+
study_loaded = False
|
266 |
+
if study_filename:
|
267 |
+
if os.path.exists(study_filename):
|
268 |
+
study = joblib.load(study_filename)
|
269 |
+
study_loaded = True
|
270 |
+
print(f'Loaded study from {study_filename}')
|
271 |
+
|
272 |
+
if not study_loaded:
|
273 |
+
study.optimize(
|
274 |
+
lambda trial: sklearn_model_objective(
|
275 |
+
trial=trial,
|
276 |
+
protein2embedding=protein2embedding,
|
277 |
+
cell2embedding=cell2embedding,
|
278 |
+
smiles2fp=smiles2fp,
|
279 |
+
train_df=train_df,
|
280 |
+
val_df=val_df,
|
281 |
+
model_type=model_type,
|
282 |
+
active_label=active_label,
|
283 |
+
),
|
284 |
+
n_trials=n_trials,
|
285 |
+
)
|
286 |
+
if study_filename:
|
287 |
+
joblib.dump(study, study_filename)
|
288 |
+
|
289 |
+
# Retrain the model with the best hyperparameters
|
290 |
+
best_hyperparameters = {k.replace('model_', ''): v for k, v in study.best_params.items() if k.startswith('model_')}
|
291 |
+
if model_type == 'RandomForest':
|
292 |
+
clf = RandomForestClassifier(random_state=42, **best_hyperparameters)
|
293 |
+
elif model_type == 'SVC':
|
294 |
+
clf = SVC(random_state=42, probability=True, **best_hyperparameters)
|
295 |
+
elif model_type == 'LogisticRegression':
|
296 |
+
clf = LogisticRegression(random_state=42, max_iter=1000, **best_hyperparameters)
|
297 |
+
elif model_type == 'GradientBoosting':
|
298 |
+
clf = GradientBoostingClassifier(random_state=42, **best_hyperparameters)
|
299 |
+
else:
|
300 |
+
raise ValueError(f'Invalid model type: {model_type}. Available: RandomForest, SVC, LogisticRegression, GradientBoosting.')
|
301 |
+
|
302 |
+
model, metrics = train_sklearn_model(
|
303 |
+
clf=clf,
|
304 |
+
protein2embedding=protein2embedding,
|
305 |
+
cell2embedding=cell2embedding,
|
306 |
+
smiles2fp=smiles2fp,
|
307 |
+
train_df=train_df,
|
308 |
+
val_df=val_df,
|
309 |
+
test_df=test_df,
|
310 |
+
active_label=active_label,
|
311 |
+
use_single_scaler=study.best_params['use_single_scaler'],
|
312 |
+
)
|
313 |
+
|
314 |
+
# Report the best hyperparameters found
|
315 |
+
metrics.update({f'hparam_{k}': v for k, v in study.best_params.items()})
|
316 |
+
|
317 |
+
# Return the best metrics
|
318 |
+
return model, metrics
|
protac_degradation_predictor/protac_dataset.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, List, Tuple, Optional, Dict
|
2 |
+
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
import numpy as np
|
5 |
+
from imblearn.over_sampling import SMOTE, ADASYN
|
6 |
+
import pandas as pd
|
7 |
+
from sklearn.preprocessing import StandardScaler
|
8 |
+
|
9 |
+
|
10 |
+
class PROTAC_Dataset(Dataset):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
protac_df: pd.DataFrame,
|
14 |
+
protein2embedding: Dict,
|
15 |
+
cell2embedding: Dict,
|
16 |
+
smiles2fp: Dict,
|
17 |
+
use_smote: bool = False,
|
18 |
+
oversampler: Optional[SMOTE | ADASYN] = None,
|
19 |
+
active_label: str = 'Active',
|
20 |
+
):
|
21 |
+
""" Initialize the PROTAC dataset
|
22 |
+
|
23 |
+
Args:
|
24 |
+
protac_df (pd.DataFrame): The PROTAC dataframe
|
25 |
+
protein2embedding (dict): Dictionary of protein embeddings
|
26 |
+
cell2embedding (dict): Dictionary of cell line embeddings
|
27 |
+
smiles2fp (dict): Dictionary of SMILES to fingerprint
|
28 |
+
use_smote (bool): Whether to use SMOTE for oversampling
|
29 |
+
use_ored_activity (bool): Whether to use the 'Active - OR' column
|
30 |
+
"""
|
31 |
+
# Filter out examples with NaN in active_col column
|
32 |
+
self.data = protac_df # [~protac_df[active_col].isna()]
|
33 |
+
self.protein2embedding = protein2embedding
|
34 |
+
self.cell2embedding = cell2embedding
|
35 |
+
self.smiles2fp = smiles2fp
|
36 |
+
self.active_label = active_label
|
37 |
+
self.use_single_scaler = None
|
38 |
+
|
39 |
+
self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0]
|
40 |
+
self.protein_emb_dim = protein2embedding[list(
|
41 |
+
protein2embedding.keys())[0]].shape[0]
|
42 |
+
self.cell_emb_dim = cell2embedding[list(
|
43 |
+
cell2embedding.keys())[0]].shape[0]
|
44 |
+
|
45 |
+
# Look up the embeddings
|
46 |
+
self.data = pd.DataFrame({
|
47 |
+
'Smiles': self.data['Smiles'].apply(lambda x: smiles2fp[x].astype(np.float32)).tolist(),
|
48 |
+
'Uniprot': self.data['Uniprot'].apply(lambda x: protein2embedding[x].astype(np.float32)).tolist(),
|
49 |
+
'E3 Ligase Uniprot': self.data['E3 Ligase Uniprot'].apply(lambda x: protein2embedding[x].astype(np.float32)).tolist(),
|
50 |
+
'Cell Line Identifier': self.data['Cell Line Identifier'].apply(lambda x: cell2embedding[x].astype(np.float32)).tolist(),
|
51 |
+
self.active_label: self.data[self.active_label].astype(np.float32).tolist(),
|
52 |
+
})
|
53 |
+
|
54 |
+
# Apply SMOTE
|
55 |
+
self.use_smote = use_smote
|
56 |
+
self.oversampler = oversampler
|
57 |
+
if self.use_smote:
|
58 |
+
self.apply_smote()
|
59 |
+
|
60 |
+
def apply_smote(self):
|
61 |
+
# Prepare the dataset for SMOTE
|
62 |
+
features = []
|
63 |
+
labels = []
|
64 |
+
for _, row in self.data.iterrows():
|
65 |
+
features.append(np.hstack([
|
66 |
+
row['Smiles'],
|
67 |
+
row['Uniprot'],
|
68 |
+
row['E3 Ligase Uniprot'],
|
69 |
+
row['Cell Line Identifier'],
|
70 |
+
]))
|
71 |
+
labels.append(row[self.active_label])
|
72 |
+
|
73 |
+
# Convert to numpy array
|
74 |
+
features = np.array(features).astype(np.float32)
|
75 |
+
labels = np.array(labels).astype(np.float32)
|
76 |
+
|
77 |
+
# Initialize SMOTE and fit
|
78 |
+
if self.oversampler is None:
|
79 |
+
oversampler = SMOTE(random_state=42)
|
80 |
+
else:
|
81 |
+
oversampler = self.oversampler
|
82 |
+
features_smote, labels_smote = oversampler.fit_resample(features, labels)
|
83 |
+
|
84 |
+
# Separate the features back into their respective embeddings
|
85 |
+
smiles_embs = features_smote[:, :self.smiles_emb_dim]
|
86 |
+
poi_embs = features_smote[:,
|
87 |
+
self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]
|
88 |
+
e3_embs = features_smote[:, self.smiles_emb_dim +
|
89 |
+
self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]
|
90 |
+
cell_embs = features_smote[:, -self.cell_emb_dim:]
|
91 |
+
|
92 |
+
# Reconstruct the dataframe with oversampled data
|
93 |
+
df_smote = pd.DataFrame({
|
94 |
+
'Smiles': list(smiles_embs),
|
95 |
+
'Uniprot': list(poi_embs),
|
96 |
+
'E3 Ligase Uniprot': list(e3_embs),
|
97 |
+
'Cell Line Identifier': list(cell_embs),
|
98 |
+
self.active_label: labels_smote
|
99 |
+
})
|
100 |
+
self.data = df_smote
|
101 |
+
|
102 |
+
def fit_scaling(self, use_single_scaler: bool = False, **scaler_kwargs) -> dict:
|
103 |
+
""" Fit the scalers for the data.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
use_single_scaler (bool): Whether to use a single scaler for all features.
|
107 |
+
scaler_kwargs: Keyword arguments for the StandardScaler.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
dict: The fitted scalers.
|
111 |
+
"""
|
112 |
+
if use_single_scaler:
|
113 |
+
self.use_single_scaler = True
|
114 |
+
scaler = StandardScaler(**scaler_kwargs)
|
115 |
+
embeddings = np.hstack([
|
116 |
+
np.array(self.data['Smiles'].tolist()),
|
117 |
+
np.array(self.data['Uniprot'].tolist()),
|
118 |
+
np.array(self.data['E3 Ligase Uniprot'].tolist()),
|
119 |
+
np.array(self.data['Cell Line Identifier'].tolist()),
|
120 |
+
])
|
121 |
+
scaler.fit(embeddings)
|
122 |
+
return scaler
|
123 |
+
else:
|
124 |
+
self.use_single_scaler = False
|
125 |
+
scalers = {}
|
126 |
+
scalers['Smiles'] = StandardScaler(**scaler_kwargs)
|
127 |
+
scalers['Uniprot'] = StandardScaler(**scaler_kwargs)
|
128 |
+
scalers['E3 Ligase Uniprot'] = StandardScaler(**scaler_kwargs)
|
129 |
+
scalers['Cell Line Identifier'] = StandardScaler(**scaler_kwargs)
|
130 |
+
|
131 |
+
scalers['Smiles'].fit(np.stack(self.data['Smiles'].to_numpy()))
|
132 |
+
scalers['Uniprot'].fit(np.stack(self.data['Uniprot'].to_numpy()))
|
133 |
+
scalers['E3 Ligase Uniprot'].fit(np.stack(self.data['E3 Ligase Uniprot'].to_numpy()))
|
134 |
+
scalers['Cell Line Identifier'].fit(np.stack(self.data['Cell Line Identifier'].to_numpy()))
|
135 |
+
|
136 |
+
return scalers
|
137 |
+
|
138 |
+
def apply_scaling(self, scalers: dict, use_single_scaler: bool = False):
|
139 |
+
""" Apply scaling to the data.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
scalers (dict): The scalers for each feature.
|
143 |
+
use_single_scaler (bool): Whether to use a single scaler for all features.
|
144 |
+
"""
|
145 |
+
if self.use_single_scaler is None:
|
146 |
+
raise ValueError(
|
147 |
+
"The fit_scaling method must be called before apply_scaling.")
|
148 |
+
if use_single_scaler != self.use_single_scaler:
|
149 |
+
raise ValueError(
|
150 |
+
f"The use_single_scaler parameter must be the same as the one used in the fit_scaling method. Got {use_single_scaler}, previously {self.use_single_scaler}.")
|
151 |
+
if use_single_scaler:
|
152 |
+
embeddings = np.hstack([
|
153 |
+
np.array(self.data['Smiles'].tolist()),
|
154 |
+
np.array(self.data['Uniprot'].tolist()),
|
155 |
+
np.array(self.data['E3 Ligase Uniprot'].tolist()),
|
156 |
+
np.array(self.data['Cell Line Identifier'].tolist()),
|
157 |
+
])
|
158 |
+
scaled_embeddings = scalers.transform(embeddings)
|
159 |
+
self.data = pd.DataFrame({
|
160 |
+
'Smiles': list(scaled_embeddings[:, :self.smiles_emb_dim]),
|
161 |
+
'Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]),
|
162 |
+
'E3 Ligase Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim+self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]),
|
163 |
+
'Cell Line Identifier': list(scaled_embeddings[:, -self.cell_emb_dim:]),
|
164 |
+
self.active_label: self.data[self.active_label]
|
165 |
+
})
|
166 |
+
else:
|
167 |
+
self.data['Smiles'] = self.data['Smiles'].apply(lambda x: scalers['Smiles'].transform(x[np.newaxis, :])[0])
|
168 |
+
self.data['Uniprot'] = self.data['Uniprot'].apply(lambda x: scalers['Uniprot'].transform(x[np.newaxis, :])[0])
|
169 |
+
self.data['E3 Ligase Uniprot'] = self.data['E3 Ligase Uniprot'].apply(lambda x: scalers['E3 Ligase Uniprot'].transform(x[np.newaxis, :])[0])
|
170 |
+
self.data['Cell Line Identifier'] = self.data['Cell Line Identifier'].apply(lambda x: scalers['Cell Line Identifier'].transform(x[np.newaxis, :])[0])
|
171 |
+
|
172 |
+
def get_numpy_arrays(self):
|
173 |
+
X = np.hstack([
|
174 |
+
np.array(self.data['Smiles'].tolist()),
|
175 |
+
np.array(self.data['Uniprot'].tolist()),
|
176 |
+
np.array(self.data['E3 Ligase Uniprot'].tolist()),
|
177 |
+
np.array(self.data['Cell Line Identifier'].tolist()),
|
178 |
+
]).copy()
|
179 |
+
y = self.data[self.active_label].values.copy()
|
180 |
+
return X, y
|
181 |
+
|
182 |
+
def __len__(self):
|
183 |
+
return len(self.data)
|
184 |
+
|
185 |
+
def __getitem__(self, idx):
|
186 |
+
elem = {
|
187 |
+
'smiles_emb': self.data['Smiles'].iloc[idx],
|
188 |
+
'poi_emb': self.data['Uniprot'].iloc[idx],
|
189 |
+
'e3_emb': self.data['E3 Ligase Uniprot'].iloc[idx],
|
190 |
+
'cell_emb': self.data['Cell Line Identifier'].iloc[idx],
|
191 |
+
'active': self.data[self.active_label].iloc[idx],
|
192 |
+
}
|
193 |
+
return elem
|
protac_degradation_predictor/protac_degradation_predictor.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pkg_resources
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from pytorch_models import PROTAC_Model, load_model
|
5 |
+
from data_utils import (
|
6 |
+
load_protein2embedding,
|
7 |
+
load_cell2embedding,
|
8 |
+
get_fingerprint,
|
9 |
+
)
|
10 |
+
from config import config
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from torch import sigmoid
|
15 |
+
|
16 |
+
package_name = 'protac_degradation_predictor'
|
17 |
+
|
18 |
+
def get_protac_active_proba(
|
19 |
+
protac_smiles: str,
|
20 |
+
e3_ligase: str,
|
21 |
+
target_uniprot: str,
|
22 |
+
cell_line: str,
|
23 |
+
device: str = 'cpu',
|
24 |
+
) -> bool:
|
25 |
+
ckpt_path = pkg_resources.resource_stream(__name__, 'data/model.ckpt')
|
26 |
+
model = load_model(ckpt_path).to(device)
|
27 |
+
protein2embedding = load_protein2embedding()
|
28 |
+
cell2embedding = load_cell2embedding()
|
29 |
+
|
30 |
+
# Setup default embeddings
|
31 |
+
if e3_ligase not in config.e3_ligase2uniprot:
|
32 |
+
available_e3_ligases = ', '.join(list(config.e3_ligase2uniprot.keys()))
|
33 |
+
logging.warning(f"The E3 ligase {e3_ligase} is not in the database. Using the default E3 ligase. Available E3 ligases are: {available_e3_ligases}")
|
34 |
+
if target_uniprot not in protein2embedding:
|
35 |
+
logging.warning(f"The target protein {target_uniprot} is not in the database. Using the default target protein.")
|
36 |
+
if cell_line not in load_cell2embedding():
|
37 |
+
logging.warning(f"The cell line {cell_line} is not in the database. Using the default cell line.")
|
38 |
+
|
39 |
+
default_protein_emb = np.zeros(config.protein_embedding_size)
|
40 |
+
default_cell_emb = np.zeros(config.cell_embedding_size)
|
41 |
+
|
42 |
+
# Convert the E3 ligase to Uniprot ID
|
43 |
+
e3_ligase_uniprot = config.e3_ligase2uniprot.get(e3_ligase, '')
|
44 |
+
|
45 |
+
# Get the embeddings
|
46 |
+
poi_emb = protein2embedding.get(target_uniprot, default_protein_emb)
|
47 |
+
e3_emb = protein2embedding.get(e3_ligase_uniprot, default_protein_emb)
|
48 |
+
cell_emb = cell2embedding.get(cell_line, default_cell_emb)
|
49 |
+
smiles_emb = get_fingerprint(protac_smiles)
|
50 |
+
|
51 |
+
# Convert to torch tensors
|
52 |
+
poi_emb = torch.tensor(poi_emb).to(device)
|
53 |
+
e3_emb = torch.tensor(e3_emb).to(device)
|
54 |
+
cell_emb = torch.tensor(cell_emb).to(device)
|
55 |
+
smiles_emb = torch.tensor(smiles_emb).to(device)
|
56 |
+
|
57 |
+
return model(poi_emb, e3_emb, cell_emb, smiles_emb).item()
|
58 |
+
|
59 |
+
|
60 |
+
def is_protac_active(
|
61 |
+
protac_smiles: str,
|
62 |
+
e3_ligase: str,
|
63 |
+
target_uniprot: str,
|
64 |
+
cell_line: str,
|
65 |
+
device: str = 'cpu',
|
66 |
+
proba_threshold: float = 0.5,
|
67 |
+
) -> bool:
|
68 |
+
""" Predict whether a PROTAC is active or not.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
protac_smiles (str): The SMILES of the PROTAC.
|
72 |
+
e3_ligase (str): The Uniprot ID of the E3 ligase.
|
73 |
+
target_uniprot (str): The Uniprot ID of the target protein.
|
74 |
+
cell_line (str): The cell line identifier.
|
75 |
+
device (str): The device to run the model on.
|
76 |
+
proba_threshold (float): The probability threshold.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
bool: Whether the PROTAC is active or not.
|
80 |
+
"""
|
81 |
+
pred = get_protac_active_proba(
|
82 |
+
protac_smiles,
|
83 |
+
e3_ligase,
|
84 |
+
target_uniprot,
|
85 |
+
cell_line,
|
86 |
+
device,
|
87 |
+
)
|
88 |
+
return sigmoid(pred) > proba_threshold
|
protac_degradation_predictor/pytorch_models.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import Literal, List, Tuple, Optional, Dict
|
3 |
+
|
4 |
+
from protac_dataset import PROTAC_Dataset
|
5 |
+
from config import Config
|
6 |
+
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch.optim as optim
|
13 |
+
import pytorch_lightning as pl
|
14 |
+
from torch.utils.data import Dataset, DataLoader
|
15 |
+
from torchmetrics import (
|
16 |
+
Accuracy,
|
17 |
+
AUROC,
|
18 |
+
Precision,
|
19 |
+
Recall,
|
20 |
+
F1Score,
|
21 |
+
MetricCollection,
|
22 |
+
)
|
23 |
+
from imblearn.over_sampling import SMOTE
|
24 |
+
|
25 |
+
|
26 |
+
class PROTAC_Predictor(nn.Module):
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
hidden_dim: int,
|
31 |
+
smiles_emb_dim: int = Config.fingerprint_size,
|
32 |
+
poi_emb_dim: int = Config.protein_embedding_size,
|
33 |
+
e3_emb_dim: int = Config.protein_embedding_size,
|
34 |
+
cell_emb_dim: int = Config.cell_embedding_size,
|
35 |
+
dropout: float = 0.2,
|
36 |
+
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
|
37 |
+
disabled_embeddings: list = [],
|
38 |
+
):
|
39 |
+
""" Initialize the PROTAC model.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
hidden_dim (int): The hidden dimension of the model
|
43 |
+
smiles_emb_dim (int): The dimension of the SMILES embeddings
|
44 |
+
poi_emb_dim (int): The dimension of the POI embeddings
|
45 |
+
e3_emb_dim (int): The dimension of the E3 Ligase embeddings
|
46 |
+
cell_emb_dim (int): The dimension of the cell line embeddings
|
47 |
+
dropout (float): The dropout rate
|
48 |
+
join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings
|
49 |
+
disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles'
|
50 |
+
"""
|
51 |
+
super().__init__()
|
52 |
+
self.poi_emb_dim = poi_emb_dim
|
53 |
+
self.e3_emb_dim = e3_emb_dim
|
54 |
+
self.cell_emb_dim = cell_emb_dim
|
55 |
+
self.smiles_emb_dim = smiles_emb_dim
|
56 |
+
self.hidden_dim = hidden_dim
|
57 |
+
self.join_embeddings = join_embeddings
|
58 |
+
self.disabled_embeddings = disabled_embeddings
|
59 |
+
# Set our init args as class attributes
|
60 |
+
self.__dict__.update(locals())
|
61 |
+
|
62 |
+
# Define "surrogate models" branches
|
63 |
+
if self.join_embeddings != 'beginning':
|
64 |
+
if 'poi' not in self.disabled_embeddings:
|
65 |
+
self.poi_emb = nn.Linear(poi_emb_dim, hidden_dim)
|
66 |
+
if 'e3' not in self.disabled_embeddings:
|
67 |
+
self.e3_emb = nn.Linear(e3_emb_dim, hidden_dim)
|
68 |
+
if 'cell' not in self.disabled_embeddings:
|
69 |
+
self.cell_emb = nn.Linear(cell_emb_dim, hidden_dim)
|
70 |
+
if 'smiles' not in self.disabled_embeddings:
|
71 |
+
self.smiles_emb = nn.Linear(smiles_emb_dim, hidden_dim)
|
72 |
+
|
73 |
+
# Define hidden dimension for joining layer
|
74 |
+
if self.join_embeddings == 'beginning':
|
75 |
+
joint_dim = smiles_emb_dim if 'smiles' not in self.disabled_embeddings else 0
|
76 |
+
joint_dim += poi_emb_dim if 'poi' not in self.disabled_embeddings else 0
|
77 |
+
joint_dim += e3_emb_dim if 'e3' not in self.disabled_embeddings else 0
|
78 |
+
joint_dim += cell_emb_dim if 'cell' not in self.disabled_embeddings else 0
|
79 |
+
elif self.join_embeddings == 'concat':
|
80 |
+
joint_dim = hidden_dim * (4 - len(self.disabled_embeddings))
|
81 |
+
elif self.join_embeddings == 'sum':
|
82 |
+
joint_dim = hidden_dim
|
83 |
+
|
84 |
+
self.fc0 = nn.Linear(joint_dim, joint_dim)
|
85 |
+
self.fc1 = nn.Linear(joint_dim, hidden_dim)
|
86 |
+
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
87 |
+
self.fc3 = nn.Linear(hidden_dim, 1)
|
88 |
+
|
89 |
+
self.dropout = nn.Dropout(p=dropout)
|
90 |
+
|
91 |
+
|
92 |
+
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
|
93 |
+
embeddings = []
|
94 |
+
if self.join_embeddings == 'beginning':
|
95 |
+
if 'poi' not in self.disabled_embeddings:
|
96 |
+
embeddings.append(poi_emb)
|
97 |
+
if 'e3' not in self.disabled_embeddings:
|
98 |
+
embeddings.append(e3_emb)
|
99 |
+
if 'cell' not in self.disabled_embeddings:
|
100 |
+
embeddings.append(cell_emb)
|
101 |
+
if 'smiles' not in self.disabled_embeddings:
|
102 |
+
embeddings.append(smiles_emb)
|
103 |
+
x = torch.cat(embeddings, dim=1)
|
104 |
+
x = self.dropout(F.relu(self.fc0(x)))
|
105 |
+
else:
|
106 |
+
if 'poi' not in self.disabled_embeddings:
|
107 |
+
embeddings.append(self.poi_emb(poi_emb))
|
108 |
+
if 'e3' not in self.disabled_embeddings:
|
109 |
+
embeddings.append(self.e3_emb(e3_emb))
|
110 |
+
if 'cell' not in self.disabled_embeddings:
|
111 |
+
embeddings.append(self.cell_emb(cell_emb))
|
112 |
+
if 'smiles' not in self.disabled_embeddings:
|
113 |
+
embeddings.append(self.smiles_emb(smiles_emb))
|
114 |
+
if self.join_embeddings == 'concat':
|
115 |
+
x = torch.cat(embeddings, dim=1)
|
116 |
+
elif self.join_embeddings == 'sum':
|
117 |
+
if len(embeddings) > 1:
|
118 |
+
embeddings = torch.stack(embeddings, dim=1)
|
119 |
+
x = torch.sum(embeddings, dim=1)
|
120 |
+
else:
|
121 |
+
x = embeddings[0]
|
122 |
+
x = self.dropout(F.relu(self.fc1(x)))
|
123 |
+
x = self.dropout(F.relu(self.fc2(x)))
|
124 |
+
x = self.fc3(x)
|
125 |
+
return x
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
class PROTAC_Model(pl.LightningModule):
|
130 |
+
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
hidden_dim: int,
|
134 |
+
smiles_emb_dim: int = 224,
|
135 |
+
poi_emb_dim: int = 1024,
|
136 |
+
e3_emb_dim: int = 1024,
|
137 |
+
cell_emb_dim: int = 768,
|
138 |
+
batch_size: int = 32,
|
139 |
+
learning_rate: float = 1e-3,
|
140 |
+
dropout: float = 0.2,
|
141 |
+
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
|
142 |
+
train_dataset: PROTAC_Dataset = None,
|
143 |
+
val_dataset: PROTAC_Dataset = None,
|
144 |
+
test_dataset: PROTAC_Dataset = None,
|
145 |
+
disabled_embeddings: list = [],
|
146 |
+
apply_scaling: bool = False,
|
147 |
+
):
|
148 |
+
""" Initialize the PROTAC Pytorch Lightning model.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
hidden_dim (int): The hidden dimension of the model
|
152 |
+
smiles_emb_dim (int): The dimension of the SMILES embeddings
|
153 |
+
poi_emb_dim (int): The dimension of the POI embeddings
|
154 |
+
e3_emb_dim (int): The dimension of the E3 Ligase embeddings
|
155 |
+
cell_emb_dim (int): The dimension of the cell line embeddings
|
156 |
+
batch_size (int): The batch size
|
157 |
+
learning_rate (float): The learning rate
|
158 |
+
dropout (float): The dropout rate
|
159 |
+
join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings
|
160 |
+
train_dataset (PROTAC_Dataset): The training dataset
|
161 |
+
val_dataset (PROTAC_Dataset): The validation dataset
|
162 |
+
test_dataset (PROTAC_Dataset): The test dataset
|
163 |
+
disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles'
|
164 |
+
apply_scaling (bool): Whether to apply scaling to the embeddings
|
165 |
+
"""
|
166 |
+
super().__init__()
|
167 |
+
self.poi_emb_dim = poi_emb_dim
|
168 |
+
self.e3_emb_dim = e3_emb_dim
|
169 |
+
self.cell_emb_dim = cell_emb_dim
|
170 |
+
self.smiles_emb_dim = smiles_emb_dim
|
171 |
+
self.hidden_dim = hidden_dim
|
172 |
+
self.batch_size = batch_size
|
173 |
+
self.learning_rate = learning_rate
|
174 |
+
self.join_embeddings = join_embeddings
|
175 |
+
self.train_dataset = train_dataset
|
176 |
+
self.val_dataset = val_dataset
|
177 |
+
self.test_dataset = test_dataset
|
178 |
+
self.disabled_embeddings = disabled_embeddings
|
179 |
+
self.apply_scaling = apply_scaling
|
180 |
+
# Set our init args as class attributes
|
181 |
+
self.__dict__.update(locals()) # Add arguments as attributes
|
182 |
+
# Save the arguments passed to init
|
183 |
+
ignore_args_as_hyperparams = [
|
184 |
+
'train_dataset',
|
185 |
+
'test_dataset',
|
186 |
+
'val_dataset',
|
187 |
+
]
|
188 |
+
self.save_hyperparameters(ignore=ignore_args_as_hyperparams)
|
189 |
+
|
190 |
+
self.model = PROTAC_Predictor(
|
191 |
+
hidden_dim=hidden_dim,
|
192 |
+
smiles_emb_dim=smiles_emb_dim,
|
193 |
+
poi_emb_dim=poi_emb_dim,
|
194 |
+
e3_emb_dim=e3_emb_dim,
|
195 |
+
cell_emb_dim=cell_emb_dim,
|
196 |
+
dropout=dropout,
|
197 |
+
join_embeddings=join_embeddings,
|
198 |
+
disabled_embeddings=disabled_embeddings,
|
199 |
+
)
|
200 |
+
|
201 |
+
stages = ['train_metrics', 'val_metrics', 'test_metrics']
|
202 |
+
self.metrics = nn.ModuleDict({s: MetricCollection({
|
203 |
+
'acc': Accuracy(task='binary'),
|
204 |
+
'roc_auc': AUROC(task='binary'),
|
205 |
+
'precision': Precision(task='binary'),
|
206 |
+
'recall': Recall(task='binary'),
|
207 |
+
'f1_score': F1Score(task='binary'),
|
208 |
+
'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
|
209 |
+
'hp_metric': Accuracy(task='binary'),
|
210 |
+
}, prefix=s.replace('metrics', '')) for s in stages})
|
211 |
+
|
212 |
+
# Misc settings
|
213 |
+
self.missing_dataset_error = \
|
214 |
+
'''Class variable `{0}` is None. If the model was loaded from a checkpoint, the dataset must be set manually:
|
215 |
+
|
216 |
+
model = {1}.load_from_checkpoint('checkpoint.ckpt')
|
217 |
+
model.{0} = my_{0}
|
218 |
+
'''
|
219 |
+
|
220 |
+
# Apply scaling in datasets
|
221 |
+
if self.apply_scaling:
|
222 |
+
use_single_scaler = True if self.join_embeddings == 'beginning' else False
|
223 |
+
self.scalers = self.train_dataset.fit_scaling(use_single_scaler)
|
224 |
+
self.train_dataset.apply_scaling(self.scalers, use_single_scaler)
|
225 |
+
self.val_dataset.apply_scaling(self.scalers, use_single_scaler)
|
226 |
+
if self.test_dataset:
|
227 |
+
self.test_dataset.apply_scaling(self.scalers, use_single_scaler)
|
228 |
+
|
229 |
+
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
|
230 |
+
return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
|
231 |
+
|
232 |
+
def step(self, batch, batch_idx, stage):
|
233 |
+
poi_emb = batch['poi_emb']
|
234 |
+
e3_emb = batch['e3_emb']
|
235 |
+
cell_emb = batch['cell_emb']
|
236 |
+
smiles_emb = batch['smiles_emb']
|
237 |
+
y = batch['active'].float().unsqueeze(1)
|
238 |
+
|
239 |
+
y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
|
240 |
+
loss = F.binary_cross_entropy_with_logits(y_hat, y)
|
241 |
+
|
242 |
+
self.metrics[f'{stage}_metrics'].update(y_hat, y)
|
243 |
+
self.log(f'{stage}_loss', loss, on_epoch=True, prog_bar=True)
|
244 |
+
self.log_dict(self.metrics[f'{stage}_metrics'], on_epoch=True)
|
245 |
+
|
246 |
+
return loss
|
247 |
+
|
248 |
+
def training_step(self, batch, batch_idx):
|
249 |
+
return self.step(batch, batch_idx, 'train')
|
250 |
+
|
251 |
+
def validation_step(self, batch, batch_idx):
|
252 |
+
return self.step(batch, batch_idx, 'val')
|
253 |
+
|
254 |
+
def test_step(self, batch, batch_idx):
|
255 |
+
return self.step(batch, batch_idx, 'test')
|
256 |
+
|
257 |
+
def configure_optimizers(self):
|
258 |
+
return optim.Adam(self.parameters(), lr=self.learning_rate)
|
259 |
+
|
260 |
+
def predict_step(self, batch, batch_idx):
|
261 |
+
poi_emb = batch['poi_emb']
|
262 |
+
e3_emb = batch['e3_emb']
|
263 |
+
cell_emb = batch['cell_emb']
|
264 |
+
smiles_emb = batch['smiles_emb']
|
265 |
+
|
266 |
+
if self.apply_scaling:
|
267 |
+
if self.join_embeddings == 'beginning':
|
268 |
+
embeddings = np.hstack([
|
269 |
+
np.array(smiles_emb.tolist()),
|
270 |
+
np.array(poi_emb.tolist()),
|
271 |
+
np.array(e3_emb.tolist()),
|
272 |
+
np.array(cell_emb.tolist()),
|
273 |
+
])
|
274 |
+
embeddings = self.scalers.transform(embeddings)
|
275 |
+
smiles_emb = embeddings[:, :self.smiles_emb_dim]
|
276 |
+
poi_emb = embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.poi_emb_dim]
|
277 |
+
e3_emb = embeddings[:, self.smiles_emb_dim+self.poi_emb_dim:self.smiles_emb_dim+2*self.poi_emb_dim]
|
278 |
+
cell_emb = embeddings[:, -self.cell_emb_dim:]
|
279 |
+
else:
|
280 |
+
poi_emb = self.scalers['Uniprot'].transform(poi_emb)
|
281 |
+
e3_emb = self.scalers['E3 Ligase Uniprot'].transform(e3_emb)
|
282 |
+
cell_emb = self.scalers['Cell Line Identifier'].transform(cell_emb)
|
283 |
+
smiles_emb = self.scalers['Smiles'].transform(smiles_emb)
|
284 |
+
|
285 |
+
y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
|
286 |
+
return torch.sigmoid(y_hat)
|
287 |
+
|
288 |
+
def train_dataloader(self):
|
289 |
+
if self.train_dataset is None:
|
290 |
+
format = 'train_dataset', self.__class__.__name__
|
291 |
+
raise ValueError(self.missing_dataset_error.format(*format))
|
292 |
+
|
293 |
+
return DataLoader(
|
294 |
+
self.train_dataset,
|
295 |
+
batch_size=self.batch_size,
|
296 |
+
shuffle=True,
|
297 |
+
# drop_last=True,
|
298 |
+
)
|
299 |
+
|
300 |
+
def val_dataloader(self):
|
301 |
+
if self.val_dataset is None:
|
302 |
+
format = 'val_dataset', self.__class__.__name__
|
303 |
+
raise ValueError(self.missing_dataset_error.format(*format))
|
304 |
+
return DataLoader(
|
305 |
+
self.val_dataset,
|
306 |
+
batch_size=self.batch_size,
|
307 |
+
shuffle=False,
|
308 |
+
)
|
309 |
+
|
310 |
+
def test_dataloader(self):
|
311 |
+
if self.test_dataset is None:
|
312 |
+
format = 'test_dataset', self.__class__.__name__
|
313 |
+
raise ValueError(self.missing_dataset_error.format(*format))
|
314 |
+
return DataLoader(
|
315 |
+
self.test_dataset,
|
316 |
+
batch_size=self.batch_size,
|
317 |
+
shuffle=False,
|
318 |
+
)
|
319 |
+
|
320 |
+
|
321 |
+
def train_model(
|
322 |
+
protein2embedding: Dict,
|
323 |
+
cell2embedding: Dict,
|
324 |
+
smiles2fp: Dict,
|
325 |
+
train_df: pd.DataFrame,
|
326 |
+
val_df: pd.DataFrame,
|
327 |
+
test_df: Optional[pd.DataFrame] = None,
|
328 |
+
hidden_dim: int = 768,
|
329 |
+
batch_size: int = 8,
|
330 |
+
learning_rate: float = 2e-5,
|
331 |
+
dropout: float = 0.2,
|
332 |
+
max_epochs: int = 50,
|
333 |
+
smiles_emb_dim: int = 224,
|
334 |
+
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
|
335 |
+
smote_k_neighbors:int = 5,
|
336 |
+
use_smote: bool = True,
|
337 |
+
apply_scaling: bool = False,
|
338 |
+
active_label: str = 'Active',
|
339 |
+
fast_dev_run: bool = False,
|
340 |
+
use_logger: bool = True,
|
341 |
+
logger_name: str = 'protac',
|
342 |
+
disabled_embeddings: List[str] = [],
|
343 |
+
) -> tuple:
|
344 |
+
""" Train a PROTAC model using the given datasets and hyperparameters.
|
345 |
+
|
346 |
+
Args:
|
347 |
+
protein2embedding (dict): Dictionary of protein embeddings.
|
348 |
+
cell2embedding (dict): Dictionary of cell line embeddings.
|
349 |
+
smiles2fp (dict): Dictionary of SMILES to fingerprint.
|
350 |
+
train_df (pd.DataFrame): The training set. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', <active_label>.
|
351 |
+
val_df (pd.DataFrame): The validation set. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', <active_label>.
|
352 |
+
test_df (pd.DataFrame): The test set. If provided, the returned metrics will include test performance. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', <active_label>.
|
353 |
+
hidden_dim (int): The hidden dimension of the model.
|
354 |
+
batch_size (int): The batch size.
|
355 |
+
learning_rate (float): The learning rate.
|
356 |
+
max_epochs (int): The maximum number of epochs.
|
357 |
+
smiles_emb_dim (int): The dimension of the SMILES embeddings.
|
358 |
+
smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
|
359 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
360 |
+
disabled_embeddings (list): The list of disabled embeddings.
|
361 |
+
|
362 |
+
Returns:
|
363 |
+
tuple: The trained model, the trainer, and the metrics.
|
364 |
+
"""
|
365 |
+
oversampler = SMOTE(k_neighbors=smote_k_neighbors, random_state=42)
|
366 |
+
train_ds = PROTAC_Dataset(
|
367 |
+
train_df,
|
368 |
+
protein2embedding,
|
369 |
+
cell2embedding,
|
370 |
+
smiles2fp,
|
371 |
+
use_smote=use_smote,
|
372 |
+
oversampler=oversampler if use_smote else None,
|
373 |
+
active_label=active_label,
|
374 |
+
)
|
375 |
+
val_ds = PROTAC_Dataset(
|
376 |
+
val_df,
|
377 |
+
protein2embedding,
|
378 |
+
cell2embedding,
|
379 |
+
smiles2fp,
|
380 |
+
active_label=active_label,
|
381 |
+
)
|
382 |
+
if test_df is not None:
|
383 |
+
test_ds = PROTAC_Dataset(
|
384 |
+
test_df,
|
385 |
+
protein2embedding,
|
386 |
+
cell2embedding,
|
387 |
+
smiles2fp,
|
388 |
+
active_label=active_label,
|
389 |
+
)
|
390 |
+
logger = pl.loggers.TensorBoardLogger(
|
391 |
+
save_dir='../logs',
|
392 |
+
name=logger_name,
|
393 |
+
)
|
394 |
+
callbacks = [
|
395 |
+
pl.callbacks.EarlyStopping(
|
396 |
+
monitor='train_loss',
|
397 |
+
patience=10,
|
398 |
+
mode='min',
|
399 |
+
verbose=False,
|
400 |
+
),
|
401 |
+
pl.callbacks.EarlyStopping(
|
402 |
+
monitor='val_loss',
|
403 |
+
patience=5,
|
404 |
+
mode='min',
|
405 |
+
verbose=False,
|
406 |
+
),
|
407 |
+
pl.callbacks.EarlyStopping(
|
408 |
+
monitor='val_acc',
|
409 |
+
patience=10,
|
410 |
+
mode='max',
|
411 |
+
verbose=False,
|
412 |
+
),
|
413 |
+
# pl.callbacks.ModelCheckpoint(
|
414 |
+
# monitor='val_acc',
|
415 |
+
# mode='max',
|
416 |
+
# verbose=True,
|
417 |
+
# filename='{epoch}-{val_metrics_opt_score:.4f}',
|
418 |
+
# ),
|
419 |
+
]
|
420 |
+
# Define Trainer
|
421 |
+
trainer = pl.Trainer(
|
422 |
+
logger=logger if use_logger else False,
|
423 |
+
callbacks=callbacks,
|
424 |
+
max_epochs=max_epochs,
|
425 |
+
fast_dev_run=fast_dev_run,
|
426 |
+
enable_model_summary=False,
|
427 |
+
enable_checkpointing=False,
|
428 |
+
enable_progress_bar=False,
|
429 |
+
devices=1,
|
430 |
+
num_nodes=1,
|
431 |
+
)
|
432 |
+
model = PROTAC_Model(
|
433 |
+
hidden_dim=hidden_dim,
|
434 |
+
smiles_emb_dim=smiles_emb_dim,
|
435 |
+
poi_emb_dim=1024,
|
436 |
+
e3_emb_dim=1024,
|
437 |
+
cell_emb_dim=768,
|
438 |
+
batch_size=batch_size,
|
439 |
+
join_embeddings=join_embeddings,
|
440 |
+
dropout=dropout,
|
441 |
+
learning_rate=learning_rate,
|
442 |
+
apply_scaling=apply_scaling,
|
443 |
+
train_dataset=train_ds,
|
444 |
+
val_dataset=val_ds,
|
445 |
+
test_dataset=test_ds if test_df is not None else None,
|
446 |
+
disabled_embeddings=disabled_embeddings,
|
447 |
+
)
|
448 |
+
with warnings.catch_warnings():
|
449 |
+
warnings.simplefilter("ignore")
|
450 |
+
trainer.fit(model)
|
451 |
+
metrics = trainer.validate(model, verbose=False)[0]
|
452 |
+
if test_df is not None:
|
453 |
+
test_metrics = trainer.test(model, verbose=False)[0]
|
454 |
+
metrics.update(test_metrics)
|
455 |
+
return model, trainer, metrics
|
456 |
+
|
457 |
+
|
458 |
+
def load_model(
|
459 |
+
ckpt_path: str,
|
460 |
+
) -> PROTAC_Model:
|
461 |
+
""" Load a PROTAC model from a checkpoint.
|
462 |
+
|
463 |
+
Args:
|
464 |
+
ckpt_path (str): The path to the checkpoint.
|
465 |
+
|
466 |
+
Returns:
|
467 |
+
PROTAC_Model: The loaded model.
|
468 |
+
"""
|
469 |
+
model = PROTAC_Model.load_from_checkpoint(ckpt_path)
|
470 |
+
model.eval()
|
471 |
+
return model
|
protac_degradation_predictor/sklearn_models.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, List, Tuple, Optional, Dict
|
2 |
+
|
3 |
+
from protac_dataset import PROTAC_Dataset
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
from sklearn.base import ClassifierMixin
|
7 |
+
from sklearn.ensemble import (
|
8 |
+
RandomForestClassifier,
|
9 |
+
GradientBoostingClassifier,
|
10 |
+
)
|
11 |
+
from sklearn.linear_model import LogisticRegression
|
12 |
+
from sklearn.svm import SVC
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torchmetrics import (
|
17 |
+
Accuracy,
|
18 |
+
AUROC,
|
19 |
+
Precision,
|
20 |
+
Recall,
|
21 |
+
F1Score,
|
22 |
+
MetricCollection,
|
23 |
+
)
|
24 |
+
import optuna
|
25 |
+
|
26 |
+
|
27 |
+
def train_sklearn_model(
|
28 |
+
clf: ClassifierMixin,
|
29 |
+
protein2embedding: Dict,
|
30 |
+
cell2embedding: Dict,
|
31 |
+
smiles2fp: Dict,
|
32 |
+
train_df: pd.DataFrame,
|
33 |
+
val_df: pd.DataFrame,
|
34 |
+
test_df: Optional[pd.DataFrame] = None,
|
35 |
+
active_label: str = 'Active',
|
36 |
+
use_single_scaler: bool = True,
|
37 |
+
) -> Tuple[ClassifierMixin, Dict]:
|
38 |
+
""" Train a classifier model on train and val sets and evaluate it on a test set.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
clf: The classifier model to train and evaluate.
|
42 |
+
train_df (pd.DataFrame): The training set.
|
43 |
+
val_df (pd.DataFrame): The validation set.
|
44 |
+
test_df (Optional[pd.DataFrame]): The test set.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
Tuple[ClassifierMixin, nn.ModuleDict]: The trained model and the metrics.
|
48 |
+
"""
|
49 |
+
# Initialize the datasets
|
50 |
+
train_ds = PROTAC_Dataset(
|
51 |
+
train_df,
|
52 |
+
protein2embedding,
|
53 |
+
cell2embedding,
|
54 |
+
smiles2fp,
|
55 |
+
active_label=active_label,
|
56 |
+
use_smote=False,
|
57 |
+
)
|
58 |
+
scaler = train_ds.fit_scaling(use_single_scaler=use_single_scaler)
|
59 |
+
train_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
|
60 |
+
val_ds = PROTAC_Dataset(
|
61 |
+
val_df,
|
62 |
+
protein2embedding,
|
63 |
+
cell2embedding,
|
64 |
+
smiles2fp,
|
65 |
+
active_label=active_label,
|
66 |
+
use_smote=False,
|
67 |
+
)
|
68 |
+
val_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
|
69 |
+
if test_df is not None:
|
70 |
+
test_ds = PROTAC_Dataset(
|
71 |
+
test_df,
|
72 |
+
protein2embedding,
|
73 |
+
cell2embedding,
|
74 |
+
smiles2fp,
|
75 |
+
active_label=active_label,
|
76 |
+
use_smote=False,
|
77 |
+
)
|
78 |
+
test_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
|
79 |
+
|
80 |
+
# Get the numpy arrays
|
81 |
+
X_train, y_train = train_ds.get_numpy_arrays()
|
82 |
+
X_val, y_val = val_ds.get_numpy_arrays()
|
83 |
+
if test_df is not None:
|
84 |
+
X_test, y_test = test_ds.get_numpy_arrays()
|
85 |
+
|
86 |
+
# Train the model
|
87 |
+
clf.fit(X_train, y_train)
|
88 |
+
# Define the metrics as a module dict
|
89 |
+
stages = ['train_metrics', 'val_metrics', 'test_metrics']
|
90 |
+
metrics = nn.ModuleDict({s: MetricCollection({
|
91 |
+
'acc': Accuracy(task='binary'),
|
92 |
+
'roc_auc': AUROC(task='binary'),
|
93 |
+
'precision': Precision(task='binary'),
|
94 |
+
'recall': Recall(task='binary'),
|
95 |
+
'f1_score': F1Score(task='binary'),
|
96 |
+
'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
|
97 |
+
'hp_metric': Accuracy(task='binary'),
|
98 |
+
}, prefix=s.replace('metrics', '')) for s in stages})
|
99 |
+
|
100 |
+
# Get the predictions
|
101 |
+
metrics_out = {}
|
102 |
+
|
103 |
+
y_pred = torch.tensor(clf.predict_proba(X_train)[:, 1])
|
104 |
+
y_true = torch.tensor(y_train)
|
105 |
+
metrics['train_metrics'].update(y_pred, y_true)
|
106 |
+
metrics_out.update(metrics['train_metrics'].compute())
|
107 |
+
|
108 |
+
y_pred = torch.tensor(clf.predict_proba(X_val)[:, 1])
|
109 |
+
y_true = torch.tensor(y_val)
|
110 |
+
metrics['val_metrics'].update(y_pred, y_true)
|
111 |
+
metrics_out.update(metrics['val_metrics'].compute())
|
112 |
+
|
113 |
+
if test_df is not None:
|
114 |
+
y_pred = torch.tensor(clf.predict_proba(X_test)[:, 1])
|
115 |
+
y_true = torch.tensor(y_test)
|
116 |
+
metrics['test_metrics'].update(y_pred, y_true)
|
117 |
+
metrics_out.update(metrics['test_metrics'].compute())
|
118 |
+
|
119 |
+
return clf, metrics_out
|
120 |
+
|
121 |
+
|
122 |
+
def suggest_random_forest(
|
123 |
+
trial: optuna.Trial,
|
124 |
+
) -> ClassifierMixin:
|
125 |
+
""" Suggest hyperparameters for a Random Forest classifier.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
trial (optuna.Trial): The Optuna trial object.
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
ClassifierMixin: The Random Forest classifier with the suggested hyperparameters.
|
132 |
+
"""
|
133 |
+
n_estimators = trial.suggest_int('model_n_estimators', 10, 1000)
|
134 |
+
max_depth = trial.suggest_int('model_max_depth', 2, 100)
|
135 |
+
min_samples_split = trial.suggest_int('model_min_samples_split', 2, 10)
|
136 |
+
min_samples_leaf = trial.suggest_int('model_min_samples_leaf', 1, 10)
|
137 |
+
max_features = trial.suggest_categorical('model_max_features', [None, 'sqrt', 'log2'])
|
138 |
+
criterion = trial.suggest_categorical('model_criterion', ['gini', 'entropy'])
|
139 |
+
|
140 |
+
clf = RandomForestClassifier(
|
141 |
+
n_estimators=n_estimators,
|
142 |
+
max_depth=max_depth,
|
143 |
+
min_samples_split=min_samples_split,
|
144 |
+
min_samples_leaf=min_samples_leaf,
|
145 |
+
max_features=max_features,
|
146 |
+
criterion=criterion,
|
147 |
+
random_state=42,
|
148 |
+
)
|
149 |
+
|
150 |
+
return clf
|
151 |
+
|
152 |
+
|
153 |
+
def suggest_logistic_regression(
|
154 |
+
trial: optuna.Trial,
|
155 |
+
) -> ClassifierMixin:
|
156 |
+
""" Suggest hyperparameters for a Logistic Regression classifier.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
trial (optuna.Trial): The Optuna trial object.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
ClassifierMixin: The Logistic Regression classifier with the suggested hyperparameters.
|
163 |
+
"""
|
164 |
+
# Suggest values for the logistic regression hyperparameters
|
165 |
+
C = trial.suggest_loguniform('model_C', 1e-4, 1e2)
|
166 |
+
penalty = trial.suggest_categorical('model_penalty', ['l1', 'l2', 'elasticnet', None])
|
167 |
+
solver = trial.suggest_categorical('model_solver', ['newton-cholesky', 'lbfgs', 'liblinear', 'sag', 'saga'])
|
168 |
+
|
169 |
+
# Check solver compatibility
|
170 |
+
if penalty == 'l1' and solver not in ['liblinear', 'saga']:
|
171 |
+
raise optuna.exceptions.TrialPruned()
|
172 |
+
if penalty == None and solver not in ['newton-cholesky', 'lbfgs', 'sag']:
|
173 |
+
raise optuna.exceptions.TrialPruned()
|
174 |
+
|
175 |
+
# Configure the classifier with the trial's suggested parameters
|
176 |
+
clf = LogisticRegression(
|
177 |
+
C=C,
|
178 |
+
penalty=penalty,
|
179 |
+
solver=solver,
|
180 |
+
max_iter=1000,
|
181 |
+
random_state=42,
|
182 |
+
)
|
183 |
+
|
184 |
+
return clf
|
185 |
+
|
186 |
+
|
187 |
+
def suggest_svc(
|
188 |
+
trial: optuna.Trial,
|
189 |
+
) -> ClassifierMixin:
|
190 |
+
""" Suggest hyperparameters for an SVC classifier.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
trial (optuna.Trial): The Optuna trial object.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
ClassifierMixin: The SVC classifier with the suggested hyperparameters.
|
197 |
+
"""
|
198 |
+
C = trial.suggest_loguniform('model_C', 1e-4, 1e2)
|
199 |
+
kernel = trial.suggest_categorical('model_kernel', ['linear', 'poly', 'rbf', 'sigmoid'])
|
200 |
+
gamma = trial.suggest_categorical('model_gamma', ['scale', 'auto'])
|
201 |
+
degree = trial.suggest_int('model_degree', 2, 5) if kernel == 'poly' else 3
|
202 |
+
|
203 |
+
clf = SVC(
|
204 |
+
C=C,
|
205 |
+
kernel=kernel,
|
206 |
+
gamma=gamma,
|
207 |
+
degree=degree,
|
208 |
+
probability=True,
|
209 |
+
random_state=42,
|
210 |
+
)
|
211 |
+
|
212 |
+
return clf
|
213 |
+
|
214 |
+
|
215 |
+
def suggest_gradient_boosting(
|
216 |
+
trial: optuna.Trial,
|
217 |
+
) -> ClassifierMixin:
|
218 |
+
""" Suggest hyperparameters for a Gradient Boosting classifier.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
trial (optuna.Trial): The Optuna trial object.
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
ClassifierMixin: The Gradient Boosting classifier with the suggested hyperparameters.
|
225 |
+
"""
|
226 |
+
n_estimators = trial.suggest_int('model_n_estimators', 50, 500)
|
227 |
+
learning_rate = trial.suggest_loguniform('model_learning_rate', 0.01, 1)
|
228 |
+
max_depth = trial.suggest_int('model_max_depth', 3, 10)
|
229 |
+
min_samples_split = trial.suggest_int('model_min_samples_split', 2, 10)
|
230 |
+
min_samples_leaf = trial.suggest_int('model_min_samples_leaf', 1, 10)
|
231 |
+
max_features = trial.suggest_categorical('model_max_features', ['sqrt', 'log2', None])
|
232 |
+
|
233 |
+
clf = GradientBoostingClassifier(
|
234 |
+
n_estimators=n_estimators,
|
235 |
+
learning_rate=learning_rate,
|
236 |
+
max_depth=max_depth,
|
237 |
+
min_samples_split=min_samples_split,
|
238 |
+
min_samples_leaf=min_samples_leaf,
|
239 |
+
max_features=max_features,
|
240 |
+
random_state=42,
|
241 |
+
)
|
242 |
+
|
243 |
+
return clf
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_0_test_split_0.1.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 45164
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9d2f036ce141fbeb81930cc9ce49dbd6effc76221b26b92ae0498af1c34289f3
|
3 |
size 45164
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_1_test_split_0.1.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 45164
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8ce36b5f52f8f88105c3ec0c5b60f865e1b054aff8f9e96c21f1e037eaa65af
|
3 |
size 45164
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_2_test_split_0.1.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 45164
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d22f4f54d46ca72b8585645fdfac43683a23dcc00d80fb8bd1f785d4eb4a9594
|
3 |
size 45164
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_2_test_split_0.2.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:f7298a04d0888f4de87041efd6b78c42e13c3f1630c43567d582bc7710a40847
|
3 |
-
size 45164
|
|
|
|
|
|
|
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_3_test_split_0.1.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 45164
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54f869b328af6667567bd5cc805ce63fc5434ed1b77afc1e66d95b8f02e40642
|
3 |
size 45164
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_3_test_split_0.2.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:d2a7a9f6ed11e1b5b6f876dd927612d03f4780f9db3e65b9f1ebb8fbd853677f
|
3 |
-
size 45164
|
|
|
|
|
|
|
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_4_test_split_0.1.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 45164
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:27f0e76e7f89950199c843699c000eaad8628441c84aec394e20c23f701b1609
|
3 |
size 45164
|
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_4_test_split_0.2.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:fd4c40033da16a1cee16fd998c82e0403a31db4d14ba0604160ea143bae03668
|
3 |
-
size 45164
|
|
|
|
|
|
|
|
setup.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import setuptools
|
2 |
+
|
3 |
+
setuptools.setup(
|
4 |
+
name="protac_degradation_predictor",
|
5 |
+
version="0.0.1",
|
6 |
+
author="Stefano Ribes",
|
7 |
+
url="https://github.com/ribesstefano/PROTAC-Degradation-Predictor",
|
8 |
+
author_email="ribes.stefano@gmail.com",
|
9 |
+
description="A package to predict PROTAC-induced protein degradation.",
|
10 |
+
long_description=open("README.md").read(),
|
11 |
+
packages=setuptools.find_packages(),
|
12 |
+
install_requires=["torch", "pytorch_lightning", "sklearn", "imblearn", "pandas", "joblib", "h5py", "optuna", "torchmetrics"],
|
13 |
+
classifiers=[
|
14 |
+
"Programming Language :: Python :: 3",
|
15 |
+
"Programming Language :: Python :: 3.6",
|
16 |
+
"License :: OSI Approved :: MIT License",
|
17 |
+
"Operating System :: OS Independent",
|
18 |
+
],
|
19 |
+
include_package_data=True,
|
20 |
+
package_data={"": ["data/*.h5", "data/*.pkl", "data/*.csv"]},
|
21 |
+
)
|