ribesstefano commited on
Commit
33f1644
1 Parent(s): 4bf0ec2

Added XGBoost Optuna training + Added Ablation studies with zeroed input vectors

Browse files
protac_degradation_predictor/__init__.py CHANGED
@@ -17,7 +17,9 @@ from .sklearn_models import (
17
  )
18
  from .optuna_utils import (
19
  hyperparameter_tuning_and_training,
20
- hyperparameter_tuning_and_training_sklearn,
 
 
21
  )
22
  from .protac_degradation_predictor import (
23
  get_protac_active_proba,
 
17
  )
18
  from .optuna_utils import (
19
  hyperparameter_tuning_and_training,
20
+ )
21
+ from .optuna_utils_xgboost import (
22
+ xgboost_hyperparameter_tuning_and_training,
23
  )
24
  from .protac_degradation_predictor import (
25
  get_protac_active_proba,
protac_degradation_predictor/optuna_utils.py CHANGED
@@ -234,6 +234,18 @@ def pytorch_model_objective(
234
 
235
  # Optuna aims to minimize the pytorch_model_objective
236
  return - val_roc_auc
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
 
239
  def hyperparameter_tuning_and_training(
 
234
 
235
  # Optuna aims to minimize the pytorch_model_objective
236
  return - val_roc_auc
237
+ # # Get the majority vote for the test predictions
238
+ # if test_df is not None and not fast_dev_run:
239
+ # majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
240
+ # majority_vote_metrics.update(get_dataframe_stats(train_df, val_df, test_df, active_label))
241
+ # trial.set_user_attr('majority_vote_metrics', majority_vote_metrics)
242
+ # logging.info(f'Majority vote metrics: {majority_vote_metrics}')
243
+
244
+ # # Get the average validation accuracy and ROC AUC accross the folds
245
+ # val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
246
+
247
+ # # Optuna aims to minimize the pytorch_model_objective
248
+ # return - val_roc_auc
249
 
250
 
251
  def hyperparameter_tuning_and_training(
protac_degradation_predictor/optuna_utils_xgboost.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict
2
+ import logging
3
+ import os
4
+
5
+ from .optuna_utils import get_majority_vote_metrics, get_dataframe_stats
6
+ from .protac_dataset import get_datasets
7
+
8
+ import optuna
9
+ import xgboost as xgb
10
+ import pandas as pd
11
+ import numpy as np
12
+ from sklearn.model_selection import StratifiedKFold
13
+ from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
14
+ import xgboost as xgb
15
+ import pandas as pd
16
+ import numpy as np
17
+ from sklearn.model_selection import StratifiedKFold
18
+ from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
19
+ import joblib
20
+ from optuna.samplers import TPESampler
21
+ import torch
22
+
23
+
24
+ xgb.set_config(verbosity=0)
25
+
26
+
27
+ def train_and_evaluate_xgboost(
28
+ protein2embedding: Dict,
29
+ cell2embedding: Dict,
30
+ smiles2fp: Dict,
31
+ train_df: pd.DataFrame,
32
+ val_df: pd.DataFrame,
33
+ params: dict,
34
+ test_df: Optional[pd.DataFrame] = None,
35
+ active_label: str = 'Active',
36
+ num_boost_round: int = 100,
37
+ shuffle_train_data: bool = False,
38
+ ) -> tuple:
39
+ """
40
+ Train and evaluate an XGBoost model with the given parameters.
41
+
42
+ Args:
43
+ train_df (pd.DataFrame): The training and validation data.
44
+ test_df (pd.DataFrame): The test data.
45
+ params (dict): Hyperparameters for the XGBoost model.
46
+ active_label (str): The active label column.
47
+ num_boost_round (int): Maximum number of epochs.
48
+
49
+ Returns:
50
+ tuple: The trained model, test predictions, and metrics.
51
+ """
52
+ # Get datasets and their numpy arrays
53
+ train_ds, val_ds, test_ds = get_datasets(
54
+ protein2embedding=protein2embedding,
55
+ cell2embedding=cell2embedding,
56
+ smiles2fp=smiles2fp,
57
+ train_df=train_df,
58
+ val_df=val_df,
59
+ test_df=test_df,
60
+ disabled_embeddings=[],
61
+ active_label=active_label,
62
+ apply_scaling=False,
63
+ )
64
+ X_train, y_train = train_ds.get_numpy_arrays()
65
+ X_val, y_val = val_ds.get_numpy_arrays()
66
+
67
+ # Shuffle the training data
68
+ if shuffle_train_data:
69
+ idx = np.random.permutation(len(X_train))
70
+ X_train, y_train = X_train[idx], y_train[idx]
71
+
72
+ # Setup training and validation data in XGBoost data format
73
+ dtrain = xgb.DMatrix(X_train, label=y_train)
74
+ dval = xgb.DMatrix(X_val, label=y_val)
75
+ evallist = [(dval, 'eval'), (dtrain, 'train')]
76
+
77
+ # Setup test data
78
+ if test_df is not None:
79
+ X_test, y_test = test_ds.get_numpy_arrays()
80
+ dtest = xgb.DMatrix(X_test, label=y_test)
81
+ evallist.append((dtest, 'test'))
82
+
83
+ model = xgb.train(
84
+ params,
85
+ dtrain,
86
+ num_boost_round=num_boost_round,
87
+ evals=evallist,
88
+ early_stopping_rounds=10,
89
+ verbose_eval=False,
90
+ )
91
+
92
+ # Evaluate model
93
+ val_pred = model.predict(dval)
94
+ val_pred_binary = (val_pred > 0.5).astype(int)
95
+ metrics = {
96
+ 'val_accuracy': accuracy_score(y_val, val_pred_binary),
97
+ 'val_roc_auc': roc_auc_score(y_val, val_pred),
98
+ 'val_precision': precision_score(y_val, val_pred_binary),
99
+ 'val_recall': recall_score(y_val, val_pred_binary),
100
+ 'val_f1_score': f1_score(y_val, val_pred_binary),
101
+ }
102
+ preds = {'val_pred': val_pred}
103
+
104
+ if test_df is not None:
105
+ test_pred = model.predict(dtest)
106
+ test_pred_binary = (test_pred > 0.5).astype(int)
107
+ metrics.update({
108
+ 'test_accuracy': accuracy_score(y_test, test_pred_binary),
109
+ 'test_roc_auc': roc_auc_score(y_test, test_pred),
110
+ 'test_precision': precision_score(y_test, test_pred_binary),
111
+ 'test_recall': recall_score(y_test, test_pred_binary),
112
+ 'test_f1_score': f1_score(y_test, test_pred_binary),
113
+ })
114
+ preds.update({'test_pred': test_pred})
115
+
116
+ return model, preds, metrics
117
+
118
+
119
+ def xgboost_model_objective(
120
+ trial: optuna.Trial,
121
+ protein2embedding: Dict,
122
+ cell2embedding: Dict,
123
+ smiles2fp: Dict,
124
+ train_val_df: pd.DataFrame,
125
+ kf: StratifiedKFold,
126
+ groups: Optional[np.array] = None,
127
+ active_label: str = 'Active',
128
+ num_boost_round: int = 100,
129
+ ) -> float:
130
+ """ Objective function for hyperparameter optimization with XGBoost.
131
+
132
+ Args:
133
+ trial (optuna.Trial): The Optuna trial object.
134
+ train_val_df (pd.DataFrame): The training and validation data.
135
+ kf (StratifiedKFold): Stratified K-Folds cross-validator.
136
+ test_df (Optional[pd.DataFrame]): The test data.
137
+ active_label (str): The active label column.
138
+ num_boost_round (int): Maximum number of epochs.
139
+ use_logger (bool): Whether to use logging.
140
+ """
141
+ # Suggest hyperparameters to be used across the CV folds
142
+ params = {
143
+ 'booster': 'gbtree',
144
+ 'tree_method': 'hist', # if torch.cuda.is_available() else 'hist',
145
+ 'objective': 'binary:logistic',
146
+ 'eval_metric': 'auc',
147
+ 'eta': trial.suggest_float('eta', 1e-4, 1e-1, log=True),
148
+ 'max_depth': trial.suggest_int('max_depth', 3, 10),
149
+ 'min_child_weight': trial.suggest_float('min_child_weight', 1e-3, 10.0, log=True),
150
+ 'gamma': trial.suggest_float('gamma', 1e-4, 1e-1, log=True),
151
+ 'subsample': trial.suggest_float('subsample', 0.5, 1.0),
152
+ 'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
153
+ }
154
+
155
+ X = train_val_df.copy().drop(columns=active_label)
156
+ y = train_val_df[active_label].tolist()
157
+ report = []
158
+ val_preds = []
159
+
160
+ for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
161
+ logging.info(f'Fold {k + 1}/{kf.get_n_splits()}')
162
+ train_df = train_val_df.iloc[train_index]
163
+ val_df = train_val_df.iloc[val_index]
164
+
165
+ # Get some statistics from the dataframes
166
+ stats = {
167
+ 'model_type': 'XGBoost',
168
+ 'fold': k,
169
+ 'train_len': len(train_df),
170
+ 'val_len': len(val_df),
171
+ 'train_perc': len(train_df) / len(train_val_df),
172
+ 'val_perc': len(val_df) / len(train_val_df),
173
+ }
174
+ stats.update(get_dataframe_stats(train_df, val_df, active_label=active_label))
175
+ if groups is not None:
176
+ stats['train_unique_groups'] = len(np.unique(groups[train_index]))
177
+ stats['val_unique_groups'] = len(np.unique(groups[val_index]))
178
+
179
+ _, preds, metrics = train_and_evaluate_xgboost(
180
+ protein2embedding=protein2embedding,
181
+ cell2embedding=cell2embedding,
182
+ smiles2fp=smiles2fp,
183
+ train_df=train_df,
184
+ val_df=val_df,
185
+ params=params,
186
+ active_label=active_label,
187
+ num_boost_round=num_boost_round,
188
+ )
189
+ stats.update(metrics)
190
+ report.append(stats.copy())
191
+ val_preds.append(preds['val_pred'])
192
+
193
+ # Save the report in the trial
194
+ trial.set_user_attr('report', report)
195
+ trial.set_user_attr('val_preds', val_preds)
196
+ trial.set_user_attr('params', params)
197
+
198
+ # Get the average validation metrics across the folds
199
+ mean_val_roc_auc = np.mean([r['val_roc_auc'] for r in report])
200
+ logging.info(f'\tMean val ROC AUC: {mean_val_roc_auc:.4f}')
201
+
202
+ # Optuna aims to minimize the objective, so return the negative ROC AUC
203
+ return -mean_val_roc_auc
204
+
205
+
206
+ def xgboost_hyperparameter_tuning_and_training(
207
+ protein2embedding: Dict,
208
+ cell2embedding: Dict,
209
+ smiles2fp: Dict,
210
+ train_val_df: pd.DataFrame,
211
+ test_df: pd.DataFrame,
212
+ kf: StratifiedKFold,
213
+ groups: Optional[np.array] = None,
214
+ split_type: str = 'random',
215
+ n_models_for_test: int = 3,
216
+ n_trials: int = 50,
217
+ active_label: str = 'Active',
218
+ num_boost_round: int = 100,
219
+ study_filename: Optional[str] = None,
220
+ force_study: bool = False,
221
+ ) -> dict:
222
+ """ Hyperparameter tuning and training of an XGBoost model.
223
+
224
+ Args:
225
+ train_val_df (pd.DataFrame): The training and validation data.
226
+ test_df (pd.DataFrame): The test data.
227
+ kf (StratifiedKFold): Stratified K-Folds cross-validator.
228
+ groups (Optional[np.array]): Group labels for the samples used while splitting the dataset into train/test set.
229
+ split_type (str): Type of the data split.
230
+ n_models_for_test (int): Number of models to train for testing.
231
+ fast_dev_run (bool): Whether to run a fast development run.
232
+ n_trials (int): Number of trials for hyperparameter optimization.
233
+ logger_save_dir (str): Directory to save logs.
234
+ logger_name (str): Name of the logger.
235
+ active_label (str): The active label column.
236
+ num_boost_round (int): Maximum number of epochs.
237
+ study_filename (Optional[str]): File name to save/load the Optuna study.
238
+ force_study (bool): Whether to force the study optimization even if the study file exists.
239
+
240
+ Returns:
241
+ dict: A dictionary containing reports from the CV and test.
242
+ """
243
+ # Set the verbosity of Optuna
244
+ optuna.logging.set_verbosity(optuna.logging.WARNING)
245
+
246
+ # Create an Optuna study object
247
+ sampler = TPESampler(seed=42)
248
+ study = optuna.create_study(direction='minimize', sampler=sampler)
249
+
250
+ study_loaded = False
251
+ if study_filename and not force_study:
252
+ if os.path.exists(study_filename):
253
+ study = joblib.load(study_filename)
254
+ study_loaded = True
255
+ logging.info(f'Loaded study from {study_filename}')
256
+
257
+ if not study_loaded or force_study:
258
+ study.optimize(
259
+ lambda trial: xgboost_model_objective(
260
+ trial=trial,
261
+ protein2embedding=protein2embedding,
262
+ cell2embedding=cell2embedding,
263
+ smiles2fp=smiles2fp,
264
+ train_val_df=train_val_df,
265
+ kf=kf,
266
+ groups=groups,
267
+ active_label=active_label,
268
+ num_boost_round=num_boost_round,
269
+ ),
270
+ n_trials=n_trials,
271
+ )
272
+ if study_filename:
273
+ joblib.dump(study, study_filename)
274
+
275
+ cv_report = pd.DataFrame(study.best_trial.user_attrs['report'])
276
+ hparam_report = pd.DataFrame([study.best_params])
277
+
278
+ # Retrain N models with the best hyperparameters (measure model uncertainty)
279
+ best_models = []
280
+ test_report = []
281
+ test_preds = []
282
+ for i in range(n_models_for_test):
283
+ logging.info(f'Training best model {i + 1}/{n_models_for_test}')
284
+ model, preds, metrics = train_and_evaluate_xgboost(
285
+ protein2embedding=protein2embedding,
286
+ cell2embedding=cell2embedding,
287
+ smiles2fp=smiles2fp,
288
+ train_df=train_val_df,
289
+ val_df=test_df,
290
+ params=study.best_trial.user_attrs['params'],
291
+ active_label=active_label,
292
+ num_boost_round=num_boost_round,
293
+ shuffle_train_data=True,
294
+ )
295
+ metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
296
+ metrics['model_type'] = 'XGBoost'
297
+ metrics['test_model_id'] = i
298
+ metrics.update(get_dataframe_stats(
299
+ train_val_df,
300
+ test_df=test_df,
301
+ active_label=active_label,
302
+ ))
303
+ test_report.append(metrics.copy())
304
+ test_preds.append(torch.tensor(preds['val_pred']))
305
+ best_models.append(model)
306
+ test_report = pd.DataFrame(test_report)
307
+
308
+ # Get the majority vote for the test predictions
309
+ majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
310
+ majority_vote_report = pd.DataFrame([majority_vote_metrics])
311
+ majority_vote_report['model_type'] = 'XGBoost'
312
+
313
+ # Add a column with the split_type to all reports
314
+ for report in [cv_report, hparam_report, test_report, majority_vote_report]:
315
+ report['split_type'] = split_type
316
+
317
+ # Return the reports
318
+ return {
319
+ 'cv_report': cv_report,
320
+ 'hparam_report': hparam_report,
321
+ 'test_report': test_report,
322
+ 'majority_vote_report' :majority_vote_report,
323
+ }
protac_degradation_predictor/protac_dataset.py CHANGED
@@ -42,7 +42,11 @@ class PROTAC_Dataset(Dataset):
42
  cell2embedding (dict): Dictionary of cell line embeddings
43
  smiles2fp (dict): Dictionary of SMILES to fingerprint
44
  use_smote (bool): Whether to use SMOTE for oversampling
45
- use_ored_activity (bool): Whether to use the 'Active - OR' column
 
 
 
 
46
  """
47
  # Filter out examples with NaN in active_label column
48
  self.data = protac_df # [~protac_df[active_label].isna()]
@@ -124,7 +128,7 @@ class PROTAC_Dataset(Dataset):
124
  self.data = df_smote
125
 
126
  def fit_scaling(self, use_single_scaler: bool = False, **scaler_kwargs) -> dict:
127
- """ Fit the scalers for the data.
128
 
129
  Args:
130
  use_single_scaler (bool): Whether to use a single scaler for all features.
@@ -288,8 +292,25 @@ def get_datasets(
288
  disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
289
  scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
290
  use_single_scaler: Optional[bool] = None,
 
291
  ) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]:
292
- """ Get the datasets for training the PROTAC model. """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  oversampler = SMOTE(k_neighbors=smote_k_neighbors, random_state=42)
294
  train_ds = PROTAC_Dataset(
295
  train_df,
@@ -313,6 +334,10 @@ def get_datasets(
313
  scaler=train_ds.scaler if train_ds.scaler is not None else scaler,
314
  use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
315
  )
 
 
 
 
316
  if test_df is not None:
317
  test_ds = PROTAC_Dataset(
318
  test_df,
@@ -321,9 +346,11 @@ def get_datasets(
321
  smiles2fp,
322
  active_label=active_label,
323
  disabled_embeddings=disabled_embeddings,
324
- scaler=train_ds.scaler if train_ds.scaler is not None else scaler,
325
  use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
326
  )
 
 
327
  else:
328
  test_ds = None
329
  return train_ds, val_ds, test_ds
 
42
  cell2embedding (dict): Dictionary of cell line embeddings
43
  smiles2fp (dict): Dictionary of SMILES to fingerprint
44
  use_smote (bool): Whether to use SMOTE for oversampling
45
+ oversampler (SMOTE | ADASYN): The oversampler to use
46
+ active_label (str): The column containing the active/inactive information
47
+ disabled_embeddings (list): The list of embeddings to disable, i.e., return a zero vector
48
+ scaler (StandardScaler | dict): The scaler to use for the embeddings
49
+ use_single_scaler (bool): Whether to use a single scaler for all features
50
  """
51
  # Filter out examples with NaN in active_label column
52
  self.data = protac_df # [~protac_df[active_label].isna()]
 
128
  self.data = df_smote
129
 
130
  def fit_scaling(self, use_single_scaler: bool = False, **scaler_kwargs) -> dict:
131
+ """ Fit the scalers for the data and save them in the dataset class.
132
 
133
  Args:
134
  use_single_scaler (bool): Whether to use a single scaler for all features.
 
292
  disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
293
  scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
294
  use_single_scaler: Optional[bool] = None,
295
+ apply_scaling: bool = False,
296
  ) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]:
297
+ """ Get the datasets for training the PROTAC model.
298
+
299
+ Args:
300
+ train_df (pd.DataFrame): The training data.
301
+ val_df (pd.DataFrame): The validation data.
302
+ test_df (pd.DataFrame): The test data.
303
+ protein2embedding (dict): Dictionary of protein embeddings.
304
+ cell2embedding (dict): Dictionary of cell line embeddings.
305
+ smiles2fp (dict): Dictionary of SMILES to fingerprint.
306
+ use_smote (bool): Whether to use SMOTE for oversampling.
307
+ smote_k_neighbors (int): The number of neighbors to use for SMOTE.
308
+ active_label (str): The active label column.
309
+ disabled_embeddings (list): The list of embeddings to disable.
310
+ scaler (StandardScaler | dict): The scaler to use for the embeddings.
311
+ use_single_scaler (bool): Whether to use a single scaler for all features.
312
+ apply_scaling (bool): Whether to apply scaling to the data now. Defaults to False (the Pytorch Lightning model does that).
313
+ """
314
  oversampler = SMOTE(k_neighbors=smote_k_neighbors, random_state=42)
315
  train_ds = PROTAC_Dataset(
316
  train_df,
 
334
  scaler=train_ds.scaler if train_ds.scaler is not None else scaler,
335
  use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
336
  )
337
+ train_scalers = None
338
+ if apply_scaling:
339
+ train_scalers = train_ds.fit_scaling(use_single_scaler=use_single_scaler)
340
+ val_ds.apply_scaling(train_scalers, use_single_scaler=use_single_scaler)
341
  if test_df is not None:
342
  test_ds = PROTAC_Dataset(
343
  test_df,
 
346
  smiles2fp,
347
  active_label=active_label,
348
  disabled_embeddings=disabled_embeddings,
349
+ scaler=train_scalers if apply_scaling else scaler,
350
  use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
351
  )
352
+ if apply_scaling:
353
+ test_ds.apply_scaling(train_ds.scaler, use_single_scaler=use_single_scaler)
354
  else:
355
  test_ds = None
356
  return train_ds, val_ds, test_ds
reports/ablation_zero_vectors_report_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_random.csv ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ test_loss,test_acc,test_f1_score,test_precision,test_recall,test_roc_auc,train_len,train_active_perc,train_inactive_perc,train_avg_tanimoto_dist,test_len,test_active_perc,test_inactive_perc,test_avg_tanimoto_dist,num_leaking_uniprot_train_test,num_leaking_smiles_train_test,perc_leaking_uniprot_train_test,perc_leaking_smiles_train_test,majority_vote,model_type,disabled_embeddings,test_f1,split_type
2
+ 0.7269228100776672,0.604651153087616,0.6730769276618958,0.546875,0.875,0.7173913717269897,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3,,random
3
+ 0.6971672177314758,0.6162790656089783,0.5352112650871277,0.6129032373428345,0.4749999940395355,0.6717391014099121,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3,,random
4
+ 0.6542536020278931,0.6395348906517029,0.6436781883239746,0.5957446694374084,0.699999988079071,0.7141305208206177,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3,,random
5
+ ,0.6162790656089783,,0.6296296119689941,0.42500001192092896,0.689673900604248,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled e3,0.5074626803398132,random
6
+ 0.7447491884231567,0.5930232405662537,0.6534653306007385,0.5409836173057556,0.824999988079071,0.70923912525177,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi,,random
7
+ 0.7114118933677673,0.604651153087616,0.5405405163764954,0.5882353186607361,0.5,0.6630434989929199,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi,,random
8
+ 0.6734361052513123,0.6162790656089783,0.6373626589775085,0.5686274766921997,0.7250000238418579,0.6940217614173889,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi,,random
9
+ ,0.5930232405662537,,0.5806451439857483,0.44999998807907104,0.6809782981872559,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled poi,0.5070422291755676,random
10
+ 0.7288045883178711,0.6162790656089783,0.6796116232872009,0.5555555820465088,0.875,0.717663049697876,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled cell,,random
11
+ 0.6981603503227234,0.6395348906517029,0.5866666436195374,0.6285714507102966,0.550000011920929,0.6709238886833191,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled cell,,random
12
+ 0.6586534380912781,0.6395348906517029,0.6436781883239746,0.5957446694374084,0.699999988079071,0.7122282385826111,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled cell,,random
13
+ ,0.6279069781303406,,0.6333333253860474,0.4749999940395355,0.688858687877655,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled cell,0.5428571701049805,random
14
+ 0.7676423788070679,0.4651162922382355,0.6349206566810608,0.4651162922382355,1.0,0.7361413240432739,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled smiles,,random
15
+ 0.7521520256996155,0.5348837375640869,0.0,0.0,0.0,0.7638586759567261,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled smiles,,random
16
+ 0.7137073278427124,0.5930232405662537,0.2857142984867096,0.7777777910232544,0.17499999701976776,0.727989137172699,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled smiles,,random
17
+ ,0.5348837375640869,,0.0,0.0,0.7638587951660156,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled smiles,0.0,random
18
+ 0.7207046151161194,0.6162790656089783,0.6796116232872009,0.5555555820465088,0.875,0.7160326242446899,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3 cell,,random
19
+ 0.6998258829116821,0.6162790656089783,0.5352112650871277,0.6129032373428345,0.4749999940395355,0.6720108985900879,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3 cell,,random
20
+ 0.6533703207969666,0.6395348906517029,0.6436781883239746,0.5957446694374084,0.699999988079071,0.7122282385826111,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled e3 cell,,random
21
+ ,0.6162790656089783,,0.6296296119689941,0.42500001192092896,0.688858687877655,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled e3 cell,0.5074626803398132,random
22
+ 0.7362547516822815,0.5930232405662537,0.6534653306007385,0.5409836173057556,0.824999988079071,0.710326075553894,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3,,random
23
+ 0.7125736474990845,0.6162790656089783,0.5479452013969421,0.6060606241226196,0.5,0.6619565486907959,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3,,random
24
+ 0.6676729321479797,0.6395348906517029,0.6436781883239746,0.5957446694374084,0.699999988079071,0.6945651769638062,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3,,random
25
+ ,0.6162790656089783,,0.6206896305084229,0.44999998807907104,0.6836956143379211,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled poi e3,0.52173912525177,random
26
+ 0.7300900816917419,0.5930232405662537,0.6534653306007385,0.5409836173057556,0.824999988079071,0.706793487071991,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3 cell,,random
27
+ 0.7153109908103943,0.6162790656089783,0.5352112650871277,0.6129032373428345,0.4749999940395355,0.6611412763595581,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3 cell,,random
28
+ 0.6669936180114746,0.6279069781303406,0.6279069781303406,0.5869565010070801,0.675000011920929,0.6932065486907959,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,False,Pytorch,disabled poi e3 cell,,random
29
+ ,0.6162790656089783,,0.6296296119689941,0.42500001192092896,0.6834239363670349,771,0.5149156939040207,0.48508430609597925,0.3768059369269877,86,0.46511627906976744,0.5348837209302325,0.38114673659326254,34,44,0.8326848249027238,0.10246433203631647,True,Pytorch,disabled poi e3 cell,0.5074626803398132,random
reports/ablation_zero_vectors_report_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_tanimoto.csv ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ test_loss,test_acc,test_f1_score,test_precision,test_recall,test_roc_auc,train_len,train_active_perc,train_inactive_perc,train_avg_tanimoto_dist,test_len,test_active_perc,test_inactive_perc,test_avg_tanimoto_dist,num_leaking_uniprot_train_test,num_leaking_smiles_train_test,perc_leaking_uniprot_train_test,perc_leaking_smiles_train_test,majority_vote,model_type,disabled_embeddings,test_f1,split_type
2
+ 0.8296061754226685,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7832207083702087,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3,,tanimoto
3
+ 0.6474169492721558,0.6000000238418579,0.6600000262260437,0.523809552192688,0.8918918967247009,0.7668918371200562,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3,,tanimoto
4
+ 0.6295721530914307,0.7529411911964417,0.7042253613471985,0.7352941036224365,0.6756756901741028,0.8141891956329346,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3,,tanimoto
5
+ ,0.7529411911964417,,0.75,0.6486486196517944,0.8023648858070374,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled e3,0.695652186870575,tanimoto
6
+ 0.8408050537109375,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7691441774368286,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi,,tanimoto
7
+ 0.6602048277854919,0.5764706134796143,0.6470588445663452,0.5076923370361328,0.8918918967247009,0.7494369745254517,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi,,tanimoto
8
+ 0.634836733341217,0.7411764860153198,0.6944444179534912,0.7142857313156128,0.6756756901741028,0.7849099636077881,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi,,tanimoto
9
+ ,0.7411764860153198,,0.7272727489471436,0.6486486196517944,0.7770270109176636,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled poi,0.6857143044471741,tanimoto
10
+ 0.835131824016571,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7736486196517944,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled cell,,tanimoto
11
+ 0.6562066674232483,0.5882353186607361,0.6534653306007385,0.515625,0.8918918967247009,0.75,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled cell,,tanimoto
12
+ 0.6323299407958984,0.729411780834198,0.6760563254356384,0.7058823704719543,0.6486486196517944,0.8001126050949097,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled cell,,tanimoto
13
+ ,0.729411780834198,,0.71875,0.6216216087341309,0.7905405163764954,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled cell,0.6666666865348816,tanimoto
14
+ 0.8332716226577759,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.798704981803894,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled smiles,,tanimoto
15
+ 0.765400767326355,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7919481992721558,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled smiles,,tanimoto
16
+ 0.6887043118476868,0.4941176474094391,0.632478654384613,0.4625000059604645,1.0,0.8110923171043396,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled smiles,,tanimoto
17
+ ,0.4941176474094391,,0.4625000059604645,1.0,0.8110923767089844,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled smiles,0.632478654384613,tanimoto
18
+ 0.825886070728302,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7787162065505981,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3 cell,,tanimoto
19
+ 0.6474983096122742,0.6000000238418579,0.6600000262260437,0.523809552192688,0.8918918967247009,0.7567567825317383,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3 cell,,tanimoto
20
+ 0.6309086680412292,0.7411764860153198,0.6857143044471741,0.7272727489471436,0.6486486196517944,0.8119369149208069,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled e3 cell,,tanimoto
21
+ ,0.7411764860153198,,0.7419354915618896,0.6216216087341309,0.8006756901741028,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled e3 cell,0.6764705777168274,tanimoto
22
+ 0.8314616680145264,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7697072625160217,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3,,tanimoto
23
+ 0.651317298412323,0.6117647290229797,0.6666666865348816,0.5322580933570862,0.8918918967247009,0.7488738894462585,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3,,tanimoto
24
+ 0.633421003818512,0.7529411911964417,0.7042253613471985,0.7352941036224365,0.6756756901741028,0.795045018196106,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3,,tanimoto
25
+ ,0.7529411911964417,,0.75,0.6486486196517944,0.7837837934494019,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled poi e3,0.695652186870575,tanimoto
26
+ 0.8277769088745117,0.43529412150382996,0.6065573692321777,0.43529412150382996,1.0,0.7629504203796387,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3 cell,,tanimoto
27
+ 0.6514514088630676,0.6000000238418579,0.6530612111091614,0.5245901346206665,0.8648648858070374,0.7438063025474548,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3 cell,,tanimoto
28
+ 0.6348393559455872,0.7411764860153198,0.6857143044471741,0.7272727489471436,0.6486486196517944,0.7837837934494019,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,False,Pytorch,disabled poi e3 cell,,tanimoto
29
+ ,0.7411764860153198,,0.7419354915618896,0.6216216087341309,0.7742117047309875,772,0.5181347150259067,0.48186528497409326,0.37254018872057115,85,0.43529411764705883,0.5647058823529412,0.4199408355934975,22,0,0.5699481865284974,0.0,True,Pytorch,disabled poi e3 cell,0.6764705777168274,tanimoto
reports/ablation_zero_vectors_report_Active_Dmax_0.6_pDC50_6.0_test_split_0.1_uniprot.csv ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ test_loss,test_acc,test_f1_score,test_precision,test_recall,test_roc_auc,train_len,train_active_perc,train_inactive_perc,train_avg_tanimoto_dist,test_len,test_active_perc,test_inactive_perc,test_avg_tanimoto_dist,num_leaking_uniprot_train_test,num_leaking_smiles_train_test,perc_leaking_uniprot_train_test,perc_leaking_smiles_train_test,majority_vote,model_type,disabled_embeddings,test_f1,split_type
2
+ 0.7041562795639038,0.4588235318660736,0.0,0.0,0.0,0.5156075954437256,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3,,uniprot
3
+ 0.6916469931602478,0.4588235318660736,0.0,0.0,0.0,0.4420289397239685,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3,,uniprot
4
+ 0.6960257887840271,0.4588235318660736,0.0,0.0,0.0,0.4303233027458191,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3,,uniprot
5
+ ,0.4588235318660736,,0.0,0.0,0.5156075954437256,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled e3,0.0,uniprot
6
+ 0.7039564251899719,0.4588235318660736,0.0,0.0,0.0,0.532608687877655,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi,,uniprot
7
+ 0.6913965940475464,0.4588235318660736,0.0,0.0,0.0,0.46739131212234497,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi,,uniprot
8
+ 0.6957095265388489,0.4588235318660736,0.0,0.0,0.0,0.45234110951423645,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi,,uniprot
9
+ ,0.4588235318660736,,0.0,0.0,0.532608687877655,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled poi,0.0,uniprot
10
+ 0.7036164402961731,0.4588235318660736,0.0,0.0,0.0,0.530379056930542,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled cell,,uniprot
11
+ 0.6914005875587463,0.4588235318660736,0.0,0.0,0.0,0.48188406229019165,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled cell,,uniprot
12
+ 0.695412814617157,0.4588235318660736,0.0,0.0,0.0,0.4763098955154419,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled cell,,uniprot
13
+ ,0.4588235318660736,,0.0,0.0,0.530379056930542,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled cell,0.0,uniprot
14
+ 0.697465717792511,0.4588235318660736,0.0,0.0,0.0,0.6223523020744324,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled smiles,,uniprot
15
+ 0.6916133761405945,0.4588235318660736,0.0,0.0,0.0,0.6636008620262146,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled smiles,,uniprot
16
+ 0.6932395696640015,0.4588235318660736,0.0,0.0,0.0,0.651337742805481,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled smiles,,uniprot
17
+ ,0.4588235318660736,,0.0,0.0,0.6223522424697876,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled smiles,0.0,uniprot
18
+ 0.704821765422821,0.4588235318660736,0.0,0.0,0.0,0.518673300743103,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3 cell,,uniprot
19
+ 0.6916972398757935,0.4588235318660736,0.0,0.0,0.0,0.45234113931655884,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3 cell,,uniprot
20
+ 0.6962708830833435,0.4588235318660736,0.0,0.0,0.0,0.42892974615097046,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled e3 cell,,uniprot
21
+ ,0.4588235318660736,,0.0,0.0,0.5186733603477478,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled e3 cell,0.0,uniprot
22
+ 0.7051585912704468,0.4588235318660736,0.0,0.0,0.0,0.5103121399879456,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3,,uniprot
23
+ 0.6916910409927368,0.4588235318660736,0.0,0.0,0.0,0.44732439517974854,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3,,uniprot
24
+ 0.6965663433074951,0.4588235318660736,0.0,0.0,0.0,0.40328872203826904,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3,,uniprot
25
+ ,0.4588235318660736,,0.0,0.0,0.5103121399879456,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled poi e3,0.0,uniprot
26
+ 0.7058382034301758,0.4588235318660736,0.0,0.0,0.0,0.5080825090408325,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3 cell,,uniprot
27
+ 0.6917427778244019,0.4588235318660736,0.0,0.0,0.0,0.450111448764801,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3 cell,,uniprot
28
+ 0.6968205571174622,0.4588235318660736,0.0,0.0,0.0,0.4155518114566803,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,False,Pytorch,disabled poi e3 cell,,uniprot
29
+ ,0.4588235318660736,,0.0,0.0,0.5080825090408325,772,0.5064766839378239,0.49352331606217614,0.3753049487934892,85,0.5411764705882353,0.4588235294117647,0.39483030881358294,0,6,0.0,0.011658031088082901,True,Pytorch,disabled poi e3 cell,0.0,uniprot
src/run_xgboost_experiments.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from collections import defaultdict
4
+ import warnings
5
+ import logging
6
+ from typing import Literal
7
+
8
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9
+
10
+ import protac_degradation_predictor as pdp
11
+ from protac_degradation_predictor.optuna_utils import get_dataframe_stats
12
+
13
+ import pytorch_lightning as pl
14
+ from rdkit import Chem
15
+ from rdkit.Chem import AllChem
16
+ from rdkit import DataStructs
17
+ from jsonargparse import CLI
18
+ import pandas as pd
19
+ from tqdm import tqdm
20
+ import numpy as np
21
+ from sklearn.preprocessing import OrdinalEncoder
22
+ from sklearn.model_selection import (
23
+ StratifiedKFold,
24
+ StratifiedGroupKFold,
25
+ )
26
+
27
+ # Ignore UserWarning from Matplotlib
28
+ warnings.filterwarnings("ignore", ".*FixedLocator*")
29
+ # Ignore UserWarning from PyTorch Lightning
30
+ warnings.filterwarnings("ignore", ".*does not have many workers.*")
31
+
32
+
33
+ root = logging.getLogger()
34
+ root.setLevel(logging.DEBUG)
35
+
36
+ handler = logging.StreamHandler(sys.stdout)
37
+ handler.setLevel(logging.DEBUG)
38
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
39
+ handler.setFormatter(formatter)
40
+ root.addHandler(handler)
41
+
42
+
43
+ def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index:
44
+ """ Get the indices of the test set using a random split.
45
+
46
+ Args:
47
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
48
+ test_split (float): The percentage of the active PROTACs to use as the test set.
49
+
50
+ Returns:
51
+ pd.Index: The indices of the test set.
52
+ """
53
+ test_df = active_df.sample(frac=test_split, random_state=42)
54
+ return test_df.index
55
+
56
+
57
+ def get_e3_ligase_split_indices(active_df: pd.DataFrame) -> pd.Index:
58
+ """ Get the indices of the test set using the E3 ligase split.
59
+
60
+ Args:
61
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
62
+
63
+ Returns:
64
+ pd.Index: The indices of the test set.
65
+ """
66
+ encoder = OrdinalEncoder()
67
+ active_df['E3 Group'] = encoder.fit_transform(active_df[['E3 Ligase']]).astype(int)
68
+ test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')]
69
+ return test_df.index
70
+
71
+
72
+ def get_smiles2fp_and_avg_tanimoto(protac_df: pd.DataFrame) -> tuple:
73
+ """ Get the SMILES to fingerprint dictionary and the average Tanimoto similarity.
74
+
75
+ Args:
76
+ protac_df (pd.DataFrame): The DataFrame containing the PROTACs.
77
+
78
+ Returns:
79
+ tuple: The SMILES to fingerprint dictionary and the average Tanimoto similarity.
80
+ """
81
+ unique_smiles = protac_df['Smiles'].unique().tolist()
82
+
83
+ smiles2fp = {}
84
+ for smiles in tqdm(unique_smiles, desc='Precomputing fingerprints'):
85
+ smiles2fp[smiles] = pdp.get_fingerprint(smiles)
86
+
87
+ # # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
88
+ # tanimoto_matrix = defaultdict(list)
89
+ # for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
90
+ # fp1 = smiles2fp[smiles1]
91
+ # # TODO: Use BulkTanimotoSimilarity for better performance
92
+ # for j, smiles2 in enumerate(protac_df['Smiles'].unique()[i:]):
93
+ # fp2 = smiles2fp[smiles2]
94
+ # tanimoto_dist = 1 - DataStructs.TanimotoSimilarity(fp1, fp2)
95
+ # tanimoto_matrix[smiles1].append(tanimoto_dist)
96
+ # avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
97
+ # protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
98
+
99
+
100
+ tanimoto_matrix = defaultdict(list)
101
+ fps = list(smiles2fp.values())
102
+
103
+ # Compute all-against-all Tanimoto similarity using BulkTanimotoSimilarity
104
+ for i, (smiles1, fp1) in enumerate(tqdm(zip(unique_smiles, fps), desc='Computing Tanimoto similarity', total=len(fps))):
105
+ similarities = DataStructs.BulkTanimotoSimilarity(fp1, fps[i:]) # Only compute for i to end, avoiding duplicates
106
+ for j, similarity in enumerate(similarities):
107
+ distance = 1 - similarity
108
+ tanimoto_matrix[smiles1].append(distance) # Store as distance
109
+ if i != i + j:
110
+ tanimoto_matrix[unique_smiles[i + j]].append(distance) # Symmetric filling
111
+
112
+ # Calculate average Tanimoto distance for each unique SMILES
113
+ avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
114
+ protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
115
+
116
+ smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
117
+
118
+ return smiles2fp, protac_df
119
+
120
+
121
+ def get_tanimoto_split_indices(
122
+ active_df: pd.DataFrame,
123
+ active_col: str,
124
+ test_split: float,
125
+ n_bins_tanimoto: int = 200,
126
+ ) -> pd.Index:
127
+ """ Get the indices of the test set using the Tanimoto-based split.
128
+
129
+ Args:
130
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
131
+ n_bins_tanimoto (int): The number of bins to use for the Tanimoto similarity.
132
+
133
+ Returns:
134
+ pd.Index: The indices of the test set.
135
+ """
136
+ tanimoto_groups = pd.cut(active_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
137
+ encoder = OrdinalEncoder()
138
+ active_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
139
+ # Sort the groups so that samples with the highest tanimoto similarity,
140
+ # i.e., the "less similar" ones, are placed in the test set first
141
+ tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
142
+
143
+ test_df = []
144
+ # For each group, get the number of active and inactive entries. Then, add those
145
+ # entries to the test_df if: 1) the test_df lenght + the group entries is less
146
+ # 20% of the active_df lenght, and 2) the percentage of True and False entries
147
+ # in the active_col in test_df is roughly 50%.
148
+ for group in tanimoto_groups:
149
+ group_df = active_df[active_df['Tanimoto Group'] == group]
150
+ if test_df == []:
151
+ test_df.append(group_df)
152
+ continue
153
+
154
+ num_entries = len(group_df)
155
+ num_active_group = group_df[active_col].sum()
156
+ num_inactive_group = num_entries - num_active_group
157
+
158
+ tmp_test_df = pd.concat(test_df)
159
+ num_entries_test = len(tmp_test_df)
160
+ num_active_test = tmp_test_df[active_col].sum()
161
+ num_inactive_test = num_entries_test - num_active_test
162
+
163
+ # Check if the group entries can be added to the test_df
164
+ if num_entries_test + num_entries < test_split * len(active_df):
165
+ # Add anything at the beggining
166
+ if num_entries_test + num_entries < test_split / 2 * len(active_df):
167
+ test_df.append(group_df)
168
+ continue
169
+ # Be more selective and make sure that the percentage of active and
170
+ # inactive is balanced
171
+ if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
172
+ if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
173
+ test_df.append(group_df)
174
+ test_df = pd.concat(test_df)
175
+ return test_df.index
176
+
177
+
178
+ def get_target_split_indices(active_df: pd.DataFrame, active_col: str, test_split: float) -> pd.Index:
179
+ """ Get the indices of the test set using the target-based split.
180
+
181
+ Args:
182
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
183
+ active_col (str): The column containing the active/inactive information.
184
+ test_split (float): The percentage of the active PROTACs to use as the test set.
185
+
186
+ Returns:
187
+ pd.Index: The indices of the test set.
188
+ """
189
+ encoder = OrdinalEncoder()
190
+ active_df['Uniprot Group'] = encoder.fit_transform(active_df[['Uniprot']]).astype(int)
191
+
192
+ test_df = []
193
+ # For each group, get the number of active and inactive entries. Then, add those
194
+ # entries to the test_df if: 1) the test_df lenght + the group entries is less
195
+ # 20% of the active_df lenght, and 2) the percentage of True and False entries
196
+ # in the active_col in test_df is roughly 50%.
197
+ # Start the loop from the groups containing the smallest number of entries.
198
+ for group in reversed(active_df['Uniprot'].value_counts().index):
199
+ group_df = active_df[active_df['Uniprot'] == group]
200
+ if test_df == []:
201
+ test_df.append(group_df)
202
+ continue
203
+
204
+ num_entries = len(group_df)
205
+ num_active_group = group_df[active_col].sum()
206
+ num_inactive_group = num_entries - num_active_group
207
+
208
+ tmp_test_df = pd.concat(test_df)
209
+ num_entries_test = len(tmp_test_df)
210
+ num_active_test = tmp_test_df[active_col].sum()
211
+ num_inactive_test = num_entries_test - num_active_test
212
+
213
+ # Check if the group entries can be added to the test_df
214
+ if num_entries_test + num_entries < test_split * len(active_df):
215
+ # Add anything at the beggining
216
+ if num_entries_test + num_entries < test_split / 2 * len(active_df):
217
+ test_df.append(group_df)
218
+ continue
219
+ # Be more selective and make sure that the percentage of active and
220
+ # inactive is balanced
221
+ if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
222
+ if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
223
+ test_df.append(group_df)
224
+ test_df = pd.concat(test_df)
225
+ return test_df.index
226
+
227
+
228
+ def main(
229
+ active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
230
+ n_trials: int = 100,
231
+ test_split: float = 0.1,
232
+ cv_n_splits: int = 5,
233
+ num_boost_round: int = 100,
234
+ force_study: bool = False,
235
+ experiments: str | Literal['all', 'random', 'e3_ligase', 'tanimoto', 'uniprot'] = 'all',
236
+ ):
237
+ """ Train a PROTAC model using the given datasets and hyperparameters.
238
+
239
+ Args:
240
+ use_ored_activity (bool): Whether to use the 'Active - OR' column.
241
+ n_trials (int): The number of hyperparameter optimization trials.
242
+ n_splits (int): The number of cross-validation splits.
243
+ fast_dev_run (bool): Whether to run a fast development run.
244
+ """
245
+ pl.seed_everything(42)
246
+
247
+ # Set the Column to Predict
248
+ active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
249
+
250
+ # Get Dmax_threshold from the active_col
251
+ Dmax_threshold = float(active_col.split('Dmax')[1].split(',')[0].strip('(').strip(')').strip())
252
+ pDC50_threshold = float(active_col.split('pDC50')[1].strip('(').strip(')').strip())
253
+
254
+ # Load the PROTAC dataset
255
+ protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
256
+ # Map E3 Ligase Iap to IAP
257
+ protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
258
+ protac_df[active_col] = protac_df.apply(
259
+ lambda x: pdp.is_active(x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold), axis=1
260
+ )
261
+ smiles2fp, protac_df = get_smiles2fp_and_avg_tanimoto(protac_df)
262
+
263
+ ## Get the test sets
264
+ test_indeces = {}
265
+ active_df = protac_df[protac_df[active_col].notna()].copy()
266
+
267
+ if experiments == 'random' or experiments == 'all':
268
+ test_indeces['random'] = get_random_split_indices(active_df, test_split)
269
+ if experiments == 'uniprot' or experiments == 'all':
270
+ test_indeces['uniprot'] = get_target_split_indices(active_df, active_col, test_split)
271
+ if experiments == 'e3_ligase' or experiments == 'all':
272
+ test_indeces['e3_ligase'] = get_e3_ligase_split_indices(active_df)
273
+ if experiments == 'tanimoto' or experiments == 'all':
274
+ test_indeces['tanimoto'] = get_tanimoto_split_indices(active_df, active_col, test_split)
275
+
276
+ # Make directory ../reports if it does not exist
277
+ if not os.path.exists('../reports'):
278
+ os.makedirs('../reports')
279
+
280
+ # Load embedding dictionaries
281
+ protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
282
+ cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
283
+
284
+ # Cross-Validation Training
285
+ reports = defaultdict(list)
286
+ for split_type, indeces in test_indeces.items():
287
+ test_df = active_df.loc[indeces].copy()
288
+ train_val_df = active_df[~active_df.index.isin(test_df.index)].copy()
289
+
290
+ # Get the CV object
291
+ if split_type == 'random':
292
+ kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
293
+ group = None
294
+ elif split_type == 'e3_ligase':
295
+ kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
296
+ group = train_val_df['E3 Group'].to_numpy()
297
+ elif split_type == 'tanimoto':
298
+ kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
299
+ group = train_val_df['Tanimoto Group'].to_numpy()
300
+ elif split_type == 'uniprot':
301
+ kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
302
+ group = train_val_df['Uniprot Group'].to_numpy()
303
+
304
+ # Start the experiment
305
+ experiment_name = f'{active_name}_test_split_{test_split}_{split_type}'
306
+ optuna_reports = pdp.xgboost_hyperparameter_tuning_and_training(
307
+ protein2embedding=protein2embedding,
308
+ cell2embedding=cell2embedding,
309
+ smiles2fp=smiles2fp,
310
+ train_val_df=train_val_df,
311
+ test_df=test_df,
312
+ kf=kf,
313
+ groups=group,
314
+ split_type=split_type,
315
+ n_models_for_test=3,
316
+ n_trials=n_trials,
317
+ active_label=active_col,
318
+ num_boost_round=num_boost_round,
319
+ study_filename=f'../reports/study_xgboost_{experiment_name}.pkl',
320
+ force_study=force_study,
321
+ )
322
+
323
+ # Save the reports to file
324
+ for report_name, report in optuna_reports.items():
325
+ report.to_csv(f'../reports/xgboost_{report_name}_{experiment_name}.csv', index=False)
326
+ reports[report_name].append(report.copy())
327
+
328
+ if __name__ == '__main__':
329
+ cli = CLI(main)