ribesstefano commited on
Commit
5e01175
1 Parent(s): 6101de8

Started working on packaging the repository

Browse files
Files changed (23) hide show
  1. notebooks/plotting_dragradation_activity_performance.ipynb +0 -0
  2. notebooks/protac_degradation_predictor.ipynb +10 -2
  3. notebooks/protac_degradation_predictor.py +3 -2
  4. protac_degradation_predictor/__init__.py +7 -0
  5. protac_degradation_predictor/config.py +37 -0
  6. protac_degradation_predictor/data/PROTAC-DB.csv +0 -0
  7. reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_0_test_split_0.2.pkl → protac_degradation_predictor/data/cell2embedding.pkl +2 -2
  8. reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_1_test_split_0.2.pkl → protac_degradation_predictor/data/uniprot2embedding.h5 +2 -2
  9. protac_degradation_predictor/data_utils.py +46 -0
  10. protac_degradation_predictor/optuna_utils.py +318 -0
  11. protac_degradation_predictor/protac_dataset.py +193 -0
  12. protac_degradation_predictor/protac_degradation_predictor.py +88 -0
  13. protac_degradation_predictor/pytorch_models.py +471 -0
  14. protac_degradation_predictor/sklearn_models.py +243 -0
  15. reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_0_test_split_0.1.pkl +1 -1
  16. reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_1_test_split_0.1.pkl +1 -1
  17. reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_2_test_split_0.1.pkl +1 -1
  18. reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_2_test_split_0.2.pkl +0 -3
  19. reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_3_test_split_0.1.pkl +1 -1
  20. reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_3_test_split_0.2.pkl +0 -3
  21. reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_4_test_split_0.1.pkl +1 -1
  22. reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_4_test_split_0.2.pkl +0 -3
  23. setup.py +21 -0
notebooks/plotting_dragradation_activity_performance.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks/protac_degradation_predictor.ipynb CHANGED
@@ -1719,8 +1719,16 @@
1719
  }
1720
  ],
1721
  "source": [
1722
- "from typing import Literal, List, Tuple, Optional\n",
1723
- "from sklearn.base import ClassifierMixin\n",
 
 
 
 
 
 
 
 
1724
  "\n",
1725
  "# Generic function to fit and evaluate a classifier model (given as argument),\n",
1726
  "# on train and val sets (and optionally a test set) given as dataframes\n",
 
1719
  }
1720
  ],
1721
  "source": [
1722
+ "import torch\n",
1723
+ "import torch.nn as nn\n",
1724
+ "from torchmetrics import (\n",
1725
+ " Accuracy,\n",
1726
+ " AUROC,\n",
1727
+ " Precision,\n",
1728
+ " Recall,\n",
1729
+ " F1Score,\n",
1730
+ " MetricCollection,\n",
1731
+ ")\n",
1732
  "\n",
1733
  "# Generic function to fit and evaluate a classifier model (given as argument),\n",
1734
  "# on train and val sets (and optionally a test set) given as dataframes\n",
notebooks/protac_degradation_predictor.py CHANGED
@@ -680,7 +680,7 @@ def train_model(
680
  hidden_dim (int): The hidden dimension of the model.
681
  batch_size (int): The batch size.
682
  learning_rate (float): The learning rate.
683
- max_epochs (int): The maximum number of epochs.
684
  smiles_emb_dim (int): The dimension of the SMILES embeddings.
685
  smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
686
  fast_dev_run (bool): Whether to run a fast development run.
@@ -985,6 +985,8 @@ def main(
985
  encoder = OrdinalEncoder()
986
  protac_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
987
  active_df = protac_df[protac_df[active_col].notna()].copy()
 
 
988
  tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
989
 
990
  test_df = []
@@ -992,7 +994,6 @@ def main(
992
  # entries to the test_df if: 1) the test_df lenght + the group entries is less
993
  # 20% of the active_df lenght, and 2) the percentage of True and False entries
994
  # in the active_col in test_df is roughly 50%.
995
- # Start the loop from the groups containing the smallest number of entries.
996
  for group in tanimoto_groups:
997
  group_df = active_df[active_df['Tanimoto Group'] == group]
998
  if test_df == []:
 
680
  hidden_dim (int): The hidden dimension of the model.
681
  batch_size (int): The batch size.
682
  learning_rate (float): The learning rate.
683
+ max_epochs (int): Th e maximum number of epochs.
684
  smiles_emb_dim (int): The dimension of the SMILES embeddings.
685
  smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
686
  fast_dev_run (bool): Whether to run a fast development run.
 
985
  encoder = OrdinalEncoder()
986
  protac_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
987
  active_df = protac_df[protac_df[active_col].notna()].copy()
988
+ # Sort the groups so that samples with the highest tanimoto similarity,
989
+ # i.e., the "less similar" ones, are placed in the test set first
990
  tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
991
 
992
  test_df = []
 
994
  # entries to the test_df if: 1) the test_df lenght + the group entries is less
995
  # 20% of the active_df lenght, and 2) the percentage of True and False entries
996
  # in the active_col in test_df is roughly 50%.
 
997
  for group in tanimoto_groups:
998
  group_df = active_df[active_df['Tanimoto Group'] == group]
999
  if test_df == []:
protac_degradation_predictor/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .protac_degradation_predictor import (
2
+ PROTAC_Model,
3
+ train_model,
4
+ )
5
+
6
+ __version__ = "0.0.1"
7
+ __author__ = "Stefano Ribes"
protac_degradation_predictor/config.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass(frozen=True)
4
+ class Config:
5
+ # Embeddings information
6
+ morgan_radius: int = 15
7
+ fingerprint_size: int = 224
8
+ protein_embedding_size: int = 1024
9
+ cell_embedding_size: int = 768
10
+
11
+ # Data information
12
+ dmax_threshold: float = 0.6
13
+ pdc50_threshold: float = 6.0
14
+ e3_ligase2uniprot: dict = {
15
+ 'VHL': 'P40337',
16
+ 'CRBN': 'Q96SW2',
17
+ 'DCAF11': 'Q8TEB1',
18
+ 'DCAF15': 'Q66K64',
19
+ 'DCAF16': 'Q9NXF7',
20
+ 'MDM2': 'Q00987',
21
+ 'Mdm2': 'Q00987',
22
+ 'XIAP': 'P98170',
23
+ 'cIAP1': 'Q7Z460',
24
+ 'IAP': 'P98170', # I couldn't find the Uniprot ID for IAP, so it's XIAP instead
25
+ 'Iap': 'P98170', # I couldn't find the Uniprot ID for IAP, so it's XIAP instead
26
+ 'AhR': 'P35869',
27
+ 'RNF4': 'P78317',
28
+ 'RNF114': 'Q9Y508',
29
+ 'FEM1B': 'Q9UK73',
30
+ 'Ubr1': 'Q8IWV7',
31
+ }
32
+
33
+ def __post_init__(self):
34
+ self.active_label: str = f'Active (Dmax {self.dmax_threshold}, pDC50 {self.pdc50_threshold})'
35
+
36
+
37
+ config = Config()
protac_degradation_predictor/data/PROTAC-DB.csv ADDED
The diff for this file is too large to render. See raw diff
 
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_0_test_split_0.2.pkl → protac_degradation_predictor/data/cell2embedding.pkl RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:647f8b8e72f9f1f72ecdb8733b306df290f62fee42d30ab6da0f26cf3ed3b010
3
- size 45164
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:627e8ce3842afeb6bb7d5caa5ec1ba034c36dc77fab70734e15dca340a7fd718
3
+ size 3550864
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_1_test_split_0.2.pkl → protac_degradation_predictor/data/uniprot2embedding.h5 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:375e5f4f5080d2cf654b932faedeb2b0e9433d0c738542098f789beecd980c65
3
- size 45164
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19f4b8c73652392db7840962d1a7817c7e899716e2bb758e4947c8c2bb265336
3
+ size 51089512
protac_degradation_predictor/data_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pkg_resources
3
+ import pickle
4
+ from typing import Dict
5
+
6
+ from config import config
7
+
8
+ import h5py
9
+ import numpy as np
10
+ import pandas as pd
11
+ from rdkit import Chem
12
+ from rdkit.Chem import AllChem
13
+ from joblib import Memory
14
+
15
+
16
+ home_dir = os.path.expanduser('~')
17
+ cachedir = os.path.join(home_dir, '.cache', 'protac_degradation_predictor')
18
+ memory = Memory(cachedir, verbose=0)
19
+
20
+
21
+ @memory.cache
22
+ def load_protein2embedding() -> Dict[str, np.ndarray]:
23
+ embeddings_path = pkg_resources.resource_stream(__name__, 'data/uniprot2embedding.h5')
24
+ protein2embedding = {}
25
+ with h5py.File(embeddings_path, "r") as file:
26
+ for sequence_id in file.keys():
27
+ embedding = file[sequence_id][:]
28
+ protein2embedding[sequence_id] = np.array(embedding)
29
+ return protein2embedding
30
+
31
+
32
+ @memory.cache
33
+ def load_cell2embedding() -> Dict[str, np.ndarray]:
34
+ embeddings_path = pkg_resources.resource_stream(__name__, 'data/cell2embedding.pkl')
35
+ with open(embeddings_path, 'rb') as f:
36
+ cell2embedding = pickle.load(f)
37
+ return cell2embedding
38
+
39
+
40
+ def get_fingerprint(smiles: str) -> np.ndarray:
41
+ morgan_fpgen = AllChem.GetMorganGenerator(
42
+ radius=config.morgan_radius,
43
+ fpSize=config.fingerprint_size,
44
+ includeChirality=True,
45
+ )
46
+ return morgan_fpgen.GetFingerprint(Chem.MolFromSmiles(smiles))
protac_degradation_predictor/optuna_utils.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Literal, List, Tuple, Optional, Dict
3
+
4
+ from pytorch_models import train_model
5
+ from sklearn_models import (
6
+ train_sklearn_model,
7
+ suggest_random_forest,
8
+ suggest_logistic_regression,
9
+ suggest_svc,
10
+ suggest_gradient_boosting,
11
+ )
12
+
13
+ import optuna
14
+ from optuna.samplers import TPESampler
15
+ import joblib
16
+ import pandas as pd
17
+ from sklearn.ensemble import (
18
+ RandomForestClassifier,
19
+ GradientBoostingClassifier,
20
+ )
21
+ from sklearn.linear_model import LogisticRegression
22
+ from sklearn.svm import SVC
23
+
24
+
25
+ def pytorch_model_objective(
26
+ trial: optuna.Trial,
27
+ protein2embedding: Dict,
28
+ cell2embedding: Dict,
29
+ smiles2fp: Dict,
30
+ train_df: pd.DataFrame,
31
+ val_df: pd.DataFrame,
32
+ hidden_dim_options: List[int] = [256, 512, 768],
33
+ batch_size_options: List[int] = [8, 16, 32],
34
+ learning_rate_options: Tuple[float, float] = (1e-5, 1e-3),
35
+ smote_k_neighbors_options: List[int] = list(range(3, 16)),
36
+ dropout_options: Tuple[float, float] = (0.1, 0.5),
37
+ fast_dev_run: bool = False,
38
+ active_label: str = 'Active',
39
+ disabled_embeddings: List[str] = [],
40
+ max_epochs: int = 100,
41
+ ) -> float:
42
+ """ Objective function for hyperparameter optimization.
43
+
44
+ Args:
45
+ trial (optuna.Trial): The Optuna trial object.
46
+ train_df (pd.DataFrame): The training set.
47
+ val_df (pd.DataFrame): The validation set.
48
+ hidden_dim_options (List[int]): The hidden dimension options.
49
+ batch_size_options (List[int]): The batch size options.
50
+ learning_rate_options (Tuple[float, float]): The learning rate options.
51
+ smote_k_neighbors_options (List[int]): The SMOTE k neighbors options.
52
+ dropout_options (Tuple[float, float]): The dropout options.
53
+ fast_dev_run (bool): Whether to run a fast development run.
54
+ active_label (str): The active label column.
55
+ disabled_embeddings (List[str]): The list of disabled embeddings.
56
+ """
57
+ # Generate the hyperparameters
58
+ hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
59
+ batch_size = trial.suggest_categorical('batch_size', batch_size_options)
60
+ learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
61
+ join_embeddings = trial.suggest_categorical('join_embeddings', ['beginning', 'concat', 'sum'])
62
+ smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
63
+ use_smote = trial.suggest_categorical('use_smote', [True, False])
64
+ apply_scaling = trial.suggest_categorical('apply_scaling', [True, False])
65
+ dropout = trial.suggest_float('dropout', *dropout_options)
66
+
67
+ # Train the model with the current set of hyperparameters
68
+ _, _, metrics = train_model(
69
+ protein2embedding,
70
+ cell2embedding,
71
+ smiles2fp,
72
+ train_df,
73
+ val_df,
74
+ hidden_dim=hidden_dim,
75
+ batch_size=batch_size,
76
+ join_embeddings=join_embeddings,
77
+ learning_rate=learning_rate,
78
+ dropout=dropout,
79
+ max_epochs=max_epochs,
80
+ smote_k_neighbors=smote_k_neighbors,
81
+ apply_scaling=apply_scaling,
82
+ use_smote=use_smote,
83
+ use_logger=False,
84
+ fast_dev_run=fast_dev_run,
85
+ active_label=active_label,
86
+ disabled_embeddings=disabled_embeddings,
87
+ )
88
+
89
+ # Metrics is a dictionary containing at least the validation loss
90
+ val_loss = metrics['val_loss']
91
+ val_acc = metrics['val_acc']
92
+ val_roc_auc = metrics['val_roc_auc']
93
+
94
+ # Optuna aims to minimize the pytorch_model_objective
95
+ return val_loss - val_acc - val_roc_auc
96
+
97
+
98
+ def hyperparameter_tuning_and_training(
99
+ protein2embedding: Dict,
100
+ cell2embedding: Dict,
101
+ smiles2fp: Dict,
102
+ train_df: pd.DataFrame,
103
+ val_df: pd.DataFrame,
104
+ test_df: Optional[pd.DataFrame] = None,
105
+ fast_dev_run: bool = False,
106
+ n_trials: int = 50,
107
+ logger_name: str = 'protac_hparam_search',
108
+ active_label: str = 'Active',
109
+ disabled_embeddings: List[str] = [],
110
+ study_filename: Optional[str] = None,
111
+ ) -> tuple:
112
+ """ Hyperparameter tuning and training of a PROTAC model.
113
+
114
+ Args:
115
+ train_df (pd.DataFrame): The training set.
116
+ val_df (pd.DataFrame): The validation set.
117
+ test_df (pd.DataFrame): The test set.
118
+ fast_dev_run (bool): Whether to run a fast development run.
119
+ n_trials (int): The number of hyperparameter optimization trials.
120
+ logger_name (str): The name of the logger.
121
+ active_label (str): The active label column.
122
+ disabled_embeddings (List[str]): The list of disabled embeddings.
123
+
124
+ Returns:
125
+ tuple: The trained model, the trainer, and the best metrics.
126
+ """
127
+ # Define the search space
128
+ hidden_dim_options = [256, 512, 768]
129
+ batch_size_options = [8, 16, 32]
130
+ learning_rate_options = (1e-5, 1e-3) # min and max values for loguniform distribution
131
+ smote_k_neighbors_options = list(range(3, 16))
132
+
133
+ # Set the verbosity of Optuna
134
+ optuna.logging.set_verbosity(optuna.logging.WARNING)
135
+ # Create an Optuna study object
136
+ sampler = TPESampler(seed=42, multivariate=True)
137
+ study = optuna.create_study(direction='minimize', sampler=sampler)
138
+
139
+ study_loaded = False
140
+ if study_filename:
141
+ if os.path.exists(study_filename):
142
+ study = joblib.load(study_filename)
143
+ study_loaded = True
144
+ print(f'Loaded study from {study_filename}')
145
+
146
+ if not study_loaded:
147
+ study.optimize(
148
+ lambda trial: pytorch_model_objective(
149
+ trial=trial,
150
+ protein2embedding=protein2embedding,
151
+ cell2embedding=cell2embedding,
152
+ smiles2fp=smiles2fp,
153
+ train_df=train_df,
154
+ val_df=val_df,
155
+ hidden_dim_options=hidden_dim_options,
156
+ batch_size_options=batch_size_options,
157
+ learning_rate_options=learning_rate_options,
158
+ smote_k_neighbors_options=smote_k_neighbors_options,
159
+ fast_dev_run=fast_dev_run,
160
+ active_label=active_label,
161
+ disabled_embeddings=disabled_embeddings,
162
+ ),
163
+ n_trials=n_trials,
164
+ )
165
+ if study_filename:
166
+ joblib.dump(study, study_filename)
167
+
168
+ # Retrain the model with the best hyperparameters
169
+ model, trainer, metrics = train_model(
170
+ protein2embedding=protein2embedding,
171
+ cell2embedding=cell2embedding,
172
+ smiles2fp=smiles2fp,
173
+ train_df=train_df,
174
+ val_df=val_df,
175
+ test_df=test_df,
176
+ use_logger=True,
177
+ logger_name=logger_name,
178
+ fast_dev_run=fast_dev_run,
179
+ active_label=active_label,
180
+ disabled_embeddings=disabled_embeddings,
181
+ **study.best_params,
182
+ )
183
+
184
+ # Report the best hyperparameters found
185
+ metrics.update({f'hparam_{k}': v for k, v in study.best_params.items()})
186
+
187
+ # Return the best metrics
188
+ return model, trainer, metrics
189
+
190
+
191
+ def sklearn_model_objective(
192
+ trial: optuna.Trial,
193
+ protein2embedding: Dict,
194
+ cell2embedding: Dict,
195
+ smiles2fp: Dict,
196
+ train_df: pd.DataFrame,
197
+ val_df: pd.DataFrame,
198
+ model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
199
+ active_label: str = 'Active',
200
+ ) -> float:
201
+ """ Objective function for hyperparameter optimization.
202
+
203
+ Args:
204
+ trial (optuna.Trial): The Optuna trial object.
205
+ train_df (pd.DataFrame): The training set.
206
+ val_df (pd.DataFrame): The validation set.
207
+ model_type (str): The model type.
208
+ hyperparameters (Dict): The hyperparameters for the model.
209
+ fast_dev_run (bool): Whether to run a fast development run.
210
+ active_label (str): The active label column.
211
+ """
212
+
213
+ # Generate the hyperparameters
214
+ use_single_scaler = trial.suggest_categorical('use_single_scaler', [True, False])
215
+ if model_type == 'RandomForest':
216
+ clf = suggest_random_forest(trial)
217
+ elif model_type == 'SVC':
218
+ clf = suggest_svc(trial)
219
+ elif model_type == 'LogisticRegression':
220
+ clf = suggest_logistic_regression(trial)
221
+ elif model_type == 'GradientBoosting':
222
+ clf = suggest_gradient_boosting(trial)
223
+ else:
224
+ raise ValueError(f'Invalid model type: {model_type}. Available: RandomForest, SVC, LogisticRegression, GradientBoosting.')
225
+
226
+ # Train the model with the current set of hyperparameters
227
+ _, metrics = train_sklearn_model(
228
+ clf=clf,
229
+ protein2embedding=protein2embedding,
230
+ cell2embedding=cell2embedding,
231
+ smiles2fp=smiles2fp,
232
+ train_df=train_df,
233
+ val_df=val_df,
234
+ active_label=active_label,
235
+ use_single_scaler=use_single_scaler,
236
+ )
237
+
238
+ # Metrics is a dictionary containing at least the validation loss
239
+ val_acc = metrics['val_acc']
240
+ val_roc_auc = metrics['val_roc_auc']
241
+
242
+ # Optuna aims to minimize the sklearn_model_objective
243
+ return - val_acc - val_roc_auc
244
+
245
+
246
+ def hyperparameter_tuning_and_training_sklearn(
247
+ protein2embedding: Dict,
248
+ cell2embedding: Dict,
249
+ smiles2fp: Dict,
250
+ train_df: pd.DataFrame,
251
+ val_df: pd.DataFrame,
252
+ test_df: Optional[pd.DataFrame] = None,
253
+ model_type: Literal['RandomForest', 'SVC', 'LogisticRegression', 'GradientBoosting'] = 'RandomForest',
254
+ active_label: str = 'Active',
255
+ n_trials: int = 50,
256
+ logger_name: str = 'protac_hparam_search',
257
+ study_filename: Optional[str] = None,
258
+ ) -> Tuple:
259
+ # Set the verbosity of Optuna
260
+ optuna.logging.set_verbosity(optuna.logging.WARNING)
261
+ # Create an Optuna study object
262
+ sampler = TPESampler(seed=42, multivariate=True)
263
+ study = optuna.create_study(direction='minimize', sampler=sampler)
264
+
265
+ study_loaded = False
266
+ if study_filename:
267
+ if os.path.exists(study_filename):
268
+ study = joblib.load(study_filename)
269
+ study_loaded = True
270
+ print(f'Loaded study from {study_filename}')
271
+
272
+ if not study_loaded:
273
+ study.optimize(
274
+ lambda trial: sklearn_model_objective(
275
+ trial=trial,
276
+ protein2embedding=protein2embedding,
277
+ cell2embedding=cell2embedding,
278
+ smiles2fp=smiles2fp,
279
+ train_df=train_df,
280
+ val_df=val_df,
281
+ model_type=model_type,
282
+ active_label=active_label,
283
+ ),
284
+ n_trials=n_trials,
285
+ )
286
+ if study_filename:
287
+ joblib.dump(study, study_filename)
288
+
289
+ # Retrain the model with the best hyperparameters
290
+ best_hyperparameters = {k.replace('model_', ''): v for k, v in study.best_params.items() if k.startswith('model_')}
291
+ if model_type == 'RandomForest':
292
+ clf = RandomForestClassifier(random_state=42, **best_hyperparameters)
293
+ elif model_type == 'SVC':
294
+ clf = SVC(random_state=42, probability=True, **best_hyperparameters)
295
+ elif model_type == 'LogisticRegression':
296
+ clf = LogisticRegression(random_state=42, max_iter=1000, **best_hyperparameters)
297
+ elif model_type == 'GradientBoosting':
298
+ clf = GradientBoostingClassifier(random_state=42, **best_hyperparameters)
299
+ else:
300
+ raise ValueError(f'Invalid model type: {model_type}. Available: RandomForest, SVC, LogisticRegression, GradientBoosting.')
301
+
302
+ model, metrics = train_sklearn_model(
303
+ clf=clf,
304
+ protein2embedding=protein2embedding,
305
+ cell2embedding=cell2embedding,
306
+ smiles2fp=smiles2fp,
307
+ train_df=train_df,
308
+ val_df=val_df,
309
+ test_df=test_df,
310
+ active_label=active_label,
311
+ use_single_scaler=study.best_params['use_single_scaler'],
312
+ )
313
+
314
+ # Report the best hyperparameters found
315
+ metrics.update({f'hparam_{k}': v for k, v in study.best_params.items()})
316
+
317
+ # Return the best metrics
318
+ return model, metrics
protac_degradation_predictor/protac_dataset.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, List, Tuple, Optional, Dict
2
+
3
+ from torch.utils.data import Dataset
4
+ import numpy as np
5
+ from imblearn.over_sampling import SMOTE, ADASYN
6
+ import pandas as pd
7
+ from sklearn.preprocessing import StandardScaler
8
+
9
+
10
+ class PROTAC_Dataset(Dataset):
11
+ def __init__(
12
+ self,
13
+ protac_df: pd.DataFrame,
14
+ protein2embedding: Dict,
15
+ cell2embedding: Dict,
16
+ smiles2fp: Dict,
17
+ use_smote: bool = False,
18
+ oversampler: Optional[SMOTE | ADASYN] = None,
19
+ active_label: str = 'Active',
20
+ ):
21
+ """ Initialize the PROTAC dataset
22
+
23
+ Args:
24
+ protac_df (pd.DataFrame): The PROTAC dataframe
25
+ protein2embedding (dict): Dictionary of protein embeddings
26
+ cell2embedding (dict): Dictionary of cell line embeddings
27
+ smiles2fp (dict): Dictionary of SMILES to fingerprint
28
+ use_smote (bool): Whether to use SMOTE for oversampling
29
+ use_ored_activity (bool): Whether to use the 'Active - OR' column
30
+ """
31
+ # Filter out examples with NaN in active_col column
32
+ self.data = protac_df # [~protac_df[active_col].isna()]
33
+ self.protein2embedding = protein2embedding
34
+ self.cell2embedding = cell2embedding
35
+ self.smiles2fp = smiles2fp
36
+ self.active_label = active_label
37
+ self.use_single_scaler = None
38
+
39
+ self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0]
40
+ self.protein_emb_dim = protein2embedding[list(
41
+ protein2embedding.keys())[0]].shape[0]
42
+ self.cell_emb_dim = cell2embedding[list(
43
+ cell2embedding.keys())[0]].shape[0]
44
+
45
+ # Look up the embeddings
46
+ self.data = pd.DataFrame({
47
+ 'Smiles': self.data['Smiles'].apply(lambda x: smiles2fp[x].astype(np.float32)).tolist(),
48
+ 'Uniprot': self.data['Uniprot'].apply(lambda x: protein2embedding[x].astype(np.float32)).tolist(),
49
+ 'E3 Ligase Uniprot': self.data['E3 Ligase Uniprot'].apply(lambda x: protein2embedding[x].astype(np.float32)).tolist(),
50
+ 'Cell Line Identifier': self.data['Cell Line Identifier'].apply(lambda x: cell2embedding[x].astype(np.float32)).tolist(),
51
+ self.active_label: self.data[self.active_label].astype(np.float32).tolist(),
52
+ })
53
+
54
+ # Apply SMOTE
55
+ self.use_smote = use_smote
56
+ self.oversampler = oversampler
57
+ if self.use_smote:
58
+ self.apply_smote()
59
+
60
+ def apply_smote(self):
61
+ # Prepare the dataset for SMOTE
62
+ features = []
63
+ labels = []
64
+ for _, row in self.data.iterrows():
65
+ features.append(np.hstack([
66
+ row['Smiles'],
67
+ row['Uniprot'],
68
+ row['E3 Ligase Uniprot'],
69
+ row['Cell Line Identifier'],
70
+ ]))
71
+ labels.append(row[self.active_label])
72
+
73
+ # Convert to numpy array
74
+ features = np.array(features).astype(np.float32)
75
+ labels = np.array(labels).astype(np.float32)
76
+
77
+ # Initialize SMOTE and fit
78
+ if self.oversampler is None:
79
+ oversampler = SMOTE(random_state=42)
80
+ else:
81
+ oversampler = self.oversampler
82
+ features_smote, labels_smote = oversampler.fit_resample(features, labels)
83
+
84
+ # Separate the features back into their respective embeddings
85
+ smiles_embs = features_smote[:, :self.smiles_emb_dim]
86
+ poi_embs = features_smote[:,
87
+ self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]
88
+ e3_embs = features_smote[:, self.smiles_emb_dim +
89
+ self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]
90
+ cell_embs = features_smote[:, -self.cell_emb_dim:]
91
+
92
+ # Reconstruct the dataframe with oversampled data
93
+ df_smote = pd.DataFrame({
94
+ 'Smiles': list(smiles_embs),
95
+ 'Uniprot': list(poi_embs),
96
+ 'E3 Ligase Uniprot': list(e3_embs),
97
+ 'Cell Line Identifier': list(cell_embs),
98
+ self.active_label: labels_smote
99
+ })
100
+ self.data = df_smote
101
+
102
+ def fit_scaling(self, use_single_scaler: bool = False, **scaler_kwargs) -> dict:
103
+ """ Fit the scalers for the data.
104
+
105
+ Args:
106
+ use_single_scaler (bool): Whether to use a single scaler for all features.
107
+ scaler_kwargs: Keyword arguments for the StandardScaler.
108
+
109
+ Returns:
110
+ dict: The fitted scalers.
111
+ """
112
+ if use_single_scaler:
113
+ self.use_single_scaler = True
114
+ scaler = StandardScaler(**scaler_kwargs)
115
+ embeddings = np.hstack([
116
+ np.array(self.data['Smiles'].tolist()),
117
+ np.array(self.data['Uniprot'].tolist()),
118
+ np.array(self.data['E3 Ligase Uniprot'].tolist()),
119
+ np.array(self.data['Cell Line Identifier'].tolist()),
120
+ ])
121
+ scaler.fit(embeddings)
122
+ return scaler
123
+ else:
124
+ self.use_single_scaler = False
125
+ scalers = {}
126
+ scalers['Smiles'] = StandardScaler(**scaler_kwargs)
127
+ scalers['Uniprot'] = StandardScaler(**scaler_kwargs)
128
+ scalers['E3 Ligase Uniprot'] = StandardScaler(**scaler_kwargs)
129
+ scalers['Cell Line Identifier'] = StandardScaler(**scaler_kwargs)
130
+
131
+ scalers['Smiles'].fit(np.stack(self.data['Smiles'].to_numpy()))
132
+ scalers['Uniprot'].fit(np.stack(self.data['Uniprot'].to_numpy()))
133
+ scalers['E3 Ligase Uniprot'].fit(np.stack(self.data['E3 Ligase Uniprot'].to_numpy()))
134
+ scalers['Cell Line Identifier'].fit(np.stack(self.data['Cell Line Identifier'].to_numpy()))
135
+
136
+ return scalers
137
+
138
+ def apply_scaling(self, scalers: dict, use_single_scaler: bool = False):
139
+ """ Apply scaling to the data.
140
+
141
+ Args:
142
+ scalers (dict): The scalers for each feature.
143
+ use_single_scaler (bool): Whether to use a single scaler for all features.
144
+ """
145
+ if self.use_single_scaler is None:
146
+ raise ValueError(
147
+ "The fit_scaling method must be called before apply_scaling.")
148
+ if use_single_scaler != self.use_single_scaler:
149
+ raise ValueError(
150
+ f"The use_single_scaler parameter must be the same as the one used in the fit_scaling method. Got {use_single_scaler}, previously {self.use_single_scaler}.")
151
+ if use_single_scaler:
152
+ embeddings = np.hstack([
153
+ np.array(self.data['Smiles'].tolist()),
154
+ np.array(self.data['Uniprot'].tolist()),
155
+ np.array(self.data['E3 Ligase Uniprot'].tolist()),
156
+ np.array(self.data['Cell Line Identifier'].tolist()),
157
+ ])
158
+ scaled_embeddings = scalers.transform(embeddings)
159
+ self.data = pd.DataFrame({
160
+ 'Smiles': list(scaled_embeddings[:, :self.smiles_emb_dim]),
161
+ 'Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]),
162
+ 'E3 Ligase Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim+self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]),
163
+ 'Cell Line Identifier': list(scaled_embeddings[:, -self.cell_emb_dim:]),
164
+ self.active_label: self.data[self.active_label]
165
+ })
166
+ else:
167
+ self.data['Smiles'] = self.data['Smiles'].apply(lambda x: scalers['Smiles'].transform(x[np.newaxis, :])[0])
168
+ self.data['Uniprot'] = self.data['Uniprot'].apply(lambda x: scalers['Uniprot'].transform(x[np.newaxis, :])[0])
169
+ self.data['E3 Ligase Uniprot'] = self.data['E3 Ligase Uniprot'].apply(lambda x: scalers['E3 Ligase Uniprot'].transform(x[np.newaxis, :])[0])
170
+ self.data['Cell Line Identifier'] = self.data['Cell Line Identifier'].apply(lambda x: scalers['Cell Line Identifier'].transform(x[np.newaxis, :])[0])
171
+
172
+ def get_numpy_arrays(self):
173
+ X = np.hstack([
174
+ np.array(self.data['Smiles'].tolist()),
175
+ np.array(self.data['Uniprot'].tolist()),
176
+ np.array(self.data['E3 Ligase Uniprot'].tolist()),
177
+ np.array(self.data['Cell Line Identifier'].tolist()),
178
+ ]).copy()
179
+ y = self.data[self.active_label].values.copy()
180
+ return X, y
181
+
182
+ def __len__(self):
183
+ return len(self.data)
184
+
185
+ def __getitem__(self, idx):
186
+ elem = {
187
+ 'smiles_emb': self.data['Smiles'].iloc[idx],
188
+ 'poi_emb': self.data['Uniprot'].iloc[idx],
189
+ 'e3_emb': self.data['E3 Ligase Uniprot'].iloc[idx],
190
+ 'cell_emb': self.data['Cell Line Identifier'].iloc[idx],
191
+ 'active': self.data[self.active_label].iloc[idx],
192
+ }
193
+ return elem
protac_degradation_predictor/protac_degradation_predictor.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pkg_resources
2
+ import logging
3
+
4
+ from pytorch_models import PROTAC_Model, load_model
5
+ from data_utils import (
6
+ load_protein2embedding,
7
+ load_cell2embedding,
8
+ get_fingerprint,
9
+ )
10
+ from config import config
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import sigmoid
15
+
16
+ package_name = 'protac_degradation_predictor'
17
+
18
+ def get_protac_active_proba(
19
+ protac_smiles: str,
20
+ e3_ligase: str,
21
+ target_uniprot: str,
22
+ cell_line: str,
23
+ device: str = 'cpu',
24
+ ) -> bool:
25
+ ckpt_path = pkg_resources.resource_stream(__name__, 'data/model.ckpt')
26
+ model = load_model(ckpt_path).to(device)
27
+ protein2embedding = load_protein2embedding()
28
+ cell2embedding = load_cell2embedding()
29
+
30
+ # Setup default embeddings
31
+ if e3_ligase not in config.e3_ligase2uniprot:
32
+ available_e3_ligases = ', '.join(list(config.e3_ligase2uniprot.keys()))
33
+ logging.warning(f"The E3 ligase {e3_ligase} is not in the database. Using the default E3 ligase. Available E3 ligases are: {available_e3_ligases}")
34
+ if target_uniprot not in protein2embedding:
35
+ logging.warning(f"The target protein {target_uniprot} is not in the database. Using the default target protein.")
36
+ if cell_line not in load_cell2embedding():
37
+ logging.warning(f"The cell line {cell_line} is not in the database. Using the default cell line.")
38
+
39
+ default_protein_emb = np.zeros(config.protein_embedding_size)
40
+ default_cell_emb = np.zeros(config.cell_embedding_size)
41
+
42
+ # Convert the E3 ligase to Uniprot ID
43
+ e3_ligase_uniprot = config.e3_ligase2uniprot.get(e3_ligase, '')
44
+
45
+ # Get the embeddings
46
+ poi_emb = protein2embedding.get(target_uniprot, default_protein_emb)
47
+ e3_emb = protein2embedding.get(e3_ligase_uniprot, default_protein_emb)
48
+ cell_emb = cell2embedding.get(cell_line, default_cell_emb)
49
+ smiles_emb = get_fingerprint(protac_smiles)
50
+
51
+ # Convert to torch tensors
52
+ poi_emb = torch.tensor(poi_emb).to(device)
53
+ e3_emb = torch.tensor(e3_emb).to(device)
54
+ cell_emb = torch.tensor(cell_emb).to(device)
55
+ smiles_emb = torch.tensor(smiles_emb).to(device)
56
+
57
+ return model(poi_emb, e3_emb, cell_emb, smiles_emb).item()
58
+
59
+
60
+ def is_protac_active(
61
+ protac_smiles: str,
62
+ e3_ligase: str,
63
+ target_uniprot: str,
64
+ cell_line: str,
65
+ device: str = 'cpu',
66
+ proba_threshold: float = 0.5,
67
+ ) -> bool:
68
+ """ Predict whether a PROTAC is active or not.
69
+
70
+ Args:
71
+ protac_smiles (str): The SMILES of the PROTAC.
72
+ e3_ligase (str): The Uniprot ID of the E3 ligase.
73
+ target_uniprot (str): The Uniprot ID of the target protein.
74
+ cell_line (str): The cell line identifier.
75
+ device (str): The device to run the model on.
76
+ proba_threshold (float): The probability threshold.
77
+
78
+ Returns:
79
+ bool: Whether the PROTAC is active or not.
80
+ """
81
+ pred = get_protac_active_proba(
82
+ protac_smiles,
83
+ e3_ligase,
84
+ target_uniprot,
85
+ cell_line,
86
+ device,
87
+ )
88
+ return sigmoid(pred) > proba_threshold
protac_degradation_predictor/pytorch_models.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Literal, List, Tuple, Optional, Dict
3
+
4
+ from protac_dataset import PROTAC_Dataset
5
+ from config import Config
6
+
7
+ import pandas as pd
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.optim as optim
13
+ import pytorch_lightning as pl
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from torchmetrics import (
16
+ Accuracy,
17
+ AUROC,
18
+ Precision,
19
+ Recall,
20
+ F1Score,
21
+ MetricCollection,
22
+ )
23
+ from imblearn.over_sampling import SMOTE
24
+
25
+
26
+ class PROTAC_Predictor(nn.Module):
27
+
28
+ def __init__(
29
+ self,
30
+ hidden_dim: int,
31
+ smiles_emb_dim: int = Config.fingerprint_size,
32
+ poi_emb_dim: int = Config.protein_embedding_size,
33
+ e3_emb_dim: int = Config.protein_embedding_size,
34
+ cell_emb_dim: int = Config.cell_embedding_size,
35
+ dropout: float = 0.2,
36
+ join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
37
+ disabled_embeddings: list = [],
38
+ ):
39
+ """ Initialize the PROTAC model.
40
+
41
+ Args:
42
+ hidden_dim (int): The hidden dimension of the model
43
+ smiles_emb_dim (int): The dimension of the SMILES embeddings
44
+ poi_emb_dim (int): The dimension of the POI embeddings
45
+ e3_emb_dim (int): The dimension of the E3 Ligase embeddings
46
+ cell_emb_dim (int): The dimension of the cell line embeddings
47
+ dropout (float): The dropout rate
48
+ join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings
49
+ disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles'
50
+ """
51
+ super().__init__()
52
+ self.poi_emb_dim = poi_emb_dim
53
+ self.e3_emb_dim = e3_emb_dim
54
+ self.cell_emb_dim = cell_emb_dim
55
+ self.smiles_emb_dim = smiles_emb_dim
56
+ self.hidden_dim = hidden_dim
57
+ self.join_embeddings = join_embeddings
58
+ self.disabled_embeddings = disabled_embeddings
59
+ # Set our init args as class attributes
60
+ self.__dict__.update(locals())
61
+
62
+ # Define "surrogate models" branches
63
+ if self.join_embeddings != 'beginning':
64
+ if 'poi' not in self.disabled_embeddings:
65
+ self.poi_emb = nn.Linear(poi_emb_dim, hidden_dim)
66
+ if 'e3' not in self.disabled_embeddings:
67
+ self.e3_emb = nn.Linear(e3_emb_dim, hidden_dim)
68
+ if 'cell' not in self.disabled_embeddings:
69
+ self.cell_emb = nn.Linear(cell_emb_dim, hidden_dim)
70
+ if 'smiles' not in self.disabled_embeddings:
71
+ self.smiles_emb = nn.Linear(smiles_emb_dim, hidden_dim)
72
+
73
+ # Define hidden dimension for joining layer
74
+ if self.join_embeddings == 'beginning':
75
+ joint_dim = smiles_emb_dim if 'smiles' not in self.disabled_embeddings else 0
76
+ joint_dim += poi_emb_dim if 'poi' not in self.disabled_embeddings else 0
77
+ joint_dim += e3_emb_dim if 'e3' not in self.disabled_embeddings else 0
78
+ joint_dim += cell_emb_dim if 'cell' not in self.disabled_embeddings else 0
79
+ elif self.join_embeddings == 'concat':
80
+ joint_dim = hidden_dim * (4 - len(self.disabled_embeddings))
81
+ elif self.join_embeddings == 'sum':
82
+ joint_dim = hidden_dim
83
+
84
+ self.fc0 = nn.Linear(joint_dim, joint_dim)
85
+ self.fc1 = nn.Linear(joint_dim, hidden_dim)
86
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
87
+ self.fc3 = nn.Linear(hidden_dim, 1)
88
+
89
+ self.dropout = nn.Dropout(p=dropout)
90
+
91
+
92
+ def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
93
+ embeddings = []
94
+ if self.join_embeddings == 'beginning':
95
+ if 'poi' not in self.disabled_embeddings:
96
+ embeddings.append(poi_emb)
97
+ if 'e3' not in self.disabled_embeddings:
98
+ embeddings.append(e3_emb)
99
+ if 'cell' not in self.disabled_embeddings:
100
+ embeddings.append(cell_emb)
101
+ if 'smiles' not in self.disabled_embeddings:
102
+ embeddings.append(smiles_emb)
103
+ x = torch.cat(embeddings, dim=1)
104
+ x = self.dropout(F.relu(self.fc0(x)))
105
+ else:
106
+ if 'poi' not in self.disabled_embeddings:
107
+ embeddings.append(self.poi_emb(poi_emb))
108
+ if 'e3' not in self.disabled_embeddings:
109
+ embeddings.append(self.e3_emb(e3_emb))
110
+ if 'cell' not in self.disabled_embeddings:
111
+ embeddings.append(self.cell_emb(cell_emb))
112
+ if 'smiles' not in self.disabled_embeddings:
113
+ embeddings.append(self.smiles_emb(smiles_emb))
114
+ if self.join_embeddings == 'concat':
115
+ x = torch.cat(embeddings, dim=1)
116
+ elif self.join_embeddings == 'sum':
117
+ if len(embeddings) > 1:
118
+ embeddings = torch.stack(embeddings, dim=1)
119
+ x = torch.sum(embeddings, dim=1)
120
+ else:
121
+ x = embeddings[0]
122
+ x = self.dropout(F.relu(self.fc1(x)))
123
+ x = self.dropout(F.relu(self.fc2(x)))
124
+ x = self.fc3(x)
125
+ return x
126
+
127
+
128
+
129
+ class PROTAC_Model(pl.LightningModule):
130
+
131
+ def __init__(
132
+ self,
133
+ hidden_dim: int,
134
+ smiles_emb_dim: int = 224,
135
+ poi_emb_dim: int = 1024,
136
+ e3_emb_dim: int = 1024,
137
+ cell_emb_dim: int = 768,
138
+ batch_size: int = 32,
139
+ learning_rate: float = 1e-3,
140
+ dropout: float = 0.2,
141
+ join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
142
+ train_dataset: PROTAC_Dataset = None,
143
+ val_dataset: PROTAC_Dataset = None,
144
+ test_dataset: PROTAC_Dataset = None,
145
+ disabled_embeddings: list = [],
146
+ apply_scaling: bool = False,
147
+ ):
148
+ """ Initialize the PROTAC Pytorch Lightning model.
149
+
150
+ Args:
151
+ hidden_dim (int): The hidden dimension of the model
152
+ smiles_emb_dim (int): The dimension of the SMILES embeddings
153
+ poi_emb_dim (int): The dimension of the POI embeddings
154
+ e3_emb_dim (int): The dimension of the E3 Ligase embeddings
155
+ cell_emb_dim (int): The dimension of the cell line embeddings
156
+ batch_size (int): The batch size
157
+ learning_rate (float): The learning rate
158
+ dropout (float): The dropout rate
159
+ join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings
160
+ train_dataset (PROTAC_Dataset): The training dataset
161
+ val_dataset (PROTAC_Dataset): The validation dataset
162
+ test_dataset (PROTAC_Dataset): The test dataset
163
+ disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles'
164
+ apply_scaling (bool): Whether to apply scaling to the embeddings
165
+ """
166
+ super().__init__()
167
+ self.poi_emb_dim = poi_emb_dim
168
+ self.e3_emb_dim = e3_emb_dim
169
+ self.cell_emb_dim = cell_emb_dim
170
+ self.smiles_emb_dim = smiles_emb_dim
171
+ self.hidden_dim = hidden_dim
172
+ self.batch_size = batch_size
173
+ self.learning_rate = learning_rate
174
+ self.join_embeddings = join_embeddings
175
+ self.train_dataset = train_dataset
176
+ self.val_dataset = val_dataset
177
+ self.test_dataset = test_dataset
178
+ self.disabled_embeddings = disabled_embeddings
179
+ self.apply_scaling = apply_scaling
180
+ # Set our init args as class attributes
181
+ self.__dict__.update(locals()) # Add arguments as attributes
182
+ # Save the arguments passed to init
183
+ ignore_args_as_hyperparams = [
184
+ 'train_dataset',
185
+ 'test_dataset',
186
+ 'val_dataset',
187
+ ]
188
+ self.save_hyperparameters(ignore=ignore_args_as_hyperparams)
189
+
190
+ self.model = PROTAC_Predictor(
191
+ hidden_dim=hidden_dim,
192
+ smiles_emb_dim=smiles_emb_dim,
193
+ poi_emb_dim=poi_emb_dim,
194
+ e3_emb_dim=e3_emb_dim,
195
+ cell_emb_dim=cell_emb_dim,
196
+ dropout=dropout,
197
+ join_embeddings=join_embeddings,
198
+ disabled_embeddings=disabled_embeddings,
199
+ )
200
+
201
+ stages = ['train_metrics', 'val_metrics', 'test_metrics']
202
+ self.metrics = nn.ModuleDict({s: MetricCollection({
203
+ 'acc': Accuracy(task='binary'),
204
+ 'roc_auc': AUROC(task='binary'),
205
+ 'precision': Precision(task='binary'),
206
+ 'recall': Recall(task='binary'),
207
+ 'f1_score': F1Score(task='binary'),
208
+ 'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
209
+ 'hp_metric': Accuracy(task='binary'),
210
+ }, prefix=s.replace('metrics', '')) for s in stages})
211
+
212
+ # Misc settings
213
+ self.missing_dataset_error = \
214
+ '''Class variable `{0}` is None. If the model was loaded from a checkpoint, the dataset must be set manually:
215
+
216
+ model = {1}.load_from_checkpoint('checkpoint.ckpt')
217
+ model.{0} = my_{0}
218
+ '''
219
+
220
+ # Apply scaling in datasets
221
+ if self.apply_scaling:
222
+ use_single_scaler = True if self.join_embeddings == 'beginning' else False
223
+ self.scalers = self.train_dataset.fit_scaling(use_single_scaler)
224
+ self.train_dataset.apply_scaling(self.scalers, use_single_scaler)
225
+ self.val_dataset.apply_scaling(self.scalers, use_single_scaler)
226
+ if self.test_dataset:
227
+ self.test_dataset.apply_scaling(self.scalers, use_single_scaler)
228
+
229
+ def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
230
+ return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
231
+
232
+ def step(self, batch, batch_idx, stage):
233
+ poi_emb = batch['poi_emb']
234
+ e3_emb = batch['e3_emb']
235
+ cell_emb = batch['cell_emb']
236
+ smiles_emb = batch['smiles_emb']
237
+ y = batch['active'].float().unsqueeze(1)
238
+
239
+ y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
240
+ loss = F.binary_cross_entropy_with_logits(y_hat, y)
241
+
242
+ self.metrics[f'{stage}_metrics'].update(y_hat, y)
243
+ self.log(f'{stage}_loss', loss, on_epoch=True, prog_bar=True)
244
+ self.log_dict(self.metrics[f'{stage}_metrics'], on_epoch=True)
245
+
246
+ return loss
247
+
248
+ def training_step(self, batch, batch_idx):
249
+ return self.step(batch, batch_idx, 'train')
250
+
251
+ def validation_step(self, batch, batch_idx):
252
+ return self.step(batch, batch_idx, 'val')
253
+
254
+ def test_step(self, batch, batch_idx):
255
+ return self.step(batch, batch_idx, 'test')
256
+
257
+ def configure_optimizers(self):
258
+ return optim.Adam(self.parameters(), lr=self.learning_rate)
259
+
260
+ def predict_step(self, batch, batch_idx):
261
+ poi_emb = batch['poi_emb']
262
+ e3_emb = batch['e3_emb']
263
+ cell_emb = batch['cell_emb']
264
+ smiles_emb = batch['smiles_emb']
265
+
266
+ if self.apply_scaling:
267
+ if self.join_embeddings == 'beginning':
268
+ embeddings = np.hstack([
269
+ np.array(smiles_emb.tolist()),
270
+ np.array(poi_emb.tolist()),
271
+ np.array(e3_emb.tolist()),
272
+ np.array(cell_emb.tolist()),
273
+ ])
274
+ embeddings = self.scalers.transform(embeddings)
275
+ smiles_emb = embeddings[:, :self.smiles_emb_dim]
276
+ poi_emb = embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.poi_emb_dim]
277
+ e3_emb = embeddings[:, self.smiles_emb_dim+self.poi_emb_dim:self.smiles_emb_dim+2*self.poi_emb_dim]
278
+ cell_emb = embeddings[:, -self.cell_emb_dim:]
279
+ else:
280
+ poi_emb = self.scalers['Uniprot'].transform(poi_emb)
281
+ e3_emb = self.scalers['E3 Ligase Uniprot'].transform(e3_emb)
282
+ cell_emb = self.scalers['Cell Line Identifier'].transform(cell_emb)
283
+ smiles_emb = self.scalers['Smiles'].transform(smiles_emb)
284
+
285
+ y_hat = self.forward(poi_emb, e3_emb, cell_emb, smiles_emb)
286
+ return torch.sigmoid(y_hat)
287
+
288
+ def train_dataloader(self):
289
+ if self.train_dataset is None:
290
+ format = 'train_dataset', self.__class__.__name__
291
+ raise ValueError(self.missing_dataset_error.format(*format))
292
+
293
+ return DataLoader(
294
+ self.train_dataset,
295
+ batch_size=self.batch_size,
296
+ shuffle=True,
297
+ # drop_last=True,
298
+ )
299
+
300
+ def val_dataloader(self):
301
+ if self.val_dataset is None:
302
+ format = 'val_dataset', self.__class__.__name__
303
+ raise ValueError(self.missing_dataset_error.format(*format))
304
+ return DataLoader(
305
+ self.val_dataset,
306
+ batch_size=self.batch_size,
307
+ shuffle=False,
308
+ )
309
+
310
+ def test_dataloader(self):
311
+ if self.test_dataset is None:
312
+ format = 'test_dataset', self.__class__.__name__
313
+ raise ValueError(self.missing_dataset_error.format(*format))
314
+ return DataLoader(
315
+ self.test_dataset,
316
+ batch_size=self.batch_size,
317
+ shuffle=False,
318
+ )
319
+
320
+
321
+ def train_model(
322
+ protein2embedding: Dict,
323
+ cell2embedding: Dict,
324
+ smiles2fp: Dict,
325
+ train_df: pd.DataFrame,
326
+ val_df: pd.DataFrame,
327
+ test_df: Optional[pd.DataFrame] = None,
328
+ hidden_dim: int = 768,
329
+ batch_size: int = 8,
330
+ learning_rate: float = 2e-5,
331
+ dropout: float = 0.2,
332
+ max_epochs: int = 50,
333
+ smiles_emb_dim: int = 224,
334
+ join_embeddings: Literal['beginning', 'concat', 'sum'] = 'concat',
335
+ smote_k_neighbors:int = 5,
336
+ use_smote: bool = True,
337
+ apply_scaling: bool = False,
338
+ active_label: str = 'Active',
339
+ fast_dev_run: bool = False,
340
+ use_logger: bool = True,
341
+ logger_name: str = 'protac',
342
+ disabled_embeddings: List[str] = [],
343
+ ) -> tuple:
344
+ """ Train a PROTAC model using the given datasets and hyperparameters.
345
+
346
+ Args:
347
+ protein2embedding (dict): Dictionary of protein embeddings.
348
+ cell2embedding (dict): Dictionary of cell line embeddings.
349
+ smiles2fp (dict): Dictionary of SMILES to fingerprint.
350
+ train_df (pd.DataFrame): The training set. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', <active_label>.
351
+ val_df (pd.DataFrame): The validation set. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', <active_label>.
352
+ test_df (pd.DataFrame): The test set. If provided, the returned metrics will include test performance. It must include the following columns: 'Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier', <active_label>.
353
+ hidden_dim (int): The hidden dimension of the model.
354
+ batch_size (int): The batch size.
355
+ learning_rate (float): The learning rate.
356
+ max_epochs (int): The maximum number of epochs.
357
+ smiles_emb_dim (int): The dimension of the SMILES embeddings.
358
+ smote_k_neighbors (int): The number of neighbors for the SMOTE oversampler.
359
+ fast_dev_run (bool): Whether to run a fast development run.
360
+ disabled_embeddings (list): The list of disabled embeddings.
361
+
362
+ Returns:
363
+ tuple: The trained model, the trainer, and the metrics.
364
+ """
365
+ oversampler = SMOTE(k_neighbors=smote_k_neighbors, random_state=42)
366
+ train_ds = PROTAC_Dataset(
367
+ train_df,
368
+ protein2embedding,
369
+ cell2embedding,
370
+ smiles2fp,
371
+ use_smote=use_smote,
372
+ oversampler=oversampler if use_smote else None,
373
+ active_label=active_label,
374
+ )
375
+ val_ds = PROTAC_Dataset(
376
+ val_df,
377
+ protein2embedding,
378
+ cell2embedding,
379
+ smiles2fp,
380
+ active_label=active_label,
381
+ )
382
+ if test_df is not None:
383
+ test_ds = PROTAC_Dataset(
384
+ test_df,
385
+ protein2embedding,
386
+ cell2embedding,
387
+ smiles2fp,
388
+ active_label=active_label,
389
+ )
390
+ logger = pl.loggers.TensorBoardLogger(
391
+ save_dir='../logs',
392
+ name=logger_name,
393
+ )
394
+ callbacks = [
395
+ pl.callbacks.EarlyStopping(
396
+ monitor='train_loss',
397
+ patience=10,
398
+ mode='min',
399
+ verbose=False,
400
+ ),
401
+ pl.callbacks.EarlyStopping(
402
+ monitor='val_loss',
403
+ patience=5,
404
+ mode='min',
405
+ verbose=False,
406
+ ),
407
+ pl.callbacks.EarlyStopping(
408
+ monitor='val_acc',
409
+ patience=10,
410
+ mode='max',
411
+ verbose=False,
412
+ ),
413
+ # pl.callbacks.ModelCheckpoint(
414
+ # monitor='val_acc',
415
+ # mode='max',
416
+ # verbose=True,
417
+ # filename='{epoch}-{val_metrics_opt_score:.4f}',
418
+ # ),
419
+ ]
420
+ # Define Trainer
421
+ trainer = pl.Trainer(
422
+ logger=logger if use_logger else False,
423
+ callbacks=callbacks,
424
+ max_epochs=max_epochs,
425
+ fast_dev_run=fast_dev_run,
426
+ enable_model_summary=False,
427
+ enable_checkpointing=False,
428
+ enable_progress_bar=False,
429
+ devices=1,
430
+ num_nodes=1,
431
+ )
432
+ model = PROTAC_Model(
433
+ hidden_dim=hidden_dim,
434
+ smiles_emb_dim=smiles_emb_dim,
435
+ poi_emb_dim=1024,
436
+ e3_emb_dim=1024,
437
+ cell_emb_dim=768,
438
+ batch_size=batch_size,
439
+ join_embeddings=join_embeddings,
440
+ dropout=dropout,
441
+ learning_rate=learning_rate,
442
+ apply_scaling=apply_scaling,
443
+ train_dataset=train_ds,
444
+ val_dataset=val_ds,
445
+ test_dataset=test_ds if test_df is not None else None,
446
+ disabled_embeddings=disabled_embeddings,
447
+ )
448
+ with warnings.catch_warnings():
449
+ warnings.simplefilter("ignore")
450
+ trainer.fit(model)
451
+ metrics = trainer.validate(model, verbose=False)[0]
452
+ if test_df is not None:
453
+ test_metrics = trainer.test(model, verbose=False)[0]
454
+ metrics.update(test_metrics)
455
+ return model, trainer, metrics
456
+
457
+
458
+ def load_model(
459
+ ckpt_path: str,
460
+ ) -> PROTAC_Model:
461
+ """ Load a PROTAC model from a checkpoint.
462
+
463
+ Args:
464
+ ckpt_path (str): The path to the checkpoint.
465
+
466
+ Returns:
467
+ PROTAC_Model: The loaded model.
468
+ """
469
+ model = PROTAC_Model.load_from_checkpoint(ckpt_path)
470
+ model.eval()
471
+ return model
protac_degradation_predictor/sklearn_models.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, List, Tuple, Optional, Dict
2
+
3
+ from protac_dataset import PROTAC_Dataset
4
+
5
+ import pandas as pd
6
+ from sklearn.base import ClassifierMixin
7
+ from sklearn.ensemble import (
8
+ RandomForestClassifier,
9
+ GradientBoostingClassifier,
10
+ )
11
+ from sklearn.linear_model import LogisticRegression
12
+ from sklearn.svm import SVC
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torchmetrics import (
17
+ Accuracy,
18
+ AUROC,
19
+ Precision,
20
+ Recall,
21
+ F1Score,
22
+ MetricCollection,
23
+ )
24
+ import optuna
25
+
26
+
27
+ def train_sklearn_model(
28
+ clf: ClassifierMixin,
29
+ protein2embedding: Dict,
30
+ cell2embedding: Dict,
31
+ smiles2fp: Dict,
32
+ train_df: pd.DataFrame,
33
+ val_df: pd.DataFrame,
34
+ test_df: Optional[pd.DataFrame] = None,
35
+ active_label: str = 'Active',
36
+ use_single_scaler: bool = True,
37
+ ) -> Tuple[ClassifierMixin, Dict]:
38
+ """ Train a classifier model on train and val sets and evaluate it on a test set.
39
+
40
+ Args:
41
+ clf: The classifier model to train and evaluate.
42
+ train_df (pd.DataFrame): The training set.
43
+ val_df (pd.DataFrame): The validation set.
44
+ test_df (Optional[pd.DataFrame]): The test set.
45
+
46
+ Returns:
47
+ Tuple[ClassifierMixin, nn.ModuleDict]: The trained model and the metrics.
48
+ """
49
+ # Initialize the datasets
50
+ train_ds = PROTAC_Dataset(
51
+ train_df,
52
+ protein2embedding,
53
+ cell2embedding,
54
+ smiles2fp,
55
+ active_label=active_label,
56
+ use_smote=False,
57
+ )
58
+ scaler = train_ds.fit_scaling(use_single_scaler=use_single_scaler)
59
+ train_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
60
+ val_ds = PROTAC_Dataset(
61
+ val_df,
62
+ protein2embedding,
63
+ cell2embedding,
64
+ smiles2fp,
65
+ active_label=active_label,
66
+ use_smote=False,
67
+ )
68
+ val_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
69
+ if test_df is not None:
70
+ test_ds = PROTAC_Dataset(
71
+ test_df,
72
+ protein2embedding,
73
+ cell2embedding,
74
+ smiles2fp,
75
+ active_label=active_label,
76
+ use_smote=False,
77
+ )
78
+ test_ds.apply_scaling(scaler, use_single_scaler=use_single_scaler)
79
+
80
+ # Get the numpy arrays
81
+ X_train, y_train = train_ds.get_numpy_arrays()
82
+ X_val, y_val = val_ds.get_numpy_arrays()
83
+ if test_df is not None:
84
+ X_test, y_test = test_ds.get_numpy_arrays()
85
+
86
+ # Train the model
87
+ clf.fit(X_train, y_train)
88
+ # Define the metrics as a module dict
89
+ stages = ['train_metrics', 'val_metrics', 'test_metrics']
90
+ metrics = nn.ModuleDict({s: MetricCollection({
91
+ 'acc': Accuracy(task='binary'),
92
+ 'roc_auc': AUROC(task='binary'),
93
+ 'precision': Precision(task='binary'),
94
+ 'recall': Recall(task='binary'),
95
+ 'f1_score': F1Score(task='binary'),
96
+ 'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
97
+ 'hp_metric': Accuracy(task='binary'),
98
+ }, prefix=s.replace('metrics', '')) for s in stages})
99
+
100
+ # Get the predictions
101
+ metrics_out = {}
102
+
103
+ y_pred = torch.tensor(clf.predict_proba(X_train)[:, 1])
104
+ y_true = torch.tensor(y_train)
105
+ metrics['train_metrics'].update(y_pred, y_true)
106
+ metrics_out.update(metrics['train_metrics'].compute())
107
+
108
+ y_pred = torch.tensor(clf.predict_proba(X_val)[:, 1])
109
+ y_true = torch.tensor(y_val)
110
+ metrics['val_metrics'].update(y_pred, y_true)
111
+ metrics_out.update(metrics['val_metrics'].compute())
112
+
113
+ if test_df is not None:
114
+ y_pred = torch.tensor(clf.predict_proba(X_test)[:, 1])
115
+ y_true = torch.tensor(y_test)
116
+ metrics['test_metrics'].update(y_pred, y_true)
117
+ metrics_out.update(metrics['test_metrics'].compute())
118
+
119
+ return clf, metrics_out
120
+
121
+
122
+ def suggest_random_forest(
123
+ trial: optuna.Trial,
124
+ ) -> ClassifierMixin:
125
+ """ Suggest hyperparameters for a Random Forest classifier.
126
+
127
+ Args:
128
+ trial (optuna.Trial): The Optuna trial object.
129
+
130
+ Returns:
131
+ ClassifierMixin: The Random Forest classifier with the suggested hyperparameters.
132
+ """
133
+ n_estimators = trial.suggest_int('model_n_estimators', 10, 1000)
134
+ max_depth = trial.suggest_int('model_max_depth', 2, 100)
135
+ min_samples_split = trial.suggest_int('model_min_samples_split', 2, 10)
136
+ min_samples_leaf = trial.suggest_int('model_min_samples_leaf', 1, 10)
137
+ max_features = trial.suggest_categorical('model_max_features', [None, 'sqrt', 'log2'])
138
+ criterion = trial.suggest_categorical('model_criterion', ['gini', 'entropy'])
139
+
140
+ clf = RandomForestClassifier(
141
+ n_estimators=n_estimators,
142
+ max_depth=max_depth,
143
+ min_samples_split=min_samples_split,
144
+ min_samples_leaf=min_samples_leaf,
145
+ max_features=max_features,
146
+ criterion=criterion,
147
+ random_state=42,
148
+ )
149
+
150
+ return clf
151
+
152
+
153
+ def suggest_logistic_regression(
154
+ trial: optuna.Trial,
155
+ ) -> ClassifierMixin:
156
+ """ Suggest hyperparameters for a Logistic Regression classifier.
157
+
158
+ Args:
159
+ trial (optuna.Trial): The Optuna trial object.
160
+
161
+ Returns:
162
+ ClassifierMixin: The Logistic Regression classifier with the suggested hyperparameters.
163
+ """
164
+ # Suggest values for the logistic regression hyperparameters
165
+ C = trial.suggest_loguniform('model_C', 1e-4, 1e2)
166
+ penalty = trial.suggest_categorical('model_penalty', ['l1', 'l2', 'elasticnet', None])
167
+ solver = trial.suggest_categorical('model_solver', ['newton-cholesky', 'lbfgs', 'liblinear', 'sag', 'saga'])
168
+
169
+ # Check solver compatibility
170
+ if penalty == 'l1' and solver not in ['liblinear', 'saga']:
171
+ raise optuna.exceptions.TrialPruned()
172
+ if penalty == None and solver not in ['newton-cholesky', 'lbfgs', 'sag']:
173
+ raise optuna.exceptions.TrialPruned()
174
+
175
+ # Configure the classifier with the trial's suggested parameters
176
+ clf = LogisticRegression(
177
+ C=C,
178
+ penalty=penalty,
179
+ solver=solver,
180
+ max_iter=1000,
181
+ random_state=42,
182
+ )
183
+
184
+ return clf
185
+
186
+
187
+ def suggest_svc(
188
+ trial: optuna.Trial,
189
+ ) -> ClassifierMixin:
190
+ """ Suggest hyperparameters for an SVC classifier.
191
+
192
+ Args:
193
+ trial (optuna.Trial): The Optuna trial object.
194
+
195
+ Returns:
196
+ ClassifierMixin: The SVC classifier with the suggested hyperparameters.
197
+ """
198
+ C = trial.suggest_loguniform('model_C', 1e-4, 1e2)
199
+ kernel = trial.suggest_categorical('model_kernel', ['linear', 'poly', 'rbf', 'sigmoid'])
200
+ gamma = trial.suggest_categorical('model_gamma', ['scale', 'auto'])
201
+ degree = trial.suggest_int('model_degree', 2, 5) if kernel == 'poly' else 3
202
+
203
+ clf = SVC(
204
+ C=C,
205
+ kernel=kernel,
206
+ gamma=gamma,
207
+ degree=degree,
208
+ probability=True,
209
+ random_state=42,
210
+ )
211
+
212
+ return clf
213
+
214
+
215
+ def suggest_gradient_boosting(
216
+ trial: optuna.Trial,
217
+ ) -> ClassifierMixin:
218
+ """ Suggest hyperparameters for a Gradient Boosting classifier.
219
+
220
+ Args:
221
+ trial (optuna.Trial): The Optuna trial object.
222
+
223
+ Returns:
224
+ ClassifierMixin: The Gradient Boosting classifier with the suggested hyperparameters.
225
+ """
226
+ n_estimators = trial.suggest_int('model_n_estimators', 50, 500)
227
+ learning_rate = trial.suggest_loguniform('model_learning_rate', 0.01, 1)
228
+ max_depth = trial.suggest_int('model_max_depth', 3, 10)
229
+ min_samples_split = trial.suggest_int('model_min_samples_split', 2, 10)
230
+ min_samples_leaf = trial.suggest_int('model_min_samples_leaf', 1, 10)
231
+ max_features = trial.suggest_categorical('model_max_features', ['sqrt', 'log2', None])
232
+
233
+ clf = GradientBoostingClassifier(
234
+ n_estimators=n_estimators,
235
+ learning_rate=learning_rate,
236
+ max_depth=max_depth,
237
+ min_samples_split=min_samples_split,
238
+ min_samples_leaf=min_samples_leaf,
239
+ max_features=max_features,
240
+ random_state=42,
241
+ )
242
+
243
+ return clf
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_0_test_split_0.1.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:27aa74ab92272f4c455f8eb32e7da3d3b71e213937f9a96a7d07e3ca61af06fb
3
  size 45164
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d2f036ce141fbeb81930cc9ce49dbd6effc76221b26b92ae0498af1c34289f3
3
  size 45164
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_1_test_split_0.1.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e8b8cd9536d7d1fab755506dc6cddf5d0d66ae3743f3a687dc5e82d44b134cd7
3
  size 45164
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8ce36b5f52f8f88105c3ec0c5b60f865e1b054aff8f9e96c21f1e037eaa65af
3
  size 45164
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_2_test_split_0.1.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2503fafc20807b89d6663a978a1a52d923cd36a43e09b2cd5761bfd42c505942
3
  size 45164
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d22f4f54d46ca72b8585645fdfac43683a23dcc00d80fb8bd1f785d4eb4a9594
3
  size 45164
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_2_test_split_0.2.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f7298a04d0888f4de87041efd6b78c42e13c3f1630c43567d582bc7710a40847
3
- size 45164
 
 
 
 
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_3_test_split_0.1.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ac2296caf215b91d3afbf2f31a625cdade8a1557f791e43350fe86e04373c6f7
3
  size 45164
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54f869b328af6667567bd5cc805ce63fc5434ed1b77afc1e66d95b8f02e40642
3
  size 45164
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_3_test_split_0.2.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d2a7a9f6ed11e1b5b6f876dd927612d03f4780f9db3e65b9f1ebb8fbd853677f
3
- size 45164
 
 
 
 
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_4_test_split_0.1.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:994f1b499f359d0fb32eb708054826edecd317862cff181455fae8040330f6b9
3
  size 45164
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27f0e76e7f89950199c843699c000eaad8628441c84aec394e20c23f701b1609
3
  size 45164
reports/study_Active_Dmax_0.6_pDC50_6.0_tanimoto_fold_4_test_split_0.2.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:fd4c40033da16a1cee16fd998c82e0403a31db4d14ba0604160ea143bae03668
3
- size 45164
 
 
 
 
setup.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import setuptools
2
+
3
+ setuptools.setup(
4
+ name="protac_degradation_predictor",
5
+ version="0.0.1",
6
+ author="Stefano Ribes",
7
+ url="https://github.com/ribesstefano/PROTAC-Degradation-Predictor",
8
+ author_email="ribes.stefano@gmail.com",
9
+ description="A package to predict PROTAC-induced protein degradation.",
10
+ long_description=open("README.md").read(),
11
+ packages=setuptools.find_packages(),
12
+ install_requires=["torch", "pytorch_lightning", "sklearn", "imblearn", "pandas", "joblib", "h5py", "optuna", "torchmetrics"],
13
+ classifiers=[
14
+ "Programming Language :: Python :: 3",
15
+ "Programming Language :: Python :: 3.6",
16
+ "License :: OSI Approved :: MIT License",
17
+ "Operating System :: OS Independent",
18
+ ],
19
+ include_package_data=True,
20
+ package_data={"": ["data/*.h5", "data/*.pkl", "data/*.csv"]},
21
+ )