ribesstefano
commited on
Commit
•
165d38a
1
Parent(s):
74be897
refactored package code and fine running experimental code
Browse files
protac_degradation_predictor/__init__.py
CHANGED
@@ -1,14 +1,18 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
12 |
)
|
13 |
|
14 |
__version__ = "0.0.1"
|
|
|
1 |
+
from .data_utils import (
|
2 |
+
load_protein2embedding,
|
3 |
+
load_cell2embedding,
|
4 |
+
get_fingerprint,
|
5 |
+
is_active,
|
6 |
+
)
|
7 |
+
from .pytorch_models import (
|
8 |
+
train_model,
|
9 |
+
)
|
10 |
+
from .sklearn_models import (
|
11 |
+
train_sklearn_model,
|
12 |
+
)
|
13 |
+
from .optuna_utils import (
|
14 |
+
hyperparameter_tuning_and_training,
|
15 |
+
hyperparameter_tuning_and_training_sklearn,
|
16 |
)
|
17 |
|
18 |
__version__ = "0.0.1"
|
protac_degradation_predictor/protac_degradation_predictor.py
CHANGED
@@ -14,9 +14,6 @@ import torch
|
|
14 |
from torch import sigmoid
|
15 |
|
16 |
|
17 |
-
package_name = 'protac_degradation_predictor'
|
18 |
-
|
19 |
-
|
20 |
def get_protac_active_proba(
|
21 |
protac_smiles: str,
|
22 |
e3_ligase: str,
|
|
|
14 |
from torch import sigmoid
|
15 |
|
16 |
|
|
|
|
|
|
|
17 |
def get_protac_active_proba(
|
18 |
protac_smiles: str,
|
19 |
e3_ligase: str,
|
setup.py
CHANGED
@@ -9,7 +9,7 @@ setuptools.setup(
|
|
9 |
description="A package to predict PROTAC-induced protein degradation.",
|
10 |
long_description=open("README.md").read(),
|
11 |
packages=setuptools.find_packages(),
|
12 |
-
install_requires=["torch", "pytorch_lightning", "scikit-learn", "
|
13 |
classifiers=[
|
14 |
"Programming Language :: Python :: 3",
|
15 |
"Programming Language :: Python :: 3.6",
|
|
|
9 |
description="A package to predict PROTAC-induced protein degradation.",
|
10 |
long_description=open("README.md").read(),
|
11 |
packages=setuptools.find_packages(),
|
12 |
+
install_requires=["torch", "pytorch_lightning", "scikit-learn", "imbalanced-learn", "rdkit-pypi", "pandas", "joblib", "h5py", "optuna", "torchmetrics"],
|
13 |
classifiers=[
|
14 |
"Programming Language :: Python :: 3",
|
15 |
"Programming Language :: Python :: 3.6",
|
src/{main.py → run_experiments.py}
RENAMED
@@ -1,19 +1,12 @@
|
|
1 |
import os
|
|
|
2 |
from collections import defaultdict
|
3 |
import warnings
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
is_active,
|
10 |
-
)
|
11 |
-
from protac_degradation_predictor.pytorch_models import (
|
12 |
-
train_model,
|
13 |
-
)
|
14 |
-
from protac_degradation_predictor.optuna_utils import (
|
15 |
-
hyperparameter_tuning_and_training,
|
16 |
-
)
|
17 |
|
18 |
from rdkit import Chem
|
19 |
from rdkit.Chem import AllChem
|
@@ -28,78 +21,53 @@ from sklearn.model_selection import (
|
|
28 |
StratifiedGroupKFold,
|
29 |
)
|
30 |
|
31 |
-
|
32 |
# Ignore UserWarning from Matplotlib
|
33 |
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
34 |
# Ignore UserWarning from PyTorch Lightning
|
35 |
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
36 |
|
37 |
|
38 |
-
def
|
39 |
-
|
40 |
-
n_trials: int = 50,
|
41 |
-
fast_dev_run: bool = False,
|
42 |
-
test_split: float = 0.2,
|
43 |
-
cv_n_splits: int = 5,
|
44 |
-
):
|
45 |
-
""" Train a PROTAC model using the given datasets and hyperparameters.
|
46 |
|
47 |
Args:
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
52 |
"""
|
53 |
-
## Set the Column to Predict
|
54 |
-
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
|
55 |
-
|
56 |
-
# Get Dmax_threshold from the active_col
|
57 |
-
Dmax_threshold = float(active_col.split('Dmax')[1].split(',')[0].strip('(').strip(')').strip())
|
58 |
-
pDC50_threshold = float(active_col.split('pDC50')[1].strip('(').strip(')').strip())
|
59 |
-
|
60 |
-
## Load the Data
|
61 |
-
protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
|
62 |
-
|
63 |
-
# Map E3 Ligase Iap to IAP
|
64 |
-
protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
|
65 |
-
|
66 |
-
protac_df[active_col] = protac_df.apply(
|
67 |
-
lambda x: is_active(x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold), axis=1
|
68 |
-
)
|
69 |
-
|
70 |
-
## Test Sets
|
71 |
-
|
72 |
-
test_indeces = {}
|
73 |
-
|
74 |
-
### Random Split
|
75 |
-
|
76 |
-
# Randomly select 20% of the active PROTACs as the test set
|
77 |
-
active_df = protac_df[protac_df[active_col].notna()].copy()
|
78 |
test_df = active_df.sample(frac=test_split, random_state=42)
|
79 |
-
|
80 |
|
81 |
-
### E3-based Split
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
encoder = OrdinalEncoder()
|
84 |
-
|
85 |
-
active_df = protac_df[protac_df[active_col].notna()].copy()
|
86 |
test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')]
|
87 |
-
|
88 |
|
89 |
-
### Tanimoto-based Split
|
90 |
-
|
91 |
-
#### Precompute fingerprints
|
92 |
-
morgan_fpgen = AllChem.GetMorganGenerator(
|
93 |
-
radius=config.morgan_radius,
|
94 |
-
fpSize=config.fingerprint_size,
|
95 |
-
includeChirality=True,
|
96 |
-
)
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
smiles2fp = {}
|
99 |
for smiles in tqdm(protac_df['Smiles'].unique().tolist(), desc='Precomputing fingerprints'):
|
100 |
-
|
101 |
-
morgan_fp = morgan_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
|
102 |
-
smiles2fp[smiles] = morgan_fp
|
103 |
|
104 |
# Get the pair-wise tanimoto similarity between the PROTAC fingerprints
|
105 |
tanimoto_matrix = defaultdict(list)
|
@@ -117,12 +85,27 @@ def main(
|
|
117 |
|
118 |
smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
encoder = OrdinalEncoder()
|
124 |
-
|
125 |
-
active_df = protac_df[protac_df[active_col].notna()].copy()
|
126 |
# Sort the groups so that samples with the highest tanimoto similarity,
|
127 |
# i.e., the "less similar" ones, are placed in the test set first
|
128 |
tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
|
@@ -159,14 +142,22 @@ def main(
|
|
159 |
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
|
160 |
test_df.append(group_df)
|
161 |
test_df = pd.concat(test_df)
|
162 |
-
|
163 |
-
|
164 |
|
165 |
-
|
|
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
encoder = OrdinalEncoder()
|
168 |
-
|
169 |
-
active_df = protac_df[protac_df[active_col].notna()].copy()
|
170 |
|
171 |
test_df = []
|
172 |
# For each group, get the number of active and inactive entries. Then, add those
|
@@ -201,25 +192,64 @@ def main(
|
|
201 |
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
|
202 |
test_df.append(group_df)
|
203 |
test_df = pd.concat(test_df)
|
204 |
-
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
# Make directory ../reports if it does not exist
|
210 |
if not os.path.exists('../reports'):
|
211 |
os.makedirs('../reports')
|
212 |
|
213 |
# Load embedding dictionaries
|
214 |
-
protein2embedding = load_protein2embedding('../data/uniprot2embedding.h5')
|
215 |
-
cell2embedding = load_cell2embedding('../data/cell2embedding.pkl')
|
216 |
|
|
|
217 |
report = []
|
218 |
for split_type, indeces in test_indeces.items():
|
219 |
-
active_df = protac_df[protac_df[active_col].notna()].copy()
|
220 |
test_df = active_df.loc[indeces]
|
221 |
train_val_df = active_df[~active_df.index.isin(test_df.index)]
|
222 |
|
|
|
223 |
if split_type == 'random':
|
224 |
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
225 |
group = None
|
@@ -232,6 +262,7 @@ def main(
|
|
232 |
elif split_type == 'uniprot':
|
233 |
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
234 |
group = train_val_df['Uniprot Group'].to_numpy()
|
|
|
235 |
# Start the CV over the folds
|
236 |
X = train_val_df.drop(columns=active_col)
|
237 |
y = train_val_df[active_col].tolist()
|
@@ -269,7 +300,7 @@ def main(
|
|
269 |
|
270 |
print(stats)
|
271 |
# # Train and evaluate the model
|
272 |
-
# model, trainer, metrics = hyperparameter_tuning_and_training(
|
273 |
# protein2embedding,
|
274 |
# cell2embedding,
|
275 |
# smiles2fp,
|
@@ -294,7 +325,7 @@ def main(
|
|
294 |
# print(f'Ablation study with disabled embeddings: {disabled_embeddings}')
|
295 |
# print('-' * 100)
|
296 |
# stats['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
|
297 |
-
# model, trainer, metrics = train_model(
|
298 |
# protein2embedding,
|
299 |
# cell2embedding,
|
300 |
# smiles2fp,
|
|
|
1 |
import os
|
2 |
+
import sys
|
3 |
from collections import defaultdict
|
4 |
import warnings
|
5 |
|
6 |
+
|
7 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
8 |
+
|
9 |
+
import protac_degradation_predictor as pdp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
from rdkit import Chem
|
12 |
from rdkit.Chem import AllChem
|
|
|
21 |
StratifiedGroupKFold,
|
22 |
)
|
23 |
|
|
|
24 |
# Ignore UserWarning from Matplotlib
|
25 |
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
26 |
# Ignore UserWarning from PyTorch Lightning
|
27 |
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
28 |
|
29 |
|
30 |
+
def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index:
|
31 |
+
""" Get the indices of the test set using a random split.
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
Args:
|
34 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
35 |
+
test_split (float): The percentage of the active PROTACs to use as the test set.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
pd.Index: The indices of the test set.
|
39 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
test_df = active_df.sample(frac=test_split, random_state=42)
|
41 |
+
return test_df.index
|
42 |
|
|
|
43 |
|
44 |
+
def get_e3_ligase_split_indices(active_df: pd.DataFrame) -> pd.Index:
|
45 |
+
""" Get the indices of the test set using the E3 ligase split.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
pd.Index: The indices of the test set.
|
52 |
+
"""
|
53 |
encoder = OrdinalEncoder()
|
54 |
+
active_df['E3 Group'] = encoder.fit_transform(active_df[['E3 Ligase']]).astype(int)
|
|
|
55 |
test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')]
|
56 |
+
return test_df.index
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
+
def get_smiles2fp_and_avg_tanimoto(protac_df: pd.DataFrame) -> tuple:
|
60 |
+
""" Get the SMILES to fingerprint dictionary and the average Tanimoto similarity.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
protac_df (pd.DataFrame): The DataFrame containing the PROTACs.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
tuple: The SMILES to fingerprint dictionary and the average Tanimoto similarity.
|
67 |
+
"""
|
68 |
smiles2fp = {}
|
69 |
for smiles in tqdm(protac_df['Smiles'].unique().tolist(), desc='Precomputing fingerprints'):
|
70 |
+
smiles2fp[smiles] = pdp.get_fingerprint(smiles)
|
|
|
|
|
71 |
|
72 |
# Get the pair-wise tanimoto similarity between the PROTAC fingerprints
|
73 |
tanimoto_matrix = defaultdict(list)
|
|
|
85 |
|
86 |
smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
|
87 |
|
88 |
+
return smiles2fp, protac_df
|
89 |
+
|
90 |
+
|
91 |
+
def get_tanimoto_split_indices(
|
92 |
+
active_df: pd.DataFrame,
|
93 |
+
active_col: str,
|
94 |
+
test_split: float,
|
95 |
+
n_bins_tanimoto: int = 200,
|
96 |
+
) -> pd.Index:
|
97 |
+
""" Get the indices of the test set using the Tanimoto-based split.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
101 |
+
n_bins_tanimoto (int): The number of bins to use for the Tanimoto similarity.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
pd.Index: The indices of the test set.
|
105 |
+
"""
|
106 |
+
tanimoto_groups = pd.cut(active_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
|
107 |
encoder = OrdinalEncoder()
|
108 |
+
active_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
|
|
|
109 |
# Sort the groups so that samples with the highest tanimoto similarity,
|
110 |
# i.e., the "less similar" ones, are placed in the test set first
|
111 |
tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
|
|
|
142 |
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
|
143 |
test_df.append(group_df)
|
144 |
test_df = pd.concat(test_df)
|
145 |
+
return test_df.index
|
146 |
+
|
147 |
|
148 |
+
def get_target_split_indices(active_df: pd.DataFrame, active_col: str, test_split: float) -> pd.Index:
|
149 |
+
""" Get the indices of the test set using the target-based split.
|
150 |
|
151 |
+
Args:
|
152 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
153 |
+
active_col (str): The column containing the active/inactive information.
|
154 |
+
test_split (float): The percentage of the active PROTACs to use as the test set.
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
pd.Index: The indices of the test set.
|
158 |
+
"""
|
159 |
encoder = OrdinalEncoder()
|
160 |
+
active_df['Uniprot Group'] = encoder.fit_transform(active_df[['Uniprot']]).astype(int)
|
|
|
161 |
|
162 |
test_df = []
|
163 |
# For each group, get the number of active and inactive entries. Then, add those
|
|
|
192 |
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
|
193 |
test_df.append(group_df)
|
194 |
test_df = pd.concat(test_df)
|
195 |
+
return test_df.index
|
196 |
+
|
197 |
+
|
198 |
+
def main(
|
199 |
+
active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
|
200 |
+
n_trials: int = 50,
|
201 |
+
fast_dev_run: bool = False,
|
202 |
+
test_split: float = 0.2,
|
203 |
+
cv_n_splits: int = 5,
|
204 |
+
):
|
205 |
+
""" Train a PROTAC model using the given datasets and hyperparameters.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
use_ored_activity (bool): Whether to use the 'Active - OR' column.
|
209 |
+
n_trials (int): The number of hyperparameter optimization trials.
|
210 |
+
n_splits (int): The number of cross-validation splits.
|
211 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
212 |
+
"""
|
213 |
+
# Set the Column to Predict
|
214 |
+
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
|
215 |
|
216 |
+
# Get Dmax_threshold from the active_col
|
217 |
+
Dmax_threshold = float(active_col.split('Dmax')[1].split(',')[0].strip('(').strip(')').strip())
|
218 |
+
pDC50_threshold = float(active_col.split('pDC50')[1].strip('(').strip(')').strip())
|
219 |
+
|
220 |
+
# Load the PROTAC dataset
|
221 |
+
protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
|
222 |
+
# Map E3 Ligase Iap to IAP
|
223 |
+
protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
|
224 |
+
protac_df[active_col] = protac_df.apply(
|
225 |
+
lambda x: pdp.is_active(x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold), axis=1
|
226 |
+
)
|
227 |
+
smiles2fp, protac_df = get_smiles2fp_and_avg_tanimoto(protac_df)
|
228 |
+
|
229 |
+
## Get the test sets
|
230 |
+
test_indeces = {}
|
231 |
+
active_df = protac_df[protac_df[active_col].notna()].copy()
|
232 |
+
test_indeces['random'] = get_random_split_indices(active_df, test_split)
|
233 |
+
test_indeces['e3_ligase'] = get_e3_ligase_split_indices(active_df)
|
234 |
+
test_indeces['tanimoto'] = get_tanimoto_split_indices(active_df, active_col, test_split)
|
235 |
+
test_indeces['uniprot'] = get_target_split_indices(active_df, active_col, test_split)
|
236 |
|
237 |
# Make directory ../reports if it does not exist
|
238 |
if not os.path.exists('../reports'):
|
239 |
os.makedirs('../reports')
|
240 |
|
241 |
# Load embedding dictionaries
|
242 |
+
protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
|
243 |
+
cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
|
244 |
|
245 |
+
# Cross-Validation Training
|
246 |
report = []
|
247 |
for split_type, indeces in test_indeces.items():
|
248 |
+
# active_df = protac_df[protac_df[active_col].notna()].copy()
|
249 |
test_df = active_df.loc[indeces]
|
250 |
train_val_df = active_df[~active_df.index.isin(test_df.index)]
|
251 |
|
252 |
+
# Get the CV object
|
253 |
if split_type == 'random':
|
254 |
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
255 |
group = None
|
|
|
262 |
elif split_type == 'uniprot':
|
263 |
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
264 |
group = train_val_df['Uniprot Group'].to_numpy()
|
265 |
+
|
266 |
# Start the CV over the folds
|
267 |
X = train_val_df.drop(columns=active_col)
|
268 |
y = train_val_df[active_col].tolist()
|
|
|
300 |
|
301 |
print(stats)
|
302 |
# # Train and evaluate the model
|
303 |
+
# model, trainer, metrics = pdp.hyperparameter_tuning_and_training(
|
304 |
# protein2embedding,
|
305 |
# cell2embedding,
|
306 |
# smiles2fp,
|
|
|
325 |
# print(f'Ablation study with disabled embeddings: {disabled_embeddings}')
|
326 |
# print('-' * 100)
|
327 |
# stats['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
|
328 |
+
# model, trainer, metrics = pdp.train_model(
|
329 |
# protein2embedding,
|
330 |
# cell2embedding,
|
331 |
# smiles2fp,
|