ribesstefano commited on
Commit
f9a730b
1 Parent(s): 40d8875

Started working on optuna-based sklearn models training

Browse files
notebooks/protac_degradation_predictor.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/protac_degradation_predictor.py CHANGED
@@ -19,11 +19,13 @@ from rdkit import DataStructs
19
  from jsonargparse import CLI
20
  from tqdm.auto import tqdm
21
  from imblearn.over_sampling import SMOTE, ADASYN
 
22
  from sklearn.preprocessing import OrdinalEncoder, StandardScaler, LabelEncoder
23
  from sklearn.model_selection import (
24
  StratifiedKFold,
25
  StratifiedGroupKFold,
26
  )
 
27
 
28
  import torch
29
  import torch.nn as nn
@@ -37,8 +39,8 @@ from torchmetrics import (
37
  Precision,
38
  Recall,
39
  F1Score,
 
40
  )
41
- from torchmetrics import MetricCollection
42
 
43
 
44
  # Ignore UserWarning from Matplotlib
@@ -311,6 +313,16 @@ class PROTAC_Dataset(Dataset):
311
  self.data['E3 Ligase Uniprot'] = self.data['E3 Ligase Uniprot'].apply(lambda x: scalers['E3 Ligase Uniprot'].transform(x[np.newaxis, :])[0])
312
  self.data['Cell Line Identifier'] = self.data['Cell Line Identifier'].apply(lambda x: scalers['Cell Line Identifier'].transform(x[np.newaxis, :])[0])
313
 
 
 
 
 
 
 
 
 
 
 
314
  def __len__(self):
315
  return len(self.data)
316
 
@@ -324,6 +336,97 @@ class PROTAC_Dataset(Dataset):
324
  }
325
  return elem
326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
  class PROTAC_Model(pl.LightningModule):
329
 
@@ -573,7 +676,7 @@ def train_model(
573
  Args:
574
  train_df (pd.DataFrame): The training set.
575
  val_df (pd.DataFrame): The validation set.
576
- test_df (pd.DataFrame): The test set.
577
  hidden_dim (int): The hidden dimension of the model.
578
  batch_size (int): The batch size.
579
  learning_rate (float): The learning rate.
@@ -882,6 +985,7 @@ def main(
882
  encoder = OrdinalEncoder()
883
  protac_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
884
  active_df = protac_df[protac_df[active_col].notna()].copy()
 
885
 
886
  test_df = []
887
  # For each group, get the number of active and inactive entries. Then, add those
@@ -889,7 +993,7 @@ def main(
889
  # 20% of the active_df lenght, and 2) the percentage of True and False entries
890
  # in the active_col in test_df is roughly 50%.
891
  # Start the loop from the groups containing the smallest number of entries.
892
- for group in reversed(active_df['Tanimoto Group'].value_counts().index):
893
  group_df = active_df[active_df['Tanimoto Group'] == group]
894
  if test_df == []:
895
  test_df.append(group_df)
@@ -969,6 +1073,8 @@ def main(
969
 
970
  report = []
971
  for split_type, indeces in test_indeces.items():
 
 
972
  active_df = protac_df[protac_df[active_col].notna()].copy()
973
  test_df = active_df.loc[indeces]
974
  train_val_df = active_df[~active_df.index.isin(test_df.index)]
@@ -1060,7 +1166,7 @@ def main(
1060
 
1061
  report_df = pd.DataFrame(report)
1062
  report_df.to_csv(
1063
- f'../reports/cv_report_hparam_search_{cv_n_splits}-splits_{active_name}_test_split_{test_split}.csv',
1064
  index=False,
1065
  )
1066
 
 
19
  from jsonargparse import CLI
20
  from tqdm.auto import tqdm
21
  from imblearn.over_sampling import SMOTE, ADASYN
22
+
23
  from sklearn.preprocessing import OrdinalEncoder, StandardScaler, LabelEncoder
24
  from sklearn.model_selection import (
25
  StratifiedKFold,
26
  StratifiedGroupKFold,
27
  )
28
+ from sklearn.base import ClassifierMixin
29
 
30
  import torch
31
  import torch.nn as nn
 
39
  Precision,
40
  Recall,
41
  F1Score,
42
+ MetricCollection,
43
  )
 
44
 
45
 
46
  # Ignore UserWarning from Matplotlib
 
313
  self.data['E3 Ligase Uniprot'] = self.data['E3 Ligase Uniprot'].apply(lambda x: scalers['E3 Ligase Uniprot'].transform(x[np.newaxis, :])[0])
314
  self.data['Cell Line Identifier'] = self.data['Cell Line Identifier'].apply(lambda x: scalers['Cell Line Identifier'].transform(x[np.newaxis, :])[0])
315
 
316
+ def get_numpy_arrays(self):
317
+ X = np.hstack([
318
+ np.array(self.data['Smiles'].tolist()),
319
+ np.array(self.data['Uniprot'].tolist()),
320
+ np.array(self.data['E3 Ligase Uniprot'].tolist()),
321
+ np.array(self.data['Cell Line Identifier'].tolist()),
322
+ ]).copy()
323
+ y = self.data[self.active_label].values.copy()
324
+ return X, y
325
+
326
  def __len__(self):
327
  return len(self.data)
328
 
 
336
  }
337
  return elem
338
 
339
+ def train_sklearn_model(
340
+ clf: ClassifierMixin,
341
+ train_df: pd.DataFrame,
342
+ val_df: pd.DataFrame,
343
+ test_df: Optional[pd.DataFrame] = None,
344
+ active_label: str = 'Active',
345
+ use_single_scaler: bool = True,
346
+ ) -> Tuple[ClassifierMixin, nn.ModuleDict]:
347
+ """ Train a classifier model on train and val sets and evaluate it on a test set.
348
+
349
+ Args:
350
+ clf: The classifier model to train and evaluate.
351
+ train_df (pd.DataFrame): The training set.
352
+ val_df (pd.DataFrame): The validation set.
353
+ test_df (Optional[pd.DataFrame]): The test set.
354
+
355
+ Returns:
356
+ Tuple[ClassifierMixin, nn.ModuleDict]: The trained model and the metrics.
357
+ """
358
+ # Initialize the datasets
359
+ train_ds = PROTAC_Dataset(
360
+ train_df,
361
+ protein_embeddings,
362
+ cell2embedding,
363
+ smiles2fp,
364
+ active_label=active_label,
365
+ use_smote=False,
366
+ )
367
+ scaler = train_ds.fit_scaling(use_single_scaler=use_single_scaler)
368
+ train_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
369
+ val_ds = PROTAC_Dataset(
370
+ val_df,
371
+ protein_embeddings,
372
+ cell2embedding,
373
+ smiles2fp,
374
+ active_label=active_label,
375
+ use_smote=False,
376
+ )
377
+ val_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
378
+ if test_df is not None:
379
+ test_ds = PROTAC_Dataset(
380
+ test_df,
381
+ protein_embeddings,
382
+ cell2embedding,
383
+ smiles2fp,
384
+ active_label=active_label,
385
+ use_smote=False,
386
+ )
387
+ test_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
388
+
389
+ # Get the numpy arrays
390
+ X_train, y_train = train_ds.get_numpy_arrays()
391
+ X_val, y_val = val_ds.get_numpy_arrays()
392
+ if test_df is not None:
393
+ X_test, y_test = test_ds.get_numpy_arrays()
394
+
395
+ # Train the model
396
+ clf.fit(X_train, y_train)
397
+ # Define the metrics as a module dict
398
+ stages = ['train_metrics', 'val_metrics', 'test_metrics']
399
+ metrics = nn.ModuleDict({s: MetricCollection({
400
+ 'acc': Accuracy(task='binary'),
401
+ 'roc_auc': AUROC(task='binary'),
402
+ 'precision': Precision(task='binary'),
403
+ 'recall': Recall(task='binary'),
404
+ 'f1_score': F1Score(task='binary'),
405
+ 'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
406
+ 'hp_metric': Accuracy(task='binary'),
407
+ }, prefix=s.replace('metrics', '')) for s in stages})
408
+
409
+ # Get the predictions
410
+ metrics_out = {}
411
+
412
+ y_pred = torch.tensor(clf.predict_proba(X_train)[:, 1])
413
+ y_true = torch.tensor(y_train)
414
+ metrics['train_metrics'].update(y_pred, y_true)
415
+ metrics_out.update(metrics['train_metrics'].compute())
416
+
417
+ y_pred = torch.tensor(clf.predict_proba(X_val)[:, 1])
418
+ y_true = torch.tensor(y_val)
419
+ metrics['val_metrics'].update(y_pred, y_true)
420
+ metrics_out.update(metrics['val_metrics'].compute())
421
+
422
+ if test_df is not None:
423
+ y_pred = torch.tensor(clf.predict_proba(X_test)[:, 1])
424
+ y_true = torch.tensor(y_test)
425
+ metrics['test_metrics'].update(y_pred, y_true)
426
+ metrics_out.update(metrics['test_metrics'].compute())
427
+
428
+ return clf, metrics_out
429
+
430
 
431
  class PROTAC_Model(pl.LightningModule):
432
 
 
676
  Args:
677
  train_df (pd.DataFrame): The training set.
678
  val_df (pd.DataFrame): The validation set.
679
+ test_df (pd.DataFrame): The test set. If provided, the returned metrics will include test performance.
680
  hidden_dim (int): The hidden dimension of the model.
681
  batch_size (int): The batch size.
682
  learning_rate (float): The learning rate.
 
985
  encoder = OrdinalEncoder()
986
  protac_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
987
  active_df = protac_df[protac_df[active_col].notna()].copy()
988
+ tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
989
 
990
  test_df = []
991
  # For each group, get the number of active and inactive entries. Then, add those
 
993
  # 20% of the active_df lenght, and 2) the percentage of True and False entries
994
  # in the active_col in test_df is roughly 50%.
995
  # Start the loop from the groups containing the smallest number of entries.
996
+ for group in tanimoto_groups:
997
  group_df = active_df[active_df['Tanimoto Group'] == group]
998
  if test_df == []:
999
  test_df.append(group_df)
 
1073
 
1074
  report = []
1075
  for split_type, indeces in test_indeces.items():
1076
+ if split_type != 'tanimoto':
1077
+ continue
1078
  active_df = protac_df[protac_df[active_col].notna()].copy()
1079
  test_df = active_df.loc[indeces]
1080
  train_val_df = active_df[~active_df.index.isin(test_df.index)]
 
1166
 
1167
  report_df = pd.DataFrame(report)
1168
  report_df.to_csv(
1169
+ f'../reports/cv_report_hparam_search_{cv_n_splits}-splits_{active_name}_test_split_{test_split}_tanimoto.csv',
1170
  index=False,
1171
  )
1172