Commit
·
bda3015
1
Parent(s):
fda7af7
Added softmax in model + Fixed and updated some hparams + Fixed bug in tanimoto distance (it was treated as similarity)
Browse files
protac_degradation_predictor/optuna_utils.py
CHANGED
@@ -55,14 +55,17 @@ def get_dataframe_stats(
|
|
55 |
stats['train_len'] = len(train_df)
|
56 |
stats['train_active_perc'] = train_df[active_label].sum() / len(train_df)
|
57 |
stats['train_inactive_perc'] = (len(train_df) - train_df[active_label].sum()) / len(train_df)
|
|
|
58 |
if val_df is not None:
|
59 |
stats['val_len'] = len(val_df)
|
60 |
stats['val_active_perc'] = val_df[active_label].sum() / len(val_df)
|
61 |
stats['val_inactive_perc'] = (len(val_df) - val_df[active_label].sum()) / len(val_df)
|
|
|
62 |
if test_df is not None:
|
63 |
stats['test_len'] = len(test_df)
|
64 |
stats['test_active_perc'] = test_df[active_label].sum() / len(test_df)
|
65 |
stats['test_inactive_perc'] = (len(test_df) - test_df[active_label].sum()) / len(test_df)
|
|
|
66 |
if train_df is not None and val_df is not None:
|
67 |
leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(val_df['Uniprot'])))
|
68 |
leaking_smiles = list(set(train_df['Smiles']).intersection(set(val_df['Smiles'])))
|
@@ -98,6 +101,10 @@ def pytorch_model_objective(
|
|
98 |
active_label: str = 'Active',
|
99 |
disabled_embeddings: List[str] = [],
|
100 |
max_epochs: int = 100,
|
|
|
|
|
|
|
|
|
101 |
) -> float:
|
102 |
""" Objective function for hyperparameter optimization.
|
103 |
|
@@ -116,11 +123,11 @@ def pytorch_model_objective(
|
|
116 |
"""
|
117 |
# Suggest hyperparameters to be used accross the CV folds
|
118 |
hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
|
119 |
-
batch_size = trial.suggest_categorical('batch_size', batch_size_options)
|
120 |
learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
|
121 |
smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
|
122 |
use_smote = trial.suggest_categorical('use_smote', [True, False])
|
123 |
-
apply_scaling = trial.suggest_categorical('apply_scaling', [True, False])
|
124 |
dropout = trial.suggest_float('dropout', *dropout_options)
|
125 |
|
126 |
# Start the CV over the folds
|
@@ -166,11 +173,14 @@ def pytorch_model_objective(
|
|
166 |
smote_k_neighbors=smote_k_neighbors,
|
167 |
apply_scaling=apply_scaling,
|
168 |
use_smote=use_smote,
|
169 |
-
use_logger=False,
|
170 |
fast_dev_run=fast_dev_run,
|
171 |
active_label=active_label,
|
172 |
return_predictions=True,
|
173 |
disabled_embeddings=disabled_embeddings,
|
|
|
|
|
|
|
|
|
174 |
)
|
175 |
if test_df is not None:
|
176 |
_, _, metrics, val_pred, test_pred = ret
|
@@ -246,11 +256,13 @@ def hyperparameter_tuning_and_training(
|
|
246 |
pl.seed_everything(42)
|
247 |
|
248 |
# Define the search space
|
249 |
-
hidden_dim_options = [32, 64, 128, 256, 512
|
250 |
-
batch_size_options = [4, 8, 16, 32, 64, 128]
|
251 |
-
learning_rate_options = (1e-
|
252 |
smote_k_neighbors_options = list(range(3, 16))
|
253 |
-
|
|
|
|
|
254 |
|
255 |
# Set the verbosity of Optuna
|
256 |
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
@@ -293,6 +305,31 @@ def hyperparameter_tuning_and_training(
|
|
293 |
cv_report = pd.DataFrame(study.best_trial.user_attrs['report'])
|
294 |
hparam_report = pd.DataFrame([study.best_params])
|
295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
# Retrain N models with the best hyperparameters (measure model uncertainty)
|
297 |
test_report = []
|
298 |
test_preds = []
|
@@ -315,6 +352,8 @@ def hyperparameter_tuning_and_training(
|
|
315 |
enable_checkpointing=True,
|
316 |
checkpoint_model_name=f'best_model_n{i}_{split_type}',
|
317 |
return_predictions=True,
|
|
|
|
|
318 |
**study.best_params,
|
319 |
)
|
320 |
# Rename the keys in the metrics dictionary
|
@@ -371,6 +410,8 @@ def hyperparameter_tuning_and_training(
|
|
371 |
logger_save_dir=logger_save_dir,
|
372 |
logger_name=f'{logger_name}_disabled-{"-".join(disabled_embeddings)}',
|
373 |
disabled_embeddings=disabled_embeddings,
|
|
|
|
|
374 |
**study.best_params,
|
375 |
)
|
376 |
# Rename the keys in the metrics dictionary
|
|
|
55 |
stats['train_len'] = len(train_df)
|
56 |
stats['train_active_perc'] = train_df[active_label].sum() / len(train_df)
|
57 |
stats['train_inactive_perc'] = (len(train_df) - train_df[active_label].sum()) / len(train_df)
|
58 |
+
stats['train_avg_tanimoto_dist'] = train_df['Avg Tanimoto'].mean()
|
59 |
if val_df is not None:
|
60 |
stats['val_len'] = len(val_df)
|
61 |
stats['val_active_perc'] = val_df[active_label].sum() / len(val_df)
|
62 |
stats['val_inactive_perc'] = (len(val_df) - val_df[active_label].sum()) / len(val_df)
|
63 |
+
stats['val_avg_tanimoto_dist'] = val_df['Avg Tanimoto'].mean()
|
64 |
if test_df is not None:
|
65 |
stats['test_len'] = len(test_df)
|
66 |
stats['test_active_perc'] = test_df[active_label].sum() / len(test_df)
|
67 |
stats['test_inactive_perc'] = (len(test_df) - test_df[active_label].sum()) / len(test_df)
|
68 |
+
stats['test_avg_tanimoto_dist'] = test_df['Avg Tanimoto'].mean()
|
69 |
if train_df is not None and val_df is not None:
|
70 |
leaking_uniprot = list(set(train_df['Uniprot']).intersection(set(val_df['Uniprot'])))
|
71 |
leaking_smiles = list(set(train_df['Smiles']).intersection(set(val_df['Smiles'])))
|
|
|
101 |
active_label: str = 'Active',
|
102 |
disabled_embeddings: List[str] = [],
|
103 |
max_epochs: int = 100,
|
104 |
+
use_logger: bool = False,
|
105 |
+
logger_save_dir: str = 'logs',
|
106 |
+
logger_name: str = 'cv_model',
|
107 |
+
enable_checkpointing: bool = False,
|
108 |
) -> float:
|
109 |
""" Objective function for hyperparameter optimization.
|
110 |
|
|
|
123 |
"""
|
124 |
# Suggest hyperparameters to be used accross the CV folds
|
125 |
hidden_dim = trial.suggest_categorical('hidden_dim', hidden_dim_options)
|
126 |
+
batch_size = 128 # trial.suggest_categorical('batch_size', batch_size_options)
|
127 |
learning_rate = trial.suggest_float('learning_rate', *learning_rate_options, log=True)
|
128 |
smote_k_neighbors = trial.suggest_categorical('smote_k_neighbors', smote_k_neighbors_options)
|
129 |
use_smote = trial.suggest_categorical('use_smote', [True, False])
|
130 |
+
apply_scaling = True # trial.suggest_categorical('apply_scaling', [True, False])
|
131 |
dropout = trial.suggest_float('dropout', *dropout_options)
|
132 |
|
133 |
# Start the CV over the folds
|
|
|
173 |
smote_k_neighbors=smote_k_neighbors,
|
174 |
apply_scaling=apply_scaling,
|
175 |
use_smote=use_smote,
|
|
|
176 |
fast_dev_run=fast_dev_run,
|
177 |
active_label=active_label,
|
178 |
return_predictions=True,
|
179 |
disabled_embeddings=disabled_embeddings,
|
180 |
+
use_logger=use_logger,
|
181 |
+
logger_save_dir=logger_save_dir,
|
182 |
+
logger_name=f'{logger_name}_fold{k}',
|
183 |
+
enable_checkpointing=enable_checkpointing,
|
184 |
)
|
185 |
if test_df is not None:
|
186 |
_, _, metrics, val_pred, test_pred = ret
|
|
|
256 |
pl.seed_everything(42)
|
257 |
|
258 |
# Define the search space
|
259 |
+
hidden_dim_options = [32, 64, 128, 256, 512]
|
260 |
+
batch_size_options = [128, 128] # [4, 8, 16, 32, 64, 128]
|
261 |
+
learning_rate_options = (1e-6, 1e-3) # min and max values for loguniform distribution
|
262 |
smote_k_neighbors_options = list(range(3, 16))
|
263 |
+
# NOTE: We want Optuna to explore the combination (very low dropout, very
|
264 |
+
# small hidden_dim)
|
265 |
+
dropout_options = (0, 0.5)
|
266 |
|
267 |
# Set the verbosity of Optuna
|
268 |
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
|
|
305 |
cv_report = pd.DataFrame(study.best_trial.user_attrs['report'])
|
306 |
hparam_report = pd.DataFrame([study.best_params])
|
307 |
|
308 |
+
# Train the best CV models and store their checkpoints by running the objective
|
309 |
+
pytorch_model_objective(
|
310 |
+
trial=study.best_trial,
|
311 |
+
protein2embedding=protein2embedding,
|
312 |
+
cell2embedding=cell2embedding,
|
313 |
+
smiles2fp=smiles2fp,
|
314 |
+
train_val_df=train_val_df,
|
315 |
+
kf=kf,
|
316 |
+
groups=groups,
|
317 |
+
test_df=test_df,
|
318 |
+
hidden_dim_options=hidden_dim_options,
|
319 |
+
batch_size_options=batch_size_options,
|
320 |
+
learning_rate_options=learning_rate_options,
|
321 |
+
smote_k_neighbors_options=smote_k_neighbors_options,
|
322 |
+
dropout_options=dropout_options,
|
323 |
+
fast_dev_run=fast_dev_run,
|
324 |
+
active_label=active_label,
|
325 |
+
max_epochs=max_epochs,
|
326 |
+
disabled_embeddings=[],
|
327 |
+
use_logger=True,
|
328 |
+
logger_save_dir=logger_save_dir,
|
329 |
+
logger_name=f'{logger_name}_{split_type}_cv_model',
|
330 |
+
enable_checkpointing=True,
|
331 |
+
)
|
332 |
+
|
333 |
# Retrain N models with the best hyperparameters (measure model uncertainty)
|
334 |
test_report = []
|
335 |
test_preds = []
|
|
|
352 |
enable_checkpointing=True,
|
353 |
checkpoint_model_name=f'best_model_n{i}_{split_type}',
|
354 |
return_predictions=True,
|
355 |
+
batch_size=128,
|
356 |
+
apply_scaling=True,
|
357 |
**study.best_params,
|
358 |
)
|
359 |
# Rename the keys in the metrics dictionary
|
|
|
410 |
logger_save_dir=logger_save_dir,
|
411 |
logger_name=f'{logger_name}_disabled-{"-".join(disabled_embeddings)}',
|
412 |
disabled_embeddings=disabled_embeddings,
|
413 |
+
batch_size=128,
|
414 |
+
apply_scaling=True,
|
415 |
**study.best_params,
|
416 |
)
|
417 |
# Rename the keys in the metrics dictionary
|
protac_degradation_predictor/pytorch_models.py
CHANGED
@@ -63,15 +63,29 @@ class PROTAC_Predictor(nn.Module):
|
|
63 |
self.__dict__.update(locals())
|
64 |
|
65 |
# Define "surrogate models" branches
|
|
|
|
|
66 |
if self.join_embeddings != 'beginning':
|
67 |
if 'poi' not in self.disabled_embeddings:
|
68 |
-
self.poi_emb = nn.
|
|
|
|
|
|
|
69 |
if 'e3' not in self.disabled_embeddings:
|
70 |
-
self.e3_emb = nn.
|
|
|
|
|
|
|
71 |
if 'cell' not in self.disabled_embeddings:
|
72 |
-
self.cell_emb = nn.
|
|
|
|
|
|
|
73 |
if 'smiles' not in self.disabled_embeddings:
|
74 |
-
self.smiles_emb = nn.
|
|
|
|
|
|
|
75 |
|
76 |
# Define hidden dimension for joining layer
|
77 |
if self.join_embeddings == 'beginning':
|
@@ -95,6 +109,7 @@ class PROTAC_Predictor(nn.Module):
|
|
95 |
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
|
96 |
embeddings = []
|
97 |
if self.join_embeddings == 'beginning':
|
|
|
98 |
if 'poi' not in self.disabled_embeddings:
|
99 |
embeddings.append(poi_emb)
|
100 |
if 'e3' not in self.disabled_embeddings:
|
@@ -123,7 +138,6 @@ class PROTAC_Predictor(nn.Module):
|
|
123 |
else:
|
124 |
x = embeddings[0]
|
125 |
x = self.dropout(F.relu(self.fc1(x)))
|
126 |
-
x = self.dropout(F.relu(self.fc2(x)))
|
127 |
x = self.fc3(x)
|
128 |
return x
|
129 |
|
@@ -137,7 +151,7 @@ class PROTAC_Model(pl.LightningModule):
|
|
137 |
poi_emb_dim: int = config.protein_embedding_size,
|
138 |
e3_emb_dim: int = config.protein_embedding_size,
|
139 |
cell_emb_dim: int = config.cell_embedding_size,
|
140 |
-
batch_size: int =
|
141 |
learning_rate: float = 1e-3,
|
142 |
dropout: float = 0.2,
|
143 |
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
|
@@ -145,7 +159,7 @@ class PROTAC_Model(pl.LightningModule):
|
|
145 |
val_dataset: PROTAC_Dataset = None,
|
146 |
test_dataset: PROTAC_Dataset = None,
|
147 |
disabled_embeddings: list = [],
|
148 |
-
apply_scaling: bool =
|
149 |
):
|
150 |
""" Initialize the PROTAC Pytorch Lightning model.
|
151 |
|
@@ -388,7 +402,7 @@ def train_model(
|
|
388 |
val_df: pd.DataFrame,
|
389 |
test_df: Optional[pd.DataFrame] = None,
|
390 |
hidden_dim: int = 768,
|
391 |
-
batch_size: int =
|
392 |
learning_rate: float = 2e-5,
|
393 |
dropout: float = 0.2,
|
394 |
max_epochs: int = 50,
|
@@ -399,7 +413,7 @@ def train_model(
|
|
399 |
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
|
400 |
smote_k_neighbors:int = 5,
|
401 |
use_smote: bool = True,
|
402 |
-
apply_scaling: bool =
|
403 |
active_label: str = 'Active',
|
404 |
fast_dev_run: bool = False,
|
405 |
use_logger: bool = True,
|
@@ -508,7 +522,7 @@ def train_model(
|
|
508 |
logger=loggers if use_logger else False,
|
509 |
callbacks=callbacks,
|
510 |
max_epochs=max_epochs,
|
511 |
-
val_check_interval=0.5,
|
512 |
fast_dev_run=fast_dev_run,
|
513 |
enable_model_summary=False,
|
514 |
enable_checkpointing=enable_checkpointing,
|
|
|
63 |
self.__dict__.update(locals())
|
64 |
|
65 |
# Define "surrogate models" branches
|
66 |
+
# NOTE: The softmax is used to ensure that the embeddings are normalized
|
67 |
+
# and can be summed on a "similar scale".
|
68 |
if self.join_embeddings != 'beginning':
|
69 |
if 'poi' not in self.disabled_embeddings:
|
70 |
+
self.poi_emb = nn.Sequential(
|
71 |
+
nn.Linear(poi_emb_dim, hidden_dim),
|
72 |
+
nn.Softmax(dim=1),
|
73 |
+
)
|
74 |
if 'e3' not in self.disabled_embeddings:
|
75 |
+
self.e3_emb = nn.Sequential(
|
76 |
+
nn.Linear(e3_emb_dim, hidden_dim),
|
77 |
+
nn.Softmax(dim=1),
|
78 |
+
)
|
79 |
if 'cell' not in self.disabled_embeddings:
|
80 |
+
self.cell_emb = nn.Sequential(
|
81 |
+
nn.Linear(cell_emb_dim, hidden_dim),
|
82 |
+
nn.Softmax(dim=1),
|
83 |
+
)
|
84 |
if 'smiles' not in self.disabled_embeddings:
|
85 |
+
self.smiles_emb = nn.Sequential(
|
86 |
+
nn.Linear(smiles_emb_dim, hidden_dim),
|
87 |
+
nn.Softmax(dim=1),
|
88 |
+
)
|
89 |
|
90 |
# Define hidden dimension for joining layer
|
91 |
if self.join_embeddings == 'beginning':
|
|
|
109 |
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
|
110 |
embeddings = []
|
111 |
if self.join_embeddings == 'beginning':
|
112 |
+
# TODO: Remove this if-branch
|
113 |
if 'poi' not in self.disabled_embeddings:
|
114 |
embeddings.append(poi_emb)
|
115 |
if 'e3' not in self.disabled_embeddings:
|
|
|
138 |
else:
|
139 |
x = embeddings[0]
|
140 |
x = self.dropout(F.relu(self.fc1(x)))
|
|
|
141 |
x = self.fc3(x)
|
142 |
return x
|
143 |
|
|
|
151 |
poi_emb_dim: int = config.protein_embedding_size,
|
152 |
e3_emb_dim: int = config.protein_embedding_size,
|
153 |
cell_emb_dim: int = config.cell_embedding_size,
|
154 |
+
batch_size: int = 128,
|
155 |
learning_rate: float = 1e-3,
|
156 |
dropout: float = 0.2,
|
157 |
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
|
|
|
159 |
val_dataset: PROTAC_Dataset = None,
|
160 |
test_dataset: PROTAC_Dataset = None,
|
161 |
disabled_embeddings: list = [],
|
162 |
+
apply_scaling: bool = True,
|
163 |
):
|
164 |
""" Initialize the PROTAC Pytorch Lightning model.
|
165 |
|
|
|
402 |
val_df: pd.DataFrame,
|
403 |
test_df: Optional[pd.DataFrame] = None,
|
404 |
hidden_dim: int = 768,
|
405 |
+
batch_size: int = 128,
|
406 |
learning_rate: float = 2e-5,
|
407 |
dropout: float = 0.2,
|
408 |
max_epochs: int = 50,
|
|
|
413 |
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
|
414 |
smote_k_neighbors:int = 5,
|
415 |
use_smote: bool = True,
|
416 |
+
apply_scaling: bool = True,
|
417 |
active_label: str = 'Active',
|
418 |
fast_dev_run: bool = False,
|
419 |
use_logger: bool = True,
|
|
|
522 |
logger=loggers if use_logger else False,
|
523 |
callbacks=callbacks,
|
524 |
max_epochs=max_epochs,
|
525 |
+
# val_check_interval=0.5,
|
526 |
fast_dev_run=fast_dev_run,
|
527 |
enable_model_summary=False,
|
528 |
enable_checkpointing=enable_checkpointing,
|
src/plot_experiment_results.py
CHANGED
@@ -12,7 +12,9 @@ import numpy as np
|
|
12 |
palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']
|
13 |
|
14 |
|
15 |
-
def plot_training_curves(df, split_type):
|
|
|
|
|
16 |
# Clean the data
|
17 |
df = df.dropna(how='all', axis=1)
|
18 |
|
@@ -26,14 +28,14 @@ def plot_training_curves(df, split_type):
|
|
26 |
|
27 |
# Plot training loss
|
28 |
ax[0].plot(epoch_data.index, epoch_data['train_loss_epoch'], label='Training Loss')
|
29 |
-
ax[0].plot(epoch_data.index, epoch_data['
|
30 |
ax[0].set_ylabel('Loss')
|
31 |
ax[0].legend(loc='lower right')
|
32 |
ax[0].grid(axis='both', alpha=0.5)
|
33 |
|
34 |
# Plot training accuracy
|
35 |
ax[1].plot(epoch_data.index, epoch_data['train_acc_epoch'], label='Training Accuracy')
|
36 |
-
ax[1].plot(epoch_data.index, epoch_data['
|
37 |
ax[1].set_ylabel('Accuracy')
|
38 |
ax[1].legend(loc='lower right')
|
39 |
ax[1].grid(axis='both', alpha=0.5)
|
@@ -44,7 +46,7 @@ def plot_training_curves(df, split_type):
|
|
44 |
|
45 |
# Plot training ROC-AUC
|
46 |
ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
|
47 |
-
ax[2].plot(epoch_data.index, epoch_data['
|
48 |
ax[2].set_ylabel('ROC-AUC')
|
49 |
ax[2].legend(loc='lower right')
|
50 |
ax[2].grid(axis='both', alpha=0.5)
|
@@ -270,10 +272,18 @@ def plot_ablation_study(report):
|
|
270 |
plt.savefig(f'plots/ablation_study_{group}.pdf', bbox_inches='tight')
|
271 |
|
272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
def main():
|
274 |
active_col = 'Active (Dmax 0.6, pDC50 6.0)'
|
275 |
test_split = 0.1
|
276 |
n_models_for_test = 3
|
|
|
277 |
|
278 |
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
|
279 |
report_base_name = f'{active_name}_test_split_{test_split}'
|
@@ -300,25 +310,36 @@ def main():
|
|
300 |
pd.read_csv(f'reports/hparam_report_{report_base_name}_uniprot.csv'),
|
301 |
pd.read_csv(f'reports/hparam_report_{report_base_name}_tanimoto.csv'),
|
302 |
]),
|
|
|
|
|
|
|
|
|
|
|
303 |
}
|
304 |
|
305 |
|
306 |
-
|
307 |
-
|
308 |
-
for split_type in ['random', 'tanimoto', 'uniprot', 'e3_ligase']:
|
309 |
logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
|
310 |
-
metrics
|
311 |
-
metrics[
|
312 |
# Rename 'val_' columns to 'test_' columns
|
313 |
-
metrics
|
314 |
-
|
315 |
-
plot_training_curves(metrics[f'{split_type}_{i}'], f'{split_type}_{i}')
|
316 |
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
|
318 |
df_val = reports['cv_train']
|
319 |
df_test = reports['test']
|
320 |
plot_performance_metrics(df_val, df_test, title=f'{active_name}_metrics')
|
321 |
|
|
|
|
|
322 |
reports['test']['disabled_embeddings'] = pd.NA
|
323 |
plot_ablation_study(pd.concat([
|
324 |
reports['ablation'],
|
|
|
12 |
palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']
|
13 |
|
14 |
|
15 |
+
def plot_training_curves(df, split_type, stage='test'):
|
16 |
+
Stage = 'Test' if stage == 'test' else 'Validation'
|
17 |
+
|
18 |
# Clean the data
|
19 |
df = df.dropna(how='all', axis=1)
|
20 |
|
|
|
28 |
|
29 |
# Plot training loss
|
30 |
ax[0].plot(epoch_data.index, epoch_data['train_loss_epoch'], label='Training Loss')
|
31 |
+
ax[0].plot(epoch_data.index, epoch_data[f'{stage}_loss'], label=f'{Stage} Loss', linestyle='--')
|
32 |
ax[0].set_ylabel('Loss')
|
33 |
ax[0].legend(loc='lower right')
|
34 |
ax[0].grid(axis='both', alpha=0.5)
|
35 |
|
36 |
# Plot training accuracy
|
37 |
ax[1].plot(epoch_data.index, epoch_data['train_acc_epoch'], label='Training Accuracy')
|
38 |
+
ax[1].plot(epoch_data.index, epoch_data[f'{stage}_acc'], label=f'{Stage} Accuracy', linestyle='--')
|
39 |
ax[1].set_ylabel('Accuracy')
|
40 |
ax[1].legend(loc='lower right')
|
41 |
ax[1].grid(axis='both', alpha=0.5)
|
|
|
46 |
|
47 |
# Plot training ROC-AUC
|
48 |
ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
|
49 |
+
ax[2].plot(epoch_data.index, epoch_data[f'{stage}_roc_auc'], label=f'{Stage} ROC-AUC', linestyle='--')
|
50 |
ax[2].set_ylabel('ROC-AUC')
|
51 |
ax[2].legend(loc='lower right')
|
52 |
ax[2].grid(axis='both', alpha=0.5)
|
|
|
272 |
plt.savefig(f'plots/ablation_study_{group}.pdf', bbox_inches='tight')
|
273 |
|
274 |
|
275 |
+
def plot_majority_voting_performance(df):
|
276 |
+
# cv_models,test_acc,test_roc_auc,split_type
|
277 |
+
# Melt the dataframe
|
278 |
+
df = df.melt(id_vars=['cv_models', 'test_acc', 'test_roc_auc', 'split_type'], var_name='Metric', value_name='Score')
|
279 |
+
print(df)
|
280 |
+
|
281 |
+
|
282 |
def main():
|
283 |
active_col = 'Active (Dmax 0.6, pDC50 6.0)'
|
284 |
test_split = 0.1
|
285 |
n_models_for_test = 3
|
286 |
+
cv_n_folds = 5
|
287 |
|
288 |
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
|
289 |
report_base_name = f'{active_name}_test_split_{test_split}'
|
|
|
310 |
pd.read_csv(f'reports/hparam_report_{report_base_name}_uniprot.csv'),
|
311 |
pd.read_csv(f'reports/hparam_report_{report_base_name}_tanimoto.csv'),
|
312 |
]),
|
313 |
+
'majority_vote': pd.concat([
|
314 |
+
pd.read_csv(f'reports/majority_vote_report_{report_base_name}_random.csv'),
|
315 |
+
pd.read_csv(f'reports/majority_vote_report_{report_base_name}_uniprot.csv'),
|
316 |
+
pd.read_csv(f'reports/majority_vote_report_{report_base_name}_tanimoto.csv'),
|
317 |
+
]),
|
318 |
}
|
319 |
|
320 |
|
321 |
+
for split_type in ['random', 'tanimoto', 'uniprot']:
|
322 |
+
for i in range(n_models_for_test):
|
|
|
323 |
logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
|
324 |
+
metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
|
325 |
+
metrics['model_id'] = i
|
326 |
# Rename 'val_' columns to 'test_' columns
|
327 |
+
metrics = metrics.rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'})
|
328 |
+
plot_training_curves(metrics, f'{split_type}_best_model_n{i}')
|
|
|
329 |
|
330 |
+
for i in range(cv_n_folds):
|
331 |
+
# logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
|
332 |
+
logs_dir = f'{split_type}_cv_model_fold{i}'
|
333 |
+
metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
|
334 |
+
metrics['fold'] = i
|
335 |
+
plot_training_curves(metrics, f'{split_type}_cv_model_fold{i}', stage='val')
|
336 |
|
337 |
df_val = reports['cv_train']
|
338 |
df_test = reports['test']
|
339 |
plot_performance_metrics(df_val, df_test, title=f'{active_name}_metrics')
|
340 |
|
341 |
+
plot_majority_voting_performance(reports['majority_vote'])
|
342 |
+
|
343 |
reports['test']['disabled_embeddings'] = pd.NA
|
344 |
plot_ablation_study(pd.concat([
|
345 |
reports['ablation'],
|
src/run_experiments.py
CHANGED
@@ -8,6 +8,7 @@ from typing import Literal
|
|
8 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
9 |
|
10 |
import protac_degradation_predictor as pdp
|
|
|
11 |
|
12 |
import pytorch_lightning as pl
|
13 |
from rdkit import Chem
|
@@ -77,21 +78,38 @@ def get_smiles2fp_and_avg_tanimoto(protac_df: pd.DataFrame) -> tuple:
|
|
77 |
Returns:
|
78 |
tuple: The SMILES to fingerprint dictionary and the average Tanimoto similarity.
|
79 |
"""
|
|
|
|
|
80 |
smiles2fp = {}
|
81 |
-
for smiles in tqdm(
|
82 |
smiles2fp[smiles] = pdp.get_fingerprint(smiles)
|
83 |
|
84 |
-
# Get the pair-wise tanimoto similarity between the PROTAC fingerprints
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
tanimoto_matrix = defaultdict(list)
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
95 |
avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
|
96 |
protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
|
97 |
|
@@ -256,7 +274,7 @@ def main(
|
|
256 |
test_indeces['e3_ligase'] = get_e3_ligase_split_indices(active_df)
|
257 |
if experiments == 'tanimoto' or experiments == 'all':
|
258 |
test_indeces['tanimoto'] = get_tanimoto_split_indices(active_df, active_col, test_split)
|
259 |
-
|
260 |
# Make directory ../reports if it does not exist
|
261 |
if not os.path.exists('../reports'):
|
262 |
os.makedirs('../reports')
|
|
|
8 |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
9 |
|
10 |
import protac_degradation_predictor as pdp
|
11 |
+
from protac_degradation_predictor.optuna_utils import get_dataframe_stats
|
12 |
|
13 |
import pytorch_lightning as pl
|
14 |
from rdkit import Chem
|
|
|
78 |
Returns:
|
79 |
tuple: The SMILES to fingerprint dictionary and the average Tanimoto similarity.
|
80 |
"""
|
81 |
+
unique_smiles = protac_df['Smiles'].unique().tolist()
|
82 |
+
|
83 |
smiles2fp = {}
|
84 |
+
for smiles in tqdm(unique_smiles, desc='Precomputing fingerprints'):
|
85 |
smiles2fp[smiles] = pdp.get_fingerprint(smiles)
|
86 |
|
87 |
+
# # Get the pair-wise tanimoto similarity between the PROTAC fingerprints
|
88 |
+
# tanimoto_matrix = defaultdict(list)
|
89 |
+
# for i, smiles1 in enumerate(tqdm(protac_df['Smiles'].unique(), desc='Computing Tanimoto similarity')):
|
90 |
+
# fp1 = smiles2fp[smiles1]
|
91 |
+
# # TODO: Use BulkTanimotoSimilarity for better performance
|
92 |
+
# for j, smiles2 in enumerate(protac_df['Smiles'].unique()[i:]):
|
93 |
+
# fp2 = smiles2fp[smiles2]
|
94 |
+
# tanimoto_dist = 1 - DataStructs.TanimotoSimilarity(fp1, fp2)
|
95 |
+
# tanimoto_matrix[smiles1].append(tanimoto_dist)
|
96 |
+
# avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
|
97 |
+
# protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
|
98 |
+
|
99 |
+
|
100 |
tanimoto_matrix = defaultdict(list)
|
101 |
+
fps = list(smiles2fp.values())
|
102 |
+
|
103 |
+
# Compute all-against-all Tanimoto similarity using BulkTanimotoSimilarity
|
104 |
+
for i, (smiles1, fp1) in enumerate(tqdm(zip(unique_smiles, fps), desc='Computing Tanimoto similarity', total=len(fps))):
|
105 |
+
similarities = DataStructs.BulkTanimotoSimilarity(fp1, fps[i:]) # Only compute for i to end, avoiding duplicates
|
106 |
+
for j, similarity in enumerate(similarities):
|
107 |
+
distance = 1 - similarity
|
108 |
+
tanimoto_matrix[smiles1].append(distance) # Store as distance
|
109 |
+
if i != i + j:
|
110 |
+
tanimoto_matrix[unique_smiles[i + j]].append(distance) # Symmetric filling
|
111 |
+
|
112 |
+
# Calculate average Tanimoto distance for each unique SMILES
|
113 |
avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
|
114 |
protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
|
115 |
|
|
|
274 |
test_indeces['e3_ligase'] = get_e3_ligase_split_indices(active_df)
|
275 |
if experiments == 'tanimoto' or experiments == 'all':
|
276 |
test_indeces['tanimoto'] = get_tanimoto_split_indices(active_df, active_col, test_split)
|
277 |
+
|
278 |
# Make directory ../reports if it does not exist
|
279 |
if not os.path.exists('../reports'):
|
280 |
os.makedirs('../reports')
|