ribesstefano
commited on
Commit
•
f9a730b
1
Parent(s):
40d8875
Started working on optuna-based sklearn models training
Browse files
notebooks/protac_degradation_predictor.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/protac_degradation_predictor.py
CHANGED
@@ -19,11 +19,13 @@ from rdkit import DataStructs
|
|
19 |
from jsonargparse import CLI
|
20 |
from tqdm.auto import tqdm
|
21 |
from imblearn.over_sampling import SMOTE, ADASYN
|
|
|
22 |
from sklearn.preprocessing import OrdinalEncoder, StandardScaler, LabelEncoder
|
23 |
from sklearn.model_selection import (
|
24 |
StratifiedKFold,
|
25 |
StratifiedGroupKFold,
|
26 |
)
|
|
|
27 |
|
28 |
import torch
|
29 |
import torch.nn as nn
|
@@ -37,8 +39,8 @@ from torchmetrics import (
|
|
37 |
Precision,
|
38 |
Recall,
|
39 |
F1Score,
|
|
|
40 |
)
|
41 |
-
from torchmetrics import MetricCollection
|
42 |
|
43 |
|
44 |
# Ignore UserWarning from Matplotlib
|
@@ -311,6 +313,16 @@ class PROTAC_Dataset(Dataset):
|
|
311 |
self.data['E3 Ligase Uniprot'] = self.data['E3 Ligase Uniprot'].apply(lambda x: scalers['E3 Ligase Uniprot'].transform(x[np.newaxis, :])[0])
|
312 |
self.data['Cell Line Identifier'] = self.data['Cell Line Identifier'].apply(lambda x: scalers['Cell Line Identifier'].transform(x[np.newaxis, :])[0])
|
313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
def __len__(self):
|
315 |
return len(self.data)
|
316 |
|
@@ -324,6 +336,97 @@ class PROTAC_Dataset(Dataset):
|
|
324 |
}
|
325 |
return elem
|
326 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
|
328 |
class PROTAC_Model(pl.LightningModule):
|
329 |
|
@@ -573,7 +676,7 @@ def train_model(
|
|
573 |
Args:
|
574 |
train_df (pd.DataFrame): The training set.
|
575 |
val_df (pd.DataFrame): The validation set.
|
576 |
-
test_df (pd.DataFrame): The test set.
|
577 |
hidden_dim (int): The hidden dimension of the model.
|
578 |
batch_size (int): The batch size.
|
579 |
learning_rate (float): The learning rate.
|
@@ -882,6 +985,7 @@ def main(
|
|
882 |
encoder = OrdinalEncoder()
|
883 |
protac_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
|
884 |
active_df = protac_df[protac_df[active_col].notna()].copy()
|
|
|
885 |
|
886 |
test_df = []
|
887 |
# For each group, get the number of active and inactive entries. Then, add those
|
@@ -889,7 +993,7 @@ def main(
|
|
889 |
# 20% of the active_df lenght, and 2) the percentage of True and False entries
|
890 |
# in the active_col in test_df is roughly 50%.
|
891 |
# Start the loop from the groups containing the smallest number of entries.
|
892 |
-
for group in
|
893 |
group_df = active_df[active_df['Tanimoto Group'] == group]
|
894 |
if test_df == []:
|
895 |
test_df.append(group_df)
|
@@ -969,6 +1073,8 @@ def main(
|
|
969 |
|
970 |
report = []
|
971 |
for split_type, indeces in test_indeces.items():
|
|
|
|
|
972 |
active_df = protac_df[protac_df[active_col].notna()].copy()
|
973 |
test_df = active_df.loc[indeces]
|
974 |
train_val_df = active_df[~active_df.index.isin(test_df.index)]
|
@@ -1060,7 +1166,7 @@ def main(
|
|
1060 |
|
1061 |
report_df = pd.DataFrame(report)
|
1062 |
report_df.to_csv(
|
1063 |
-
f'../reports/cv_report_hparam_search_{cv_n_splits}-splits_{active_name}_test_split_{test_split}.csv',
|
1064 |
index=False,
|
1065 |
)
|
1066 |
|
|
|
19 |
from jsonargparse import CLI
|
20 |
from tqdm.auto import tqdm
|
21 |
from imblearn.over_sampling import SMOTE, ADASYN
|
22 |
+
|
23 |
from sklearn.preprocessing import OrdinalEncoder, StandardScaler, LabelEncoder
|
24 |
from sklearn.model_selection import (
|
25 |
StratifiedKFold,
|
26 |
StratifiedGroupKFold,
|
27 |
)
|
28 |
+
from sklearn.base import ClassifierMixin
|
29 |
|
30 |
import torch
|
31 |
import torch.nn as nn
|
|
|
39 |
Precision,
|
40 |
Recall,
|
41 |
F1Score,
|
42 |
+
MetricCollection,
|
43 |
)
|
|
|
44 |
|
45 |
|
46 |
# Ignore UserWarning from Matplotlib
|
|
|
313 |
self.data['E3 Ligase Uniprot'] = self.data['E3 Ligase Uniprot'].apply(lambda x: scalers['E3 Ligase Uniprot'].transform(x[np.newaxis, :])[0])
|
314 |
self.data['Cell Line Identifier'] = self.data['Cell Line Identifier'].apply(lambda x: scalers['Cell Line Identifier'].transform(x[np.newaxis, :])[0])
|
315 |
|
316 |
+
def get_numpy_arrays(self):
|
317 |
+
X = np.hstack([
|
318 |
+
np.array(self.data['Smiles'].tolist()),
|
319 |
+
np.array(self.data['Uniprot'].tolist()),
|
320 |
+
np.array(self.data['E3 Ligase Uniprot'].tolist()),
|
321 |
+
np.array(self.data['Cell Line Identifier'].tolist()),
|
322 |
+
]).copy()
|
323 |
+
y = self.data[self.active_label].values.copy()
|
324 |
+
return X, y
|
325 |
+
|
326 |
def __len__(self):
|
327 |
return len(self.data)
|
328 |
|
|
|
336 |
}
|
337 |
return elem
|
338 |
|
339 |
+
def train_sklearn_model(
|
340 |
+
clf: ClassifierMixin,
|
341 |
+
train_df: pd.DataFrame,
|
342 |
+
val_df: pd.DataFrame,
|
343 |
+
test_df: Optional[pd.DataFrame] = None,
|
344 |
+
active_label: str = 'Active',
|
345 |
+
use_single_scaler: bool = True,
|
346 |
+
) -> Tuple[ClassifierMixin, nn.ModuleDict]:
|
347 |
+
""" Train a classifier model on train and val sets and evaluate it on a test set.
|
348 |
+
|
349 |
+
Args:
|
350 |
+
clf: The classifier model to train and evaluate.
|
351 |
+
train_df (pd.DataFrame): The training set.
|
352 |
+
val_df (pd.DataFrame): The validation set.
|
353 |
+
test_df (Optional[pd.DataFrame]): The test set.
|
354 |
+
|
355 |
+
Returns:
|
356 |
+
Tuple[ClassifierMixin, nn.ModuleDict]: The trained model and the metrics.
|
357 |
+
"""
|
358 |
+
# Initialize the datasets
|
359 |
+
train_ds = PROTAC_Dataset(
|
360 |
+
train_df,
|
361 |
+
protein_embeddings,
|
362 |
+
cell2embedding,
|
363 |
+
smiles2fp,
|
364 |
+
active_label=active_label,
|
365 |
+
use_smote=False,
|
366 |
+
)
|
367 |
+
scaler = train_ds.fit_scaling(use_single_scaler=use_single_scaler)
|
368 |
+
train_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
|
369 |
+
val_ds = PROTAC_Dataset(
|
370 |
+
val_df,
|
371 |
+
protein_embeddings,
|
372 |
+
cell2embedding,
|
373 |
+
smiles2fp,
|
374 |
+
active_label=active_label,
|
375 |
+
use_smote=False,
|
376 |
+
)
|
377 |
+
val_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
|
378 |
+
if test_df is not None:
|
379 |
+
test_ds = PROTAC_Dataset(
|
380 |
+
test_df,
|
381 |
+
protein_embeddings,
|
382 |
+
cell2embedding,
|
383 |
+
smiles2fp,
|
384 |
+
active_label=active_label,
|
385 |
+
use_smote=False,
|
386 |
+
)
|
387 |
+
test_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
|
388 |
+
|
389 |
+
# Get the numpy arrays
|
390 |
+
X_train, y_train = train_ds.get_numpy_arrays()
|
391 |
+
X_val, y_val = val_ds.get_numpy_arrays()
|
392 |
+
if test_df is not None:
|
393 |
+
X_test, y_test = test_ds.get_numpy_arrays()
|
394 |
+
|
395 |
+
# Train the model
|
396 |
+
clf.fit(X_train, y_train)
|
397 |
+
# Define the metrics as a module dict
|
398 |
+
stages = ['train_metrics', 'val_metrics', 'test_metrics']
|
399 |
+
metrics = nn.ModuleDict({s: MetricCollection({
|
400 |
+
'acc': Accuracy(task='binary'),
|
401 |
+
'roc_auc': AUROC(task='binary'),
|
402 |
+
'precision': Precision(task='binary'),
|
403 |
+
'recall': Recall(task='binary'),
|
404 |
+
'f1_score': F1Score(task='binary'),
|
405 |
+
'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
|
406 |
+
'hp_metric': Accuracy(task='binary'),
|
407 |
+
}, prefix=s.replace('metrics', '')) for s in stages})
|
408 |
+
|
409 |
+
# Get the predictions
|
410 |
+
metrics_out = {}
|
411 |
+
|
412 |
+
y_pred = torch.tensor(clf.predict_proba(X_train)[:, 1])
|
413 |
+
y_true = torch.tensor(y_train)
|
414 |
+
metrics['train_metrics'].update(y_pred, y_true)
|
415 |
+
metrics_out.update(metrics['train_metrics'].compute())
|
416 |
+
|
417 |
+
y_pred = torch.tensor(clf.predict_proba(X_val)[:, 1])
|
418 |
+
y_true = torch.tensor(y_val)
|
419 |
+
metrics['val_metrics'].update(y_pred, y_true)
|
420 |
+
metrics_out.update(metrics['val_metrics'].compute())
|
421 |
+
|
422 |
+
if test_df is not None:
|
423 |
+
y_pred = torch.tensor(clf.predict_proba(X_test)[:, 1])
|
424 |
+
y_true = torch.tensor(y_test)
|
425 |
+
metrics['test_metrics'].update(y_pred, y_true)
|
426 |
+
metrics_out.update(metrics['test_metrics'].compute())
|
427 |
+
|
428 |
+
return clf, metrics_out
|
429 |
+
|
430 |
|
431 |
class PROTAC_Model(pl.LightningModule):
|
432 |
|
|
|
676 |
Args:
|
677 |
train_df (pd.DataFrame): The training set.
|
678 |
val_df (pd.DataFrame): The validation set.
|
679 |
+
test_df (pd.DataFrame): The test set. If provided, the returned metrics will include test performance.
|
680 |
hidden_dim (int): The hidden dimension of the model.
|
681 |
batch_size (int): The batch size.
|
682 |
learning_rate (float): The learning rate.
|
|
|
985 |
encoder = OrdinalEncoder()
|
986 |
protac_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
|
987 |
active_df = protac_df[protac_df[active_col].notna()].copy()
|
988 |
+
tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
|
989 |
|
990 |
test_df = []
|
991 |
# For each group, get the number of active and inactive entries. Then, add those
|
|
|
993 |
# 20% of the active_df lenght, and 2) the percentage of True and False entries
|
994 |
# in the active_col in test_df is roughly 50%.
|
995 |
# Start the loop from the groups containing the smallest number of entries.
|
996 |
+
for group in tanimoto_groups:
|
997 |
group_df = active_df[active_df['Tanimoto Group'] == group]
|
998 |
if test_df == []:
|
999 |
test_df.append(group_df)
|
|
|
1073 |
|
1074 |
report = []
|
1075 |
for split_type, indeces in test_indeces.items():
|
1076 |
+
if split_type != 'tanimoto':
|
1077 |
+
continue
|
1078 |
active_df = protac_df[protac_df[active_col].notna()].copy()
|
1079 |
test_df = active_df.loc[indeces]
|
1080 |
train_val_df = active_df[~active_df.index.isin(test_df.index)]
|
|
|
1166 |
|
1167 |
report_df = pd.DataFrame(report)
|
1168 |
report_df.to_csv(
|
1169 |
+
f'../reports/cv_report_hparam_search_{cv_n_splits}-splits_{active_name}_test_split_{test_split}_tanimoto.csv',
|
1170 |
index=False,
|
1171 |
)
|
1172 |
|