ribesstefano
commited on
Commit
•
33f1644
1
Parent(s):
4bf0ec2
Added XGBoost Optuna training + Added Ablation studies with zeroed input vectors
Browse files- protac_degradation_predictor/__init__.py +3 -1
- protac_degradation_predictor/optuna_utils.py +12 -0
- protac_degradation_predictor/optuna_utils_xgboost.py +323 -0
- protac_degradation_predictor/protac_dataset.py +31 -4
- reports/ablation_zero_vectors_report_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_random.csv +29 -0
- reports/ablation_zero_vectors_report_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_tanimoto.csv +29 -0
- reports/ablation_zero_vectors_report_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_uniprot.csv +29 -0
- src/run_xgboost_experiments.py +329 -0
protac_degradation_predictor/__init__.py
CHANGED
@@ -17,7 +17,9 @@ from .sklearn_models import (
|
|
17 |
)
|
18 |
from .optuna_utils import (
|
19 |
hyperparameter_tuning_and_training,
|
20 |
-
|
|
|
|
|
21 |
)
|
22 |
from .protac_degradation_predictor import (
|
23 |
get_protac_active_proba,
|
|
|
17 |
)
|
18 |
from .optuna_utils import (
|
19 |
hyperparameter_tuning_and_training,
|
20 |
+
)
|
21 |
+
from .optuna_utils_xgboost import (
|
22 |
+
xgboost_hyperparameter_tuning_and_training,
|
23 |
)
|
24 |
from .protac_degradation_predictor import (
|
25 |
get_protac_active_proba,
|
protac_degradation_predictor/optuna_utils.py
CHANGED
@@ -234,6 +234,18 @@ def pytorch_model_objective(
|
|
234 |
|
235 |
# Optuna aims to minimize the pytorch_model_objective
|
236 |
return - val_roc_auc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
|
239 |
def hyperparameter_tuning_and_training(
|
|
|
234 |
|
235 |
# Optuna aims to minimize the pytorch_model_objective
|
236 |
return - val_roc_auc
|
237 |
+
# # Get the majority vote for the test predictions
|
238 |
+
# if test_df is not None and not fast_dev_run:
|
239 |
+
# majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
|
240 |
+
# majority_vote_metrics.update(get_dataframe_stats(train_df, val_df, test_df, active_label))
|
241 |
+
# trial.set_user_attr('majority_vote_metrics', majority_vote_metrics)
|
242 |
+
# logging.info(f'Majority vote metrics: {majority_vote_metrics}')
|
243 |
+
|
244 |
+
# # Get the average validation accuracy and ROC AUC accross the folds
|
245 |
+
# val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
|
246 |
+
|
247 |
+
# # Optuna aims to minimize the pytorch_model_objective
|
248 |
+
# return - val_roc_auc
|
249 |
|
250 |
|
251 |
def hyperparameter_tuning_and_training(
|
protac_degradation_predictor/optuna_utils_xgboost.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Dict
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
|
5 |
+
from .optuna_utils import get_majority_vote_metrics, get_dataframe_stats
|
6 |
+
from .protac_dataset import get_datasets
|
7 |
+
|
8 |
+
import optuna
|
9 |
+
import xgboost as xgb
|
10 |
+
import pandas as pd
|
11 |
+
import numpy as np
|
12 |
+
from sklearn.model_selection import StratifiedKFold
|
13 |
+
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
|
14 |
+
import xgboost as xgb
|
15 |
+
import pandas as pd
|
16 |
+
import numpy as np
|
17 |
+
from sklearn.model_selection import StratifiedKFold
|
18 |
+
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
|
19 |
+
import joblib
|
20 |
+
from optuna.samplers import TPESampler
|
21 |
+
import torch
|
22 |
+
|
23 |
+
|
24 |
+
xgb.set_config(verbosity=0)
|
25 |
+
|
26 |
+
|
27 |
+
def train_and_evaluate_xgboost(
|
28 |
+
protein2embedding: Dict,
|
29 |
+
cell2embedding: Dict,
|
30 |
+
smiles2fp: Dict,
|
31 |
+
train_df: pd.DataFrame,
|
32 |
+
val_df: pd.DataFrame,
|
33 |
+
params: dict,
|
34 |
+
test_df: Optional[pd.DataFrame] = None,
|
35 |
+
active_label: str = 'Active',
|
36 |
+
num_boost_round: int = 100,
|
37 |
+
shuffle_train_data: bool = False,
|
38 |
+
) -> tuple:
|
39 |
+
"""
|
40 |
+
Train and evaluate an XGBoost model with the given parameters.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
train_df (pd.DataFrame): The training and validation data.
|
44 |
+
test_df (pd.DataFrame): The test data.
|
45 |
+
params (dict): Hyperparameters for the XGBoost model.
|
46 |
+
active_label (str): The active label column.
|
47 |
+
num_boost_round (int): Maximum number of epochs.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
tuple: The trained model, test predictions, and metrics.
|
51 |
+
"""
|
52 |
+
# Get datasets and their numpy arrays
|
53 |
+
train_ds, val_ds, test_ds = get_datasets(
|
54 |
+
protein2embedding=protein2embedding,
|
55 |
+
cell2embedding=cell2embedding,
|
56 |
+
smiles2fp=smiles2fp,
|
57 |
+
train_df=train_df,
|
58 |
+
val_df=val_df,
|
59 |
+
test_df=test_df,
|
60 |
+
disabled_embeddings=[],
|
61 |
+
active_label=active_label,
|
62 |
+
apply_scaling=False,
|
63 |
+
)
|
64 |
+
X_train, y_train = train_ds.get_numpy_arrays()
|
65 |
+
X_val, y_val = val_ds.get_numpy_arrays()
|
66 |
+
|
67 |
+
# Shuffle the training data
|
68 |
+
if shuffle_train_data:
|
69 |
+
idx = np.random.permutation(len(X_train))
|
70 |
+
X_train, y_train = X_train[idx], y_train[idx]
|
71 |
+
|
72 |
+
# Setup training and validation data in XGBoost data format
|
73 |
+
dtrain = xgb.DMatrix(X_train, label=y_train)
|
74 |
+
dval = xgb.DMatrix(X_val, label=y_val)
|
75 |
+
evallist = [(dval, 'eval'), (dtrain, 'train')]
|
76 |
+
|
77 |
+
# Setup test data
|
78 |
+
if test_df is not None:
|
79 |
+
X_test, y_test = test_ds.get_numpy_arrays()
|
80 |
+
dtest = xgb.DMatrix(X_test, label=y_test)
|
81 |
+
evallist.append((dtest, 'test'))
|
82 |
+
|
83 |
+
model = xgb.train(
|
84 |
+
params,
|
85 |
+
dtrain,
|
86 |
+
num_boost_round=num_boost_round,
|
87 |
+
evals=evallist,
|
88 |
+
early_stopping_rounds=10,
|
89 |
+
verbose_eval=False,
|
90 |
+
)
|
91 |
+
|
92 |
+
# Evaluate model
|
93 |
+
val_pred = model.predict(dval)
|
94 |
+
val_pred_binary = (val_pred > 0.5).astype(int)
|
95 |
+
metrics = {
|
96 |
+
'val_accuracy': accuracy_score(y_val, val_pred_binary),
|
97 |
+
'val_roc_auc': roc_auc_score(y_val, val_pred),
|
98 |
+
'val_precision': precision_score(y_val, val_pred_binary),
|
99 |
+
'val_recall': recall_score(y_val, val_pred_binary),
|
100 |
+
'val_f1_score': f1_score(y_val, val_pred_binary),
|
101 |
+
}
|
102 |
+
preds = {'val_pred': val_pred}
|
103 |
+
|
104 |
+
if test_df is not None:
|
105 |
+
test_pred = model.predict(dtest)
|
106 |
+
test_pred_binary = (test_pred > 0.5).astype(int)
|
107 |
+
metrics.update({
|
108 |
+
'test_accuracy': accuracy_score(y_test, test_pred_binary),
|
109 |
+
'test_roc_auc': roc_auc_score(y_test, test_pred),
|
110 |
+
'test_precision': precision_score(y_test, test_pred_binary),
|
111 |
+
'test_recall': recall_score(y_test, test_pred_binary),
|
112 |
+
'test_f1_score': f1_score(y_test, test_pred_binary),
|
113 |
+
})
|
114 |
+
preds.update({'test_pred': test_pred})
|
115 |
+
|
116 |
+
return model, preds, metrics
|
117 |
+
|
118 |
+
|
119 |
+
def xgboost_model_objective(
|
120 |
+
trial: optuna.Trial,
|
121 |
+
protein2embedding: Dict,
|
122 |
+
cell2embedding: Dict,
|
123 |
+
smiles2fp: Dict,
|
124 |
+
train_val_df: pd.DataFrame,
|
125 |
+
kf: StratifiedKFold,
|
126 |
+
groups: Optional[np.array] = None,
|
127 |
+
active_label: str = 'Active',
|
128 |
+
num_boost_round: int = 100,
|
129 |
+
) -> float:
|
130 |
+
""" Objective function for hyperparameter optimization with XGBoost.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
trial (optuna.Trial): The Optuna trial object.
|
134 |
+
train_val_df (pd.DataFrame): The training and validation data.
|
135 |
+
kf (StratifiedKFold): Stratified K-Folds cross-validator.
|
136 |
+
test_df (Optional[pd.DataFrame]): The test data.
|
137 |
+
active_label (str): The active label column.
|
138 |
+
num_boost_round (int): Maximum number of epochs.
|
139 |
+
use_logger (bool): Whether to use logging.
|
140 |
+
"""
|
141 |
+
# Suggest hyperparameters to be used across the CV folds
|
142 |
+
params = {
|
143 |
+
'booster': 'gbtree',
|
144 |
+
'tree_method': 'hist', # if torch.cuda.is_available() else 'hist',
|
145 |
+
'objective': 'binary:logistic',
|
146 |
+
'eval_metric': 'auc',
|
147 |
+
'eta': trial.suggest_float('eta', 1e-4, 1e-1, log=True),
|
148 |
+
'max_depth': trial.suggest_int('max_depth', 3, 10),
|
149 |
+
'min_child_weight': trial.suggest_float('min_child_weight', 1e-3, 10.0, log=True),
|
150 |
+
'gamma': trial.suggest_float('gamma', 1e-4, 1e-1, log=True),
|
151 |
+
'subsample': trial.suggest_float('subsample', 0.5, 1.0),
|
152 |
+
'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
|
153 |
+
}
|
154 |
+
|
155 |
+
X = train_val_df.copy().drop(columns=active_label)
|
156 |
+
y = train_val_df[active_label].tolist()
|
157 |
+
report = []
|
158 |
+
val_preds = []
|
159 |
+
|
160 |
+
for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
|
161 |
+
logging.info(f'Fold {k + 1}/{kf.get_n_splits()}')
|
162 |
+
train_df = train_val_df.iloc[train_index]
|
163 |
+
val_df = train_val_df.iloc[val_index]
|
164 |
+
|
165 |
+
# Get some statistics from the dataframes
|
166 |
+
stats = {
|
167 |
+
'model_type': 'XGBoost',
|
168 |
+
'fold': k,
|
169 |
+
'train_len': len(train_df),
|
170 |
+
'val_len': len(val_df),
|
171 |
+
'train_perc': len(train_df) / len(train_val_df),
|
172 |
+
'val_perc': len(val_df) / len(train_val_df),
|
173 |
+
}
|
174 |
+
stats.update(get_dataframe_stats(train_df, val_df, active_label=active_label))
|
175 |
+
if groups is not None:
|
176 |
+
stats['train_unique_groups'] = len(np.unique(groups[train_index]))
|
177 |
+
stats['val_unique_groups'] = len(np.unique(groups[val_index]))
|
178 |
+
|
179 |
+
_, preds, metrics = train_and_evaluate_xgboost(
|
180 |
+
protein2embedding=protein2embedding,
|
181 |
+
cell2embedding=cell2embedding,
|
182 |
+
smiles2fp=smiles2fp,
|
183 |
+
train_df=train_df,
|
184 |
+
val_df=val_df,
|
185 |
+
params=params,
|
186 |
+
active_label=active_label,
|
187 |
+
num_boost_round=num_boost_round,
|
188 |
+
)
|
189 |
+
stats.update(metrics)
|
190 |
+
report.append(stats.copy())
|
191 |
+
val_preds.append(preds['val_pred'])
|
192 |
+
|
193 |
+
# Save the report in the trial
|
194 |
+
trial.set_user_attr('report', report)
|
195 |
+
trial.set_user_attr('val_preds', val_preds)
|
196 |
+
trial.set_user_attr('params', params)
|
197 |
+
|
198 |
+
# Get the average validation metrics across the folds
|
199 |
+
mean_val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
|
200 |
+
logging.info(f'\tMean val ROC AUC: {mean_val_roc_auc:.4f}')
|
201 |
+
|
202 |
+
# Optuna aims to minimize the objective, so return the negative ROC AUC
|
203 |
+
return -mean_val_roc_auc
|
204 |
+
|
205 |
+
|
206 |
+
def xgboost_hyperparameter_tuning_and_training(
|
207 |
+
protein2embedding: Dict,
|
208 |
+
cell2embedding: Dict,
|
209 |
+
smiles2fp: Dict,
|
210 |
+
train_val_df: pd.DataFrame,
|
211 |
+
test_df: pd.DataFrame,
|
212 |
+
kf: StratifiedKFold,
|
213 |
+
groups: Optional[np.array] = None,
|
214 |
+
split_type: str = 'random',
|
215 |
+
n_models_for_test: int = 3,
|
216 |
+
n_trials: int = 50,
|
217 |
+
active_label: str = 'Active',
|
218 |
+
num_boost_round: int = 100,
|
219 |
+
study_filename: Optional[str] = None,
|
220 |
+
force_study: bool = False,
|
221 |
+
) -> dict:
|
222 |
+
""" Hyperparameter tuning and training of an XGBoost model.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
train_val_df (pd.DataFrame): The training and validation data.
|
226 |
+
test_df (pd.DataFrame): The test data.
|
227 |
+
kf (StratifiedKFold): Stratified K-Folds cross-validator.
|
228 |
+
groups (Optional[np.array]): Group labels for the samples used while splitting the dataset into train/test set.
|
229 |
+
split_type (str): Type of the data split.
|
230 |
+
n_models_for_test (int): Number of models to train for testing.
|
231 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
232 |
+
n_trials (int): Number of trials for hyperparameter optimization.
|
233 |
+
logger_save_dir (str): Directory to save logs.
|
234 |
+
logger_name (str): Name of the logger.
|
235 |
+
active_label (str): The active label column.
|
236 |
+
num_boost_round (int): Maximum number of epochs.
|
237 |
+
study_filename (Optional[str]): File name to save/load the Optuna study.
|
238 |
+
force_study (bool): Whether to force the study optimization even if the study file exists.
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
dict: A dictionary containing reports from the CV and test.
|
242 |
+
"""
|
243 |
+
# Set the verbosity of Optuna
|
244 |
+
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
245 |
+
|
246 |
+
# Create an Optuna study object
|
247 |
+
sampler = TPESampler(seed=42)
|
248 |
+
study = optuna.create_study(direction='minimize', sampler=sampler)
|
249 |
+
|
250 |
+
study_loaded = False
|
251 |
+
if study_filename and not force_study:
|
252 |
+
if os.path.exists(study_filename):
|
253 |
+
study = joblib.load(study_filename)
|
254 |
+
study_loaded = True
|
255 |
+
logging.info(f'Loaded study from {study_filename}')
|
256 |
+
|
257 |
+
if not study_loaded or force_study:
|
258 |
+
study.optimize(
|
259 |
+
lambda trial: xgboost_model_objective(
|
260 |
+
trial=trial,
|
261 |
+
protein2embedding=protein2embedding,
|
262 |
+
cell2embedding=cell2embedding,
|
263 |
+
smiles2fp=smiles2fp,
|
264 |
+
train_val_df=train_val_df,
|
265 |
+
kf=kf,
|
266 |
+
groups=groups,
|
267 |
+
active_label=active_label,
|
268 |
+
num_boost_round=num_boost_round,
|
269 |
+
),
|
270 |
+
n_trials=n_trials,
|
271 |
+
)
|
272 |
+
if study_filename:
|
273 |
+
joblib.dump(study, study_filename)
|
274 |
+
|
275 |
+
cv_report = pd.DataFrame(study.best_trial.user_attrs['report'])
|
276 |
+
hparam_report = pd.DataFrame([study.best_params])
|
277 |
+
|
278 |
+
# Retrain N models with the best hyperparameters (measure model uncertainty)
|
279 |
+
best_models = []
|
280 |
+
test_report = []
|
281 |
+
test_preds = []
|
282 |
+
for i in range(n_models_for_test):
|
283 |
+
logging.info(f'Training best model {i + 1}/{n_models_for_test}')
|
284 |
+
model, preds, metrics = train_and_evaluate_xgboost(
|
285 |
+
protein2embedding=protein2embedding,
|
286 |
+
cell2embedding=cell2embedding,
|
287 |
+
smiles2fp=smiles2fp,
|
288 |
+
train_df=train_val_df,
|
289 |
+
val_df=test_df,
|
290 |
+
params=study.best_trial.user_attrs['params'],
|
291 |
+
active_label=active_label,
|
292 |
+
num_boost_round=num_boost_round,
|
293 |
+
shuffle_train_data=True,
|
294 |
+
)
|
295 |
+
metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
|
296 |
+
metrics['model_type'] = 'XGBoost'
|
297 |
+
metrics['test_model_id'] = i
|
298 |
+
metrics.update(get_dataframe_stats(
|
299 |
+
train_val_df,
|
300 |
+
test_df=test_df,
|
301 |
+
active_label=active_label,
|
302 |
+
))
|
303 |
+
test_report.append(metrics.copy())
|
304 |
+
test_preds.append(torch.tensor(preds['val_pred']))
|
305 |
+
best_models.append(model)
|
306 |
+
test_report = pd.DataFrame(test_report)
|
307 |
+
|
308 |
+
# Get the majority vote for the test predictions
|
309 |
+
majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
|
310 |
+
majority_vote_report = pd.DataFrame([majority_vote_metrics])
|
311 |
+
majority_vote_report['model_type'] = 'XGBoost'
|
312 |
+
|
313 |
+
# Add a column with the split_type to all reports
|
314 |
+
for report in [cv_report, hparam_report, test_report, majority_vote_report]:
|
315 |
+
report['split_type'] = split_type
|
316 |
+
|
317 |
+
# Return the reports
|
318 |
+
return {
|
319 |
+
'cv_report': cv_report,
|
320 |
+
'hparam_report': hparam_report,
|
321 |
+
'test_report': test_report,
|
322 |
+
'majority_vote_report' :majority_vote_report,
|
323 |
+
}
|
protac_degradation_predictor/protac_dataset.py
CHANGED
@@ -42,7 +42,11 @@ class PROTAC_Dataset(Dataset):
|
|
42 |
cell2embedding (dict): Dictionary of cell line embeddings
|
43 |
smiles2fp (dict): Dictionary of SMILES to fingerprint
|
44 |
use_smote (bool): Whether to use SMOTE for oversampling
|
45 |
-
|
|
|
|
|
|
|
|
|
46 |
"""
|
47 |
# Filter out examples with NaN in active_label column
|
48 |
self.data = protac_df # [~protac_df[active_label].isna()]
|
@@ -124,7 +128,7 @@ class PROTAC_Dataset(Dataset):
|
|
124 |
self.data = df_smote
|
125 |
|
126 |
def fit_scaling(self, use_single_scaler: bool = False, **scaler_kwargs) -> dict:
|
127 |
-
""" Fit the scalers for the data.
|
128 |
|
129 |
Args:
|
130 |
use_single_scaler (bool): Whether to use a single scaler for all features.
|
@@ -288,8 +292,25 @@ def get_datasets(
|
|
288 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
289 |
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
|
290 |
use_single_scaler: Optional[bool] = None,
|
|
|
291 |
) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]:
|
292 |
-
""" Get the datasets for training the PROTAC model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
oversampler = SMOTE(k_neighbors=smote_k_neighbors, random_state=42)
|
294 |
train_ds = PROTAC_Dataset(
|
295 |
train_df,
|
@@ -313,6 +334,10 @@ def get_datasets(
|
|
313 |
scaler=train_ds.scaler if train_ds.scaler is not None else scaler,
|
314 |
use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
|
315 |
)
|
|
|
|
|
|
|
|
|
316 |
if test_df is not None:
|
317 |
test_ds = PROTAC_Dataset(
|
318 |
test_df,
|
@@ -321,9 +346,11 @@ def get_datasets(
|
|
321 |
smiles2fp,
|
322 |
active_label=active_label,
|
323 |
disabled_embeddings=disabled_embeddings,
|
324 |
-
scaler=
|
325 |
use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
|
326 |
)
|
|
|
|
|
327 |
else:
|
328 |
test_ds = None
|
329 |
return train_ds, val_ds, test_ds
|
|
|
42 |
cell2embedding (dict): Dictionary of cell line embeddings
|
43 |
smiles2fp (dict): Dictionary of SMILES to fingerprint
|
44 |
use_smote (bool): Whether to use SMOTE for oversampling
|
45 |
+
oversampler (SMOTE | ADASYN): The oversampler to use
|
46 |
+
active_label (str): The column containing the active/inactive information
|
47 |
+
disabled_embeddings (list): The list of embeddings to disable, i.e., return a zero vector
|
48 |
+
scaler (StandardScaler | dict): The scaler to use for the embeddings
|
49 |
+
use_single_scaler (bool): Whether to use a single scaler for all features
|
50 |
"""
|
51 |
# Filter out examples with NaN in active_label column
|
52 |
self.data = protac_df # [~protac_df[active_label].isna()]
|
|
|
128 |
self.data = df_smote
|
129 |
|
130 |
def fit_scaling(self, use_single_scaler: bool = False, **scaler_kwargs) -> dict:
|
131 |
+
""" Fit the scalers for the data and save them in the dataset class.
|
132 |
|
133 |
Args:
|
134 |
use_single_scaler (bool): Whether to use a single scaler for all features.
|
|
|
292 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
293 |
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
|
294 |
use_single_scaler: Optional[bool] = None,
|
295 |
+
apply_scaling: bool = False,
|
296 |
) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]:
|
297 |
+
""" Get the datasets for training the PROTAC model.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
train_df (pd.DataFrame): The training data.
|
301 |
+
val_df (pd.DataFrame): The validation data.
|
302 |
+
test_df (pd.DataFrame): The test data.
|
303 |
+
protein2embedding (dict): Dictionary of protein embeddings.
|
304 |
+
cell2embedding (dict): Dictionary of cell line embeddings.
|
305 |
+
smiles2fp (dict): Dictionary of SMILES to fingerprint.
|
306 |
+
use_smote (bool): Whether to use SMOTE for oversampling.
|
307 |
+
smote_k_neighbors (int): The number of neighbors to use for SMOTE.
|
308 |
+
active_label (str): The active label column.
|
309 |
+
disabled_embeddings (list): The list of embeddings to disable.
|
310 |
+
scaler (StandardScaler | dict): The scaler to use for the embeddings.
|
311 |
+
use_single_scaler (bool): Whether to use a single scaler for all features.
|
312 |
+
apply_scaling (bool): Whether to apply scaling to the data now. Defaults to False (the Pytorch Lightning model does that).
|
313 |
+
"""
|
314 |
oversampler = SMOTE(k_neighbors=smote_k_neighbors, random_state=42)
|
315 |
train_ds = PROTAC_Dataset(
|
316 |
train_df,
|
|
|
334 |
scaler=train_ds.scaler if train_ds.scaler is not None else scaler,
|
335 |
use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
|
336 |
)
|
337 |
+
train_scalers = None
|
338 |
+
if apply_scaling:
|
339 |
+
train_scalers = train_ds.fit_scaling(use_single_scaler=use_single_scaler)
|
340 |
+
val_ds.apply_scaling(train_scalers, use_single_scaler=use_single_scaler)
|
341 |
if test_df is not None:
|
342 |
test_ds = PROTAC_Dataset(
|
343 |
test_df,
|
|
|
346 |
smiles2fp,
|
347 |
active_label=active_label,
|
348 |
disabled_embeddings=disabled_embeddings,
|
349 |
+
scaler=train_scalers if apply_scaling else scaler,
|
350 |
use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
|
351 |
)
|
352 |
+
if apply_scaling:
|
353 |
+
test_ds.apply_scaling(train_ds.scaler, use_single_scaler=use_single_scaler)
|
354 |
else:
|
355 |
test_ds = None
|
356 |
return train_ds, val_ds, test_ds
|
reports/ablation_zero_vectors_report_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_random.csv
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
test_loss,test_acc,test_f1_score,test_precision,test_recall,test_roc_auc,train_len,train_active_perc,train_inactive_perc,train_avg_tanimoto_dist,test_len,test_active_perc,test_inactive_perc,test_avg_tanimoto_dist,num_leaking_uniprot_train_test,num_leaking_smiles_train_test,perc_leaking_uniprot_train_test,perc_leaking_smiles_train_test,majority_vote,model_type,disabled_embeddings,test_f1,split_type
|
2 |
+
0.7269228100776672,0.604651153087616,0.6730769276618958,0.546875,0.875,0.7173913717269897,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3,,random
|
3 |
+
0.6971672177314758,0.6162790656089783,0.5352112650871277,0.6129032373428345,0.4749999940395355,0.6717391014099121,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3,,random
|
4 |
+
0.6542536020278931,0.6395348906517029,0.6436781883239746,0.5957446694374084,0.699999988079071,0.7141305208206177,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3,,random
|
5 |
+
,0.6162790656089783,,0.6296296119689941,0.42500001192092896,0.689673900604248,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled e3,0.5074626803398132,random
|
6 |
+
0.7447491884231567,0.5930232405662537,0.6534653306007385,0.5409836173057556,0.824999988079071,0.70923912525177,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi,,random
|
7 |
+
0.7114118933677673,0.604651153087616,0.5405405163764954,0.5882353186607361,0.5,0.6630434989929199,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi,,random
|
8 |
+
0.6734361052513123,0.6162790656089783,0.6373626589775085,0.5686274766921997,0.7250000238418579,0.6940217614173889,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi,,random
|
9 |
+
,0.5930232405662537,,0.5806451439857483,0.44999998807907104,0.6809782981872559,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled poi,0.5070422291755676,random
|
10 |
+
0.7288045883178711,0.6162790656089783,0.6796116232872009,0.5555555820465088,0.875,0.717663049697876,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled cell,,random
|
11 |
+
0.6981603503227234,0.6395348906517029,0.5866666436195374,0.6285714507102966,0.550000011920929,0.6709238886833191,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled cell,,random
|
12 |
+
0.6586534380912781,0.6395348906517029,0.6436781883239746,0.5957446694374084,0.699999988079071,0.7122282385826111,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled cell,,random
|
13 |
+
,0.6279069781303406,,0.6333333253860474,0.4749999940395355,0.688858687877655,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled cell,0.5428571701049805,random
|
14 |
+
0.7676423788070679,0.4651162922382355,0.6349206566810608,0.4651162922382355,1.0,0.7361413240432739,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled smiles,,random
|
15 |
+
0.7521520256996155,0.5348837375640869,0.0,0.0,0.0,0.7638586759567261,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled smiles,,random
|
16 |
+
0.7137073278427124,0.5930232405662537,0.2857142984867096,0.7777777910232544,0.17499999701976776,0.727989137172699,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled smiles,,random
|
17 |
+
,0.5348837375640869,,0.0,0.0,0.7638587951660156,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled smiles,0.0,random
|
18 |
+
0.7207046151161194,0.6162790656089783,0.6796116232872009,0.5555555820465088,0.875,0.7160326242446899,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3 cell,,random
|
19 |
+
0.6998258829116821,0.6162790656089783,0.5352112650871277,0.6129032373428345,0.4749999940395355,0.6720108985900879,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3 cell,,random
|
20 |
+
0.6533703207969666,0.6395348906517029,0.6436781883239746,0.5957446694374084,0.699999988079071,0.7122282385826111,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3 cell,,random
|
21 |
+
,0.6162790656089783,,0.6296296119689941,0.42500001192092896,0.688858687877655,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled e3 cell,0.5074626803398132,random
|
22 |
+
0.7362547516822815,0.5930232405662537,0.6534653306007385,0.5409836173057556,0.824999988079071,0.710326075553894,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3,,random
|
23 |
+
0.7125736474990845,0.6162790656089783,0.5479452013969421,0.6060606241226196,0.5,0.6619565486907959,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3,,random
|
24 |
+
0.6676729321479797,0.6395348906517029,0.6436781883239746,0.5957446694374084,0.699999988079071,0.6945651769638062,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3,,random
|
25 |
+
,0.6162790656089783,,0.6206896305084229,0.44999998807907104,0.6836956143379211,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled poi e3,0.52173912525177,random
|
26 |
+
0.7300900816917419,0.5930232405662537,0.6534653306007385,0.5409836173057556,0.824999988079071,0.706793487071991,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3 cell,,random
|
27 |
+
0.7153109908103943,0.6162790656089783,0.5352112650871277,0.6129032373428345,0.4749999940395355,0.6611412763595581,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3 cell,,random
|
28 |
+
0.6669936180114746,0.6279069781303406,0.6279069781303406,0.5869565010070801,0.675000011920929,0.6932065486907959,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3 cell,,random
|
29 |
+
,0.6162790656089783,,0.6296296119689941,0.42500001192092896,0.6834239363670349,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled poi e3 cell,0.5074626803398132,random
|
reports/ablation_zero_vectors_report_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_tanimoto.csv
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
test_loss,test_acc,test_f1_score,test_precision,test_recall,test_roc_auc,train_len,train_active_perc,train_inactive_perc,train_avg_tanimoto_dist,test_len,test_active_perc,test_inactive_perc,test_avg_tanimoto_dist,num_leaking_uniprot_train_test,num_leaking_smiles_train_test,perc_leaking_uniprot_train_test,perc_leaking_smiles_train_test,majority_vote,model_type,disabled_embeddings,test_f1,split_type
|
2 |
+
0.8296061754226685,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7832207083702087,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3,,tanimoto
|
3 |
+
0.6474169492721558,0.6000000238418579,0.6600000262260437,0.523809552192688,0.8918918967247009,0.7668918371200562,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3,,tanimoto
|
4 |
+
0.6295721530914307,0.7529411911964417,0.7042253613471985,0.7352941036224365,0.6756756901741028,0.8141891956329346,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3,,tanimoto
|
5 |
+
,0.7529411911964417,,0.75,0.6486486196517944,0.8023648858070374,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled e3,0.695652186870575,tanimoto
|
6 |
+
0.8408050537109375,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7691441774368286,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi,,tanimoto
|
7 |
+
0.6602048277854919,0.5764706134796143,0.6470588445663452,0.5076923370361328,0.8918918967247009,0.7494369745254517,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi,,tanimoto
|
8 |
+
0.634836733341217,0.7411764860153198,0.6944444179534912,0.7142857313156128,0.6756756901741028,0.7849099636077881,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi,,tanimoto
|
9 |
+
,0.7411764860153198,,0.7272727489471436,0.6486486196517944,0.7770270109176636,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled poi,0.6857143044471741,tanimoto
|
10 |
+
0.835131824016571,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7736486196517944,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled cell,,tanimoto
|
11 |
+
0.6562066674232483,0.5882353186607361,0.6534653306007385,0.515625,0.8918918967247009,0.75,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled cell,,tanimoto
|
12 |
+
0.6323299407958984,0.729411780834198,0.6760563254356384,0.7058823704719543,0.6486486196517944,0.8001126050949097,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled cell,,tanimoto
|
13 |
+
,0.729411780834198,,0.71875,0.6216216087341309,0.7905405163764954,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled cell,0.6666666865348816,tanimoto
|
14 |
+
0.8332716226577759,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.798704981803894,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled smiles,,tanimoto
|
15 |
+
0.765400767326355,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7919481992721558,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled smiles,,tanimoto
|
16 |
+
0.6887043118476868,0.4941176474094391,0.632478654384613,0.4625000059604645,1.0,0.8110923171043396,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled smiles,,tanimoto
|
17 |
+
,0.4941176474094391,,0.4625000059604645,1.0,0.8110923767089844,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled smiles,0.632478654384613,tanimoto
|
18 |
+
0.825886070728302,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7787162065505981,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3 cell,,tanimoto
|
19 |
+
0.6474983096122742,0.6000000238418579,0.6600000262260437,0.523809552192688,0.8918918967247009,0.7567567825317383,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3 cell,,tanimoto
|
20 |
+
0.6309086680412292,0.7411764860153198,0.6857143044471741,0.7272727489471436,0.6486486196517944,0.8119369149208069,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3 cell,,tanimoto
|
21 |
+
,0.7411764860153198,,0.7419354915618896,0.6216216087341309,0.8006756901741028,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled e3 cell,0.6764705777168274,tanimoto
|
22 |
+
0.8314616680145264,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7697072625160217,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3,,tanimoto
|
23 |
+
0.651317298412323,0.6117647290229797,0.6666666865348816,0.5322580933570862,0.8918918967247009,0.7488738894462585,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3,,tanimoto
|
24 |
+
0.633421003818512,0.7529411911964417,0.7042253613471985,0.7352941036224365,0.6756756901741028,0.795045018196106,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3,,tanimoto
|
25 |
+
,0.7529411911964417,,0.75,0.6486486196517944,0.7837837934494019,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled poi e3,0.695652186870575,tanimoto
|
26 |
+
0.8277769088745117,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7629504203796387,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3 cell,,tanimoto
|
27 |
+
0.6514514088630676,0.6000000238418579,0.6530612111091614,0.5245901346206665,0.8648648858070374,0.7438063025474548,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3 cell,,tanimoto
|
28 |
+
0.6348393559455872,0.7411764860153198,0.6857143044471741,0.7272727489471436,0.6486486196517944,0.7837837934494019,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3 cell,,tanimoto
|
29 |
+
,0.7411764860153198,,0.7419354915618896,0.6216216087341309,0.7742117047309875,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled poi e3 cell,0.6764705777168274,tanimoto
|
reports/ablation_zero_vectors_report_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_uniprot.csv
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
test_loss,test_acc,test_f1_score,test_precision,test_recall,test_roc_auc,train_len,train_active_perc,train_inactive_perc,train_avg_tanimoto_dist,test_len,test_active_perc,test_inactive_perc,test_avg_tanimoto_dist,num_leaking_uniprot_train_test,num_leaking_smiles_train_test,perc_leaking_uniprot_train_test,perc_leaking_smiles_train_test,majority_vote,model_type,disabled_embeddings,test_f1,split_type
|
2 |
+
0.7041562795639038,0.4588235318660736,0.0,0.0,0.0,0.5156075954437256,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3,,uniprot
|
3 |
+
0.6916469931602478,0.4588235318660736,0.0,0.0,0.0,0.4420289397239685,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3,,uniprot
|
4 |
+
0.6960257887840271,0.4588235318660736,0.0,0.0,0.0,0.4303233027458191,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3,,uniprot
|
5 |
+
,0.4588235318660736,,0.0,0.0,0.5156075954437256,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled e3,0.0,uniprot
|
6 |
+
0.7039564251899719,0.4588235318660736,0.0,0.0,0.0,0.532608687877655,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi,,uniprot
|
7 |
+
0.6913965940475464,0.4588235318660736,0.0,0.0,0.0,0.46739131212234497,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi,,uniprot
|
8 |
+
0.6957095265388489,0.4588235318660736,0.0,0.0,0.0,0.45234110951423645,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi,,uniprot
|
9 |
+
,0.4588235318660736,,0.0,0.0,0.532608687877655,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled poi,0.0,uniprot
|
10 |
+
0.7036164402961731,0.4588235318660736,0.0,0.0,0.0,0.530379056930542,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled cell,,uniprot
|
11 |
+
0.6914005875587463,0.4588235318660736,0.0,0.0,0.0,0.48188406229019165,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled cell,,uniprot
|
12 |
+
0.695412814617157,0.4588235318660736,0.0,0.0,0.0,0.4763098955154419,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled cell,,uniprot
|
13 |
+
,0.4588235318660736,,0.0,0.0,0.530379056930542,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled cell,0.0,uniprot
|
14 |
+
0.697465717792511,0.4588235318660736,0.0,0.0,0.0,0.6223523020744324,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled smiles,,uniprot
|
15 |
+
0.6916133761405945,0.4588235318660736,0.0,0.0,0.0,0.6636008620262146,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled smiles,,uniprot
|
16 |
+
0.6932395696640015,0.4588235318660736,0.0,0.0,0.0,0.651337742805481,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled smiles,,uniprot
|
17 |
+
,0.4588235318660736,,0.0,0.0,0.6223522424697876,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled smiles,0.0,uniprot
|
18 |
+
0.704821765422821,0.4588235318660736,0.0,0.0,0.0,0.518673300743103,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3 cell,,uniprot
|
19 |
+
0.6916972398757935,0.4588235318660736,0.0,0.0,0.0,0.45234113931655884,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3 cell,,uniprot
|
20 |
+
0.6962708830833435,0.4588235318660736,0.0,0.0,0.0,0.42892974615097046,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3 cell,,uniprot
|
21 |
+
,0.4588235318660736,,0.0,0.0,0.5186733603477478,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled e3 cell,0.0,uniprot
|
22 |
+
0.7051585912704468,0.4588235318660736,0.0,0.0,0.0,0.5103121399879456,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3,,uniprot
|
23 |
+
0.6916910409927368,0.4588235318660736,0.0,0.0,0.0,0.44732439517974854,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3,,uniprot
|
24 |
+
0.6965663433074951,0.4588235318660736,0.0,0.0,0.0,0.40328872203826904,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3,,uniprot
|
25 |
+
,0.4588235318660736,,0.0,0.0,0.5103121399879456,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled poi e3,0.0,uniprot
|
26 |
+
0.7058382034301758,0.4588235318660736,0.0,0.0,0.0,0.5080825090408325,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3 cell,,uniprot
|
27 |
+
0.6917427778244019,0.4588235318660736,0.0,0.0,0.0,0.450111448764801,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3 cell,,uniprot
|
28 |
+
0.6968205571174622,0.4588235318660736,0.0,0.0,0.0,0.4155518114566803,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3 cell,,uniprot
|
29 |
+
,0.4588235318660736,,0.0,0.0,0.5080825090408325,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled poi e3 cell,0.0,uniprot
|
src/run_xgboost_experiments.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from collections import defaultdict
|
4 |
+
import warnings
|
5 |
+
import logging
|
6 |
+
from typing import Literal
|
7 |
+
|
8 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
9 |
+
|
10 |
+
import protac_degradation_predictor as pdp
|
11 |
+
from protac_degradation_predictor.optuna_utils import get_dataframe_stats
|
12 |
+
|
13 |
+
import pytorch_lightning as pl
|
14 |
+
from rdkit import Chem
|
15 |
+
from rdkit.Chem import AllChem
|
16 |
+
from rdkit import DataStructs
|
17 |
+
from jsonargparse import CLI
|
18 |
+
import pandas as pd
|
19 |
+
from tqdm import tqdm
|
20 |
+
import numpy as np
|
21 |
+
from sklearn.preprocessing import OrdinalEncoder
|
22 |
+
from sklearn.model_selection import (
|
23 |
+
StratifiedKFold,
|
24 |
+
StratifiedGroupKFold,
|
25 |
+
)
|
26 |
+
|
27 |
+
# Ignore UserWarning from Matplotlib
|
28 |
+
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
29 |
+
# Ignore UserWarning from PyTorch Lightning
|
30 |
+
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
31 |
+
|
32 |
+
|
33 |
+
root = logging.getLogger()
|
34 |
+
root.setLevel(logging.DEBUG)
|
35 |
+
|
36 |
+
handler = logging.StreamHandler(sys.stdout)
|
37 |
+
handler.setLevel(logging.DEBUG)
|
38 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
39 |
+
handler.setFormatter(formatter)
|
40 |
+
root.addHandler(handler)
|
41 |
+
|
42 |
+
|
43 |
+
def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index:
|
44 |
+
""" Get the indices of the test set using a random split.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
48 |
+
test_split (float): The percentage of the active PROTACs to use as the test set.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
pd.Index: The indices of the test set.
|
52 |
+
"""
|
53 |
+
test_df = active_df.sample(frac=test_split, random_state=42)
|
54 |
+
return test_df.index
|
55 |
+
|
56 |
+
|
57 |
+
def get_e3_ligase_split_indices(active_df: pd.DataFrame) -> pd.Index:
|
58 |
+
""" Get the indices of the test set using the E3 ligase split.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
pd.Index: The indices of the test set.
|
65 |
+
"""
|
66 |
+
encoder = OrdinalEncoder()
|
67 |
+
active_df['E3 Group'] = encoder.fit_transform(active_df[['E3 Ligase']]).astype(int)
|
68 |
+
test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')]
|
69 |
+
return test_df.index
|
70 |
+
|
71 |
+
|
72 |
+
def get_smiles2fp_and_avg_tanimoto(protac_df: pd.DataFrame) -> tuple:
|
73 |
+
""" Get the SMILES to fingerprint dictionary and the average Tanimoto similarity.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
protac_df (pd.DataFrame): The DataFrame containing the PROTACs.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
tuple: The SMILES to fingerprint dictionary and the average Tanimoto similarity.
|
80 |
+
"""
|
81 |
+
unique_smiles = protac_df['Smiles'].unique().tolist()
|
82 |
+
|
83 |
+
smiles2fp = {}
|
84 |
+
for smiles in tqdm(unique_smiles, desc='Precomputing fingerprints'):
|
85 |
+
smiles2fp[smiles] = pdp.get_fingerprint(smiles)
|
86 |
+
|
87 |
+
# # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
|
88 |
+
# tanimoto_matrix = defaultdict(list)
|
89 |
+
# for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
|
90 |
+
# fp1 = smiles2fp[smiles1]
|
91 |
+
# # TODO: Use BulkTanimotoSimilarity for better performance
|
92 |
+
# for j, smiles2 in enumerate(protac_df['Smiles'].unique()[i:]):
|
93 |
+
# fp2 = smiles2fp[smiles2]
|
94 |
+
# tanimoto_dist = 1 - DataStructs.TanimotoSimilarity(fp1, fp2)
|
95 |
+
# tanimoto_matrix[smiles1].append(tanimoto_dist)
|
96 |
+
# avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
|
97 |
+
# protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
|
98 |
+
|
99 |
+
|
100 |
+
tanimoto_matrix = defaultdict(list)
|
101 |
+
fps = list(smiles2fp.values())
|
102 |
+
|
103 |
+
# Compute all-against-all Tanimoto similarity using BulkTanimotoSimilarity
|
104 |
+
for i, (smiles1, fp1) in enumerate(tqdm(zip(unique_smiles, fps), desc='Computing Tanimoto similarity', total=len(fps))):
|
105 |
+
similarities = DataStructs.BulkTanimotoSimilarity(fp1, fps[i:]) # Only compute for i to end, avoiding duplicates
|
106 |
+
for j, similarity in enumerate(similarities):
|
107 |
+
distance = 1 - similarity
|
108 |
+
tanimoto_matrix[smiles1].append(distance) # Store as distance
|
109 |
+
if i != i + j:
|
110 |
+
tanimoto_matrix[unique_smiles[i + j]].append(distance) # Symmetric filling
|
111 |
+
|
112 |
+
# Calculate average Tanimoto distance for each unique SMILES
|
113 |
+
avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
|
114 |
+
protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
|
115 |
+
|
116 |
+
smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
|
117 |
+
|
118 |
+
return smiles2fp, protac_df
|
119 |
+
|
120 |
+
|
121 |
+
def get_tanimoto_split_indices(
|
122 |
+
active_df: pd.DataFrame,
|
123 |
+
active_col: str,
|
124 |
+
test_split: float,
|
125 |
+
n_bins_tanimoto: int = 200,
|
126 |
+
) -> pd.Index:
|
127 |
+
""" Get the indices of the test set using the Tanimoto-based split.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
131 |
+
n_bins_tanimoto (int): The number of bins to use for the Tanimoto similarity.
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
pd.Index: The indices of the test set.
|
135 |
+
"""
|
136 |
+
tanimoto_groups = pd.cut(active_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
|
137 |
+
encoder = OrdinalEncoder()
|
138 |
+
active_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
|
139 |
+
# Sort the groups so that samples with the highest tanimoto similarity,
|
140 |
+
# i.e., the "less similar" ones, are placed in the test set first
|
141 |
+
tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
|
142 |
+
|
143 |
+
test_df = []
|
144 |
+
# For each group, get the number of active and inactive entries. Then, add those
|
145 |
+
# entries to the test_df if: 1) the test_df lenght + the group entries is less
|
146 |
+
# 20% of the active_df lenght, and 2) the percentage of True and False entries
|
147 |
+
# in the active_col in test_df is roughly 50%.
|
148 |
+
for group in tanimoto_groups:
|
149 |
+
group_df = active_df[active_df['Tanimoto Group'] == group]
|
150 |
+
if test_df == []:
|
151 |
+
test_df.append(group_df)
|
152 |
+
continue
|
153 |
+
|
154 |
+
num_entries = len(group_df)
|
155 |
+
num_active_group = group_df[active_col].sum()
|
156 |
+
num_inactive_group = num_entries - num_active_group
|
157 |
+
|
158 |
+
tmp_test_df = pd.concat(test_df)
|
159 |
+
num_entries_test = len(tmp_test_df)
|
160 |
+
num_active_test = tmp_test_df[active_col].sum()
|
161 |
+
num_inactive_test = num_entries_test - num_active_test
|
162 |
+
|
163 |
+
# Check if the group entries can be added to the test_df
|
164 |
+
if num_entries_test + num_entries < test_split * len(active_df):
|
165 |
+
# Add anything at the beggining
|
166 |
+
if num_entries_test + num_entries < test_split / 2 * len(active_df):
|
167 |
+
test_df.append(group_df)
|
168 |
+
continue
|
169 |
+
# Be more selective and make sure that the percentage of active and
|
170 |
+
# inactive is balanced
|
171 |
+
if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
|
172 |
+
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
|
173 |
+
test_df.append(group_df)
|
174 |
+
test_df = pd.concat(test_df)
|
175 |
+
return test_df.index
|
176 |
+
|
177 |
+
|
178 |
+
def get_target_split_indices(active_df: pd.DataFrame, active_col: str, test_split: float) -> pd.Index:
|
179 |
+
""" Get the indices of the test set using the target-based split.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
183 |
+
active_col (str): The column containing the active/inactive information.
|
184 |
+
test_split (float): The percentage of the active PROTACs to use as the test set.
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
pd.Index: The indices of the test set.
|
188 |
+
"""
|
189 |
+
encoder = OrdinalEncoder()
|
190 |
+
active_df['Uniprot Group'] = encoder.fit_transform(active_df[['Uniprot']]).astype(int)
|
191 |
+
|
192 |
+
test_df = []
|
193 |
+
# For each group, get the number of active and inactive entries. Then, add those
|
194 |
+
# entries to the test_df if: 1) the test_df lenght + the group entries is less
|
195 |
+
# 20% of the active_df lenght, and 2) the percentage of True and False entries
|
196 |
+
# in the active_col in test_df is roughly 50%.
|
197 |
+
# Start the loop from the groups containing the smallest number of entries.
|
198 |
+
for group in reversed(active_df['Uniprot'].value_counts().index):
|
199 |
+
group_df = active_df[active_df['Uniprot'] == group]
|
200 |
+
if test_df == []:
|
201 |
+
test_df.append(group_df)
|
202 |
+
continue
|
203 |
+
|
204 |
+
num_entries = len(group_df)
|
205 |
+
num_active_group = group_df[active_col].sum()
|
206 |
+
num_inactive_group = num_entries - num_active_group
|
207 |
+
|
208 |
+
tmp_test_df = pd.concat(test_df)
|
209 |
+
num_entries_test = len(tmp_test_df)
|
210 |
+
num_active_test = tmp_test_df[active_col].sum()
|
211 |
+
num_inactive_test = num_entries_test - num_active_test
|
212 |
+
|
213 |
+
# Check if the group entries can be added to the test_df
|
214 |
+
if num_entries_test + num_entries < test_split * len(active_df):
|
215 |
+
# Add anything at the beggining
|
216 |
+
if num_entries_test + num_entries < test_split / 2 * len(active_df):
|
217 |
+
test_df.append(group_df)
|
218 |
+
continue
|
219 |
+
# Be more selective and make sure that the percentage of active and
|
220 |
+
# inactive is balanced
|
221 |
+
if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
|
222 |
+
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
|
223 |
+
test_df.append(group_df)
|
224 |
+
test_df = pd.concat(test_df)
|
225 |
+
return test_df.index
|
226 |
+
|
227 |
+
|
228 |
+
def main(
|
229 |
+
active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
|
230 |
+
n_trials: int = 100,
|
231 |
+
test_split: float = 0.1,
|
232 |
+
cv_n_splits: int = 5,
|
233 |
+
num_boost_round: int = 100,
|
234 |
+
force_study: bool = False,
|
235 |
+
experiments: str | Literal['all', 'random', 'e3_ligase', 'tanimoto', 'uniprot'] = 'all',
|
236 |
+
):
|
237 |
+
""" Train a PROTAC model using the given datasets and hyperparameters.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
use_ored_activity (bool): Whether to use the 'Active - OR' column.
|
241 |
+
n_trials (int): The number of hyperparameter optimization trials.
|
242 |
+
n_splits (int): The number of cross-validation splits.
|
243 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
244 |
+
"""
|
245 |
+
pl.seed_everything(42)
|
246 |
+
|
247 |
+
# Set the Column to Predict
|
248 |
+
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
|
249 |
+
|
250 |
+
# Get Dmax_threshold from the active_col
|
251 |
+
Dmax_threshold = float(active_col.split('Dmax')[1].split(',')[0].strip('(').strip(')').strip())
|
252 |
+
pDC50_threshold = float(active_col.split('pDC50')[1].strip('(').strip(')').strip())
|
253 |
+
|
254 |
+
# Load the PROTAC dataset
|
255 |
+
protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
|
256 |
+
# Map E3 Ligase Iap to IAP
|
257 |
+
protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
|
258 |
+
protac_df[active_col] = protac_df.apply(
|
259 |
+
lambda x: pdp.is_active(x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold), axis=1
|
260 |
+
)
|
261 |
+
smiles2fp, protac_df = get_smiles2fp_and_avg_tanimoto(protac_df)
|
262 |
+
|
263 |
+
## Get the test sets
|
264 |
+
test_indeces = {}
|
265 |
+
active_df = protac_df[protac_df[active_col].notna()].copy()
|
266 |
+
|
267 |
+
if experiments == 'random' or experiments == 'all':
|
268 |
+
test_indeces['random'] = get_random_split_indices(active_df, test_split)
|
269 |
+
if experiments == 'uniprot' or experiments == 'all':
|
270 |
+
test_indeces['uniprot'] = get_target_split_indices(active_df, active_col, test_split)
|
271 |
+
if experiments == 'e3_ligase' or experiments == 'all':
|
272 |
+
test_indeces['e3_ligase'] = get_e3_ligase_split_indices(active_df)
|
273 |
+
if experiments == 'tanimoto' or experiments == 'all':
|
274 |
+
test_indeces['tanimoto'] = get_tanimoto_split_indices(active_df, active_col, test_split)
|
275 |
+
|
276 |
+
# Make directory ../reports if it does not exist
|
277 |
+
if not os.path.exists('../reports'):
|
278 |
+
os.makedirs('../reports')
|
279 |
+
|
280 |
+
# Load embedding dictionaries
|
281 |
+
protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
|
282 |
+
cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
|
283 |
+
|
284 |
+
# Cross-Validation Training
|
285 |
+
reports = defaultdict(list)
|
286 |
+
for split_type, indeces in test_indeces.items():
|
287 |
+
test_df = active_df.loc[indeces].copy()
|
288 |
+
train_val_df = active_df[~active_df.index.isin(test_df.index)].copy()
|
289 |
+
|
290 |
+
# Get the CV object
|
291 |
+
if split_type == 'random':
|
292 |
+
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
293 |
+
group = None
|
294 |
+
elif split_type == 'e3_ligase':
|
295 |
+
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
296 |
+
group = train_val_df['E3 Group'].to_numpy()
|
297 |
+
elif split_type == 'tanimoto':
|
298 |
+
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
299 |
+
group = train_val_df['Tanimoto Group'].to_numpy()
|
300 |
+
elif split_type == 'uniprot':
|
301 |
+
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
302 |
+
group = train_val_df['Uniprot Group'].to_numpy()
|
303 |
+
|
304 |
+
# Start the experiment
|
305 |
+
experiment_name = f'{active_name}_test_split_{test_split}_{split_type}'
|
306 |
+
optuna_reports = pdp.xgboost_hyperparameter_tuning_and_training(
|
307 |
+
protein2embedding=protein2embedding,
|
308 |
+
cell2embedding=cell2embedding,
|
309 |
+
smiles2fp=smiles2fp,
|
310 |
+
train_val_df=train_val_df,
|
311 |
+
test_df=test_df,
|
312 |
+
kf=kf,
|
313 |
+
groups=group,
|
314 |
+
split_type=split_type,
|
315 |
+
n_models_for_test=3,
|
316 |
+
n_trials=n_trials,
|
317 |
+
active_label=active_col,
|
318 |
+
num_boost_round=num_boost_round,
|
319 |
+
study_filename=f'../reports/study_xgboost_{experiment_name}.pkl',
|
320 |
+
force_study=force_study,
|
321 |
+
)
|
322 |
+
|
323 |
+
# Save the reports to file
|
324 |
+
for report_name, report in optuna_reports.items():
|
325 |
+
report.to_csv(f'../reports/xgboost_{report_name}_{experiment_name}.csv', index=False)
|
326 |
+
reports[report_name].append(report.copy())
|
327 |
+
|
328 |
+
if __name__ == '__main__':
|
329 |
+
cli = CLI(main)
|