ribesstefano commited on
Commit
fda7af7
1 Parent(s): 62ccb16

Added FP search + Added train metrics to logs

Browse files
notebooks/best_fingerprint_search.ipynb ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 12,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import sys\n",
11
+ "from collections import defaultdict\n",
12
+ "import warnings\n",
13
+ "import logging\n",
14
+ "from typing import Literal\n",
15
+ "\n",
16
+ "sys.path.append('~/PROTAC-Degradation-Predictor/protac_degradation_predictor')\n",
17
+ "import protac_degradation_predictor as pdp\n",
18
+ "\n",
19
+ "import pytorch_lightning as pl\n",
20
+ "from rdkit import Chem\n",
21
+ "from rdkit.Chem import AllChem\n",
22
+ "from rdkit import DataStructs\n",
23
+ "from jsonargparse import CLI\n",
24
+ "import pandas as pd\n",
25
+ "# Import tqdm for notebook\n",
26
+ "from tqdm.notebook import tqdm\n",
27
+ "import numpy as np\n",
28
+ "from sklearn.preprocessing import OrdinalEncoder\n",
29
+ "from sklearn.model_selection import (\n",
30
+ " StratifiedKFold,\n",
31
+ " StratifiedGroupKFold,\n",
32
+ ")\n",
33
+ "\n",
34
+ "\n",
35
+ "active_col = 'Active (Dmax 0.6, pDC50 6.0)'\n",
36
+ "pDC50_threshold = 6.0\n",
37
+ "Dmax_threshold = 0.6\n",
38
+ "\n",
39
+ "protac_df = pd.read_csv('~/PROTAC-Degradation-Predictor/data/PROTAC-Degradation-DB.csv')\n",
40
+ "protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')\n",
41
+ "protac_df[active_col] = protac_df.apply(\n",
42
+ " lambda x: pdp.is_active(x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold), axis=1\n",
43
+ ")"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": 13,
49
+ "metadata": {},
50
+ "outputs": [
51
+ {
52
+ "data": {
53
+ "text/plain": [
54
+ "771"
55
+ ]
56
+ },
57
+ "execution_count": 13,
58
+ "metadata": {},
59
+ "output_type": "execute_result"
60
+ }
61
+ ],
62
+ "source": [
63
+ "def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index:\n",
64
+ " \"\"\" Get the indices of the test set using a random split.\n",
65
+ " \n",
66
+ " Args:\n",
67
+ " active_df (pd.DataFrame): The DataFrame containing the active PROTACs.\n",
68
+ " test_split (float): The percentage of the active PROTACs to use as the test set.\n",
69
+ " \n",
70
+ " Returns:\n",
71
+ " pd.Index: The indices of the test set.\n",
72
+ " \"\"\"\n",
73
+ " test_df = active_df.sample(frac=test_split, random_state=42)\n",
74
+ " return test_df.index\n",
75
+ "\n",
76
+ "active_df = protac_df[protac_df[active_col].notna()].copy()\n",
77
+ "test_split = 0.1\n",
78
+ "test_indices = get_random_split_indices(active_df, test_split)\n",
79
+ "train_val_df = active_df[~active_df.index.isin(test_indices)].copy()\n",
80
+ "len(train_val_df)"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 29,
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "import optuna\n",
90
+ "\n",
91
+ "def objective(trial: optuna.Trial, verbose: int = 0) -> float:\n",
92
+ " \n",
93
+ " radius = trial.suggest_int('radius', 1, 15)\n",
94
+ " fpsize = trial.suggest_int('fpsize', 128, 2048, step=128)\n",
95
+ "\n",
96
+ " morgan_fpgen = AllChem.GetMorganGenerator(\n",
97
+ " radius=radius,\n",
98
+ " fpSize=fpsize,\n",
99
+ " includeChirality=True,\n",
100
+ " )\n",
101
+ "\n",
102
+ " smiles2fp = {}\n",
103
+ " for smiles in train_val_df['Smiles'].unique().tolist():\n",
104
+ " smiles2fp[smiles] = pdp.get_fingerprint(smiles, morgan_fpgen)\n",
105
+ "\n",
106
+ " # Count the number of unique SMILES and the number of unique Morgan fingerprints\n",
107
+ " unique_fps = set([tuple(fp) for fp in smiles2fp.values()])\n",
108
+ " # Get the list of SMILES with overlapping fingerprints\n",
109
+ " overlapping_smiles = []\n",
110
+ " unique_fps = set()\n",
111
+ " for smiles, fp in smiles2fp.items():\n",
112
+ " if tuple(fp) in unique_fps:\n",
113
+ " overlapping_smiles.append(smiles)\n",
114
+ " else:\n",
115
+ " unique_fps.add(tuple(fp))\n",
116
+ " num_overlaps = len(train_val_df[train_val_df[\"Smiles\"].isin(overlapping_smiles)])\n",
117
+ " num_overlaps_tot = len(protac_df[protac_df[\"Smiles\"].isin(overlapping_smiles)])\n",
118
+ "\n",
119
+ " if verbose:\n",
120
+ " print(f'Radius: {radius}')\n",
121
+ " print(f'FP length: {fpsize}')\n",
122
+ " print(f'Number of unique SMILES: {len(smiles2fp)}')\n",
123
+ " print(f'Number of unique fingerprints: {len(unique_fps)}')\n",
124
+ " print(f'Number of SMILES with overlapping fingerprints: {len(overlapping_smiles)}')\n",
125
+ " print(f'Number of overlapping SMILES in train_val_df: {num_overlaps}')\n",
126
+ " print(f'Number of overlapping SMILES in protac_df: {num_overlaps_tot}')\n",
127
+ " return num_overlaps + radius + fpsize / 100"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": 30,
133
+ "metadata": {},
134
+ "outputs": [
135
+ {
136
+ "name": "stderr",
137
+ "output_type": "stream",
138
+ "text": [
139
+ "[I 2024-04-29 11:28:05,626] A new study created in memory with name: no-name-4db5d822-6220-4ab8-bc3a-c776b0e5cac2\n"
140
+ ]
141
+ },
142
+ {
143
+ "data": {
144
+ "application/vnd.jupyter.widget-view+json": {
145
+ "model_id": "678150f59ec548bb89562e2230993989",
146
+ "version_major": 2,
147
+ "version_minor": 0
148
+ },
149
+ "text/plain": [
150
+ " 0%| | 0/50 [00:00<?, ?it/s]"
151
+ ]
152
+ },
153
+ "metadata": {},
154
+ "output_type": "display_data"
155
+ },
156
+ {
157
+ "name": "stdout",
158
+ "output_type": "stream",
159
+ "text": [
160
+ "[I 2024-04-29 11:28:07,705] Trial 0 finished with value: 39.480000000000004 and parameters: {'radius': 6, 'fpsize': 2048}. Best is trial 0 with value: 39.480000000000004.\n",
161
+ "[I 2024-04-29 11:28:09,590] Trial 1 finished with value: 23.8 and parameters: {'radius': 11, 'fpsize': 1280}. Best is trial 1 with value: 23.8.\n",
162
+ "[I 2024-04-29 11:28:10,474] Trial 2 finished with value: 131.84 and parameters: {'radius': 3, 'fpsize': 384}. Best is trial 1 with value: 23.8.\n",
163
+ "[I 2024-04-29 11:28:11,978] Trial 3 finished with value: 281.92 and parameters: {'radius': 1, 'fpsize': 1792}. Best is trial 1 with value: 23.8.\n",
164
+ "[I 2024-04-29 11:28:13,994] Trial 4 finished with value: 25.36 and parameters: {'radius': 10, 'fpsize': 1536}. Best is trial 1 with value: 23.8.\n",
165
+ "[I 2024-04-29 11:28:15,642] Trial 5 finished with value: 284.48 and parameters: {'radius': 1, 'fpsize': 2048}. Best is trial 1 with value: 23.8.\n",
166
+ "[I 2024-04-29 11:28:17,154] Trial 6 finished with value: 18.12 and parameters: {'radius': 13, 'fpsize': 512}. Best is trial 6 with value: 18.12.\n",
167
+ "[I 2024-04-29 11:28:18,057] Trial 7 finished with value: 131.84 and parameters: {'radius': 3, 'fpsize': 384}. Best is trial 6 with value: 18.12.\n",
168
+ "[I 2024-04-29 11:28:19,570] Trial 8 finished with value: 41.519999999999996 and parameters: {'radius': 5, 'fpsize': 1152}. Best is trial 6 with value: 18.12.\n",
169
+ "[I 2024-04-29 11:28:20,860] Trial 9 finished with value: 23.4 and parameters: {'radius': 7, 'fpsize': 640}. Best is trial 6 with value: 18.12.\n",
170
+ "[I 2024-04-29 11:28:22,631] Trial 10 finished with value: 22.68 and parameters: {'radius': 15, 'fpsize': 768}. Best is trial 6 with value: 18.12.\n",
171
+ "[I 2024-04-29 11:28:24,427] Trial 11 finished with value: 22.68 and parameters: {'radius': 15, 'fpsize': 768}. Best is trial 6 with value: 18.12.\n",
172
+ "[I 2024-04-29 11:28:25,756] Trial 12 finished with value: 92.28 and parameters: {'radius': 15, 'fpsize': 128}. Best is trial 6 with value: 18.12.\n",
173
+ "[I 2024-04-29 11:28:27,466] Trial 13 finished with value: 20.96 and parameters: {'radius': 12, 'fpsize': 896}. Best is trial 6 with value: 18.12.\n",
174
+ "[I 2024-04-29 11:28:29,156] Trial 14 finished with value: 20.96 and parameters: {'radius': 12, 'fpsize': 896}. Best is trial 6 with value: 18.12.\n",
175
+ "[I 2024-04-29 11:28:30,727] Trial 15 finished with value: 18.12 and parameters: {'radius': 13, 'fpsize': 512}. Best is trial 6 with value: 18.12.\n",
176
+ "[I 2024-04-29 11:28:31,842] Trial 16 finished with value: 22.28 and parameters: {'radius': 9, 'fpsize': 128}. Best is trial 6 with value: 18.12.\n",
177
+ "[I 2024-04-29 11:28:33,365] Trial 17 finished with value: 18.12 and parameters: {'radius': 13, 'fpsize': 512}. Best is trial 6 with value: 18.12.\n",
178
+ "[I 2024-04-29 11:28:34,801] Trial 18 finished with value: 16.84 and parameters: {'radius': 13, 'fpsize': 384}. Best is trial 18 with value: 16.84.\n",
179
+ "[I 2024-04-29 11:28:35,986] Trial 19 finished with value: 13.56 and parameters: {'radius': 9, 'fpsize': 256}. Best is trial 19 with value: 13.56.\n",
180
+ "[I 2024-04-29 11:28:37,122] Trial 20 finished with value: 14.56 and parameters: {'radius': 8, 'fpsize': 256}. Best is trial 19 with value: 13.56.\n",
181
+ "[I 2024-04-29 11:28:38,175] Trial 21 finished with value: 30.28 and parameters: {'radius': 8, 'fpsize': 128}. Best is trial 19 with value: 13.56.\n",
182
+ "[I 2024-04-29 11:28:39,406] Trial 22 finished with value: 13.56 and parameters: {'radius': 9, 'fpsize': 256}. Best is trial 19 with value: 13.56.\n",
183
+ "[I 2024-04-29 11:28:40,649] Trial 23 finished with value: 13.56 and parameters: {'radius': 9, 'fpsize': 256}. Best is trial 19 with value: 13.56.\n",
184
+ "[I 2024-04-29 11:28:41,868] Trial 24 finished with value: 12.56 and parameters: {'radius': 10, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n",
185
+ "[I 2024-04-29 11:28:43,109] Trial 25 finished with value: 12.56 and parameters: {'radius': 10, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n",
186
+ "[I 2024-04-29 11:28:44,587] Trial 26 finished with value: 16.4 and parameters: {'radius': 10, 'fpsize': 640}. Best is trial 24 with value: 12.56.\n",
187
+ "[I 2024-04-29 11:28:46,599] Trial 27 finished with value: 25.08 and parameters: {'radius': 11, 'fpsize': 1408}. Best is trial 24 with value: 12.56.\n",
188
+ "[I 2024-04-29 11:28:48,015] Trial 28 finished with value: 31.96 and parameters: {'radius': 6, 'fpsize': 896}. Best is trial 24 with value: 12.56.\n",
189
+ "[I 2024-04-29 11:28:49,347] Trial 29 finished with value: 23.4 and parameters: {'radius': 7, 'fpsize': 640}. Best is trial 24 with value: 12.56.\n",
190
+ "[I 2024-04-29 11:28:51,503] Trial 30 finished with value: 27.64 and parameters: {'radius': 11, 'fpsize': 1664}. Best is trial 24 with value: 12.56.\n",
191
+ "[I 2024-04-29 11:28:52,657] Trial 31 finished with value: 13.56 and parameters: {'radius': 9, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n",
192
+ "[I 2024-04-29 11:28:53,840] Trial 32 finished with value: 12.56 and parameters: {'radius': 10, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n",
193
+ "[I 2024-04-29 11:28:55,159] Trial 33 finished with value: 13.84 and parameters: {'radius': 10, 'fpsize': 384}. Best is trial 24 with value: 12.56.\n",
194
+ "[I 2024-04-29 11:28:56,140] Trial 34 finished with value: 39.28 and parameters: {'radius': 7, 'fpsize': 128}. Best is trial 24 with value: 12.56.\n",
195
+ "[I 2024-04-29 11:28:57,508] Trial 35 finished with value: 14.84 and parameters: {'radius': 11, 'fpsize': 384}. Best is trial 24 with value: 12.56.\n",
196
+ "[I 2024-04-29 11:28:58,900] Trial 36 finished with value: 15.120000000000001 and parameters: {'radius': 10, 'fpsize': 512}. Best is trial 24 with value: 12.56.\n",
197
+ "[I 2024-04-29 11:29:00,203] Trial 37 finished with value: 14.56 and parameters: {'radius': 12, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n",
198
+ "[I 2024-04-29 11:29:02,225] Trial 38 finished with value: 49.2 and parameters: {'radius': 5, 'fpsize': 1920}. Best is trial 24 with value: 12.56.\n",
199
+ "[I 2024-04-29 11:29:03,942] Trial 39 finished with value: 22.52 and parameters: {'radius': 8, 'fpsize': 1152}. Best is trial 24 with value: 12.56.\n",
200
+ "[I 2024-04-29 11:29:05,240] Trial 40 finished with value: 13.84 and parameters: {'radius': 10, 'fpsize': 384}. Best is trial 24 with value: 12.56.\n",
201
+ "[I 2024-04-29 11:29:06,396] Trial 41 finished with value: 13.56 and parameters: {'radius': 9, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n",
202
+ "[I 2024-04-29 11:29:07,422] Trial 42 finished with value: 30.28 and parameters: {'radius': 8, 'fpsize': 128}. Best is trial 24 with value: 12.56.\n",
203
+ "[I 2024-04-29 11:29:08,590] Trial 43 finished with value: 13.56 and parameters: {'radius': 9, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n",
204
+ "[I 2024-04-29 11:29:09,949] Trial 44 finished with value: 14.84 and parameters: {'radius': 11, 'fpsize': 384}. Best is trial 24 with value: 12.56.\n",
205
+ "[I 2024-04-29 11:29:11,378] Trial 45 finished with value: 15.120000000000001 and parameters: {'radius': 10, 'fpsize': 512}. Best is trial 24 with value: 12.56.\n",
206
+ "[I 2024-04-29 11:29:12,637] Trial 46 finished with value: 26.4 and parameters: {'radius': 6, 'fpsize': 640}. Best is trial 24 with value: 12.56.\n",
207
+ "[I 2024-04-29 11:29:14,232] Trial 47 finished with value: 18.68 and parameters: {'radius': 11, 'fpsize': 768}. Best is trial 24 with value: 12.56.\n",
208
+ "[I 2024-04-29 11:29:14,904] Trial 48 finished with value: 214.28 and parameters: {'radius': 2, 'fpsize': 128}. Best is trial 24 with value: 12.56.\n",
209
+ "[I 2024-04-29 11:29:16,323] Trial 49 finished with value: 16.56 and parameters: {'radius': 14, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n"
210
+ ]
211
+ }
212
+ ],
213
+ "source": [
214
+ "sampler = optuna.samplers.TPESampler(seed=42)\n",
215
+ "study = optuna.create_study(sampler=sampler, direction='minimize')\n",
216
+ "study.optimize(objective, n_trials=50, show_progress_bar=True)"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": 31,
222
+ "metadata": {},
223
+ "outputs": [
224
+ {
225
+ "name": "stdout",
226
+ "output_type": "stream",
227
+ "text": [
228
+ "Radius: 10\n",
229
+ "FP length: 256\n",
230
+ "Number of unique SMILES: 532\n",
231
+ "Number of unique fingerprints: 532\n",
232
+ "Number of SMILES with overlapping fingerprints: 0\n",
233
+ "Number of overlapping SMILES in train_val_df: 0\n",
234
+ "Number of overlapping SMILES in protac_df: 0\n"
235
+ ]
236
+ },
237
+ {
238
+ "data": {
239
+ "text/plain": [
240
+ "12.56"
241
+ ]
242
+ },
243
+ "execution_count": 31,
244
+ "metadata": {},
245
+ "output_type": "execute_result"
246
+ }
247
+ ],
248
+ "source": [
249
+ "# Run objective with best params and verbose\n",
250
+ "objective(study.best_trial, verbose=1)"
251
+ ]
252
+ }
253
+ ],
254
+ "metadata": {
255
+ "kernelspec": {
256
+ "display_name": "Python 3 (ipykernel)",
257
+ "language": "python",
258
+ "name": "python3"
259
+ },
260
+ "language_info": {
261
+ "codemirror_mode": {
262
+ "name": "ipython",
263
+ "version": 3
264
+ },
265
+ "file_extension": ".py",
266
+ "mimetype": "text/x-python",
267
+ "name": "python",
268
+ "nbconvert_exporter": "python",
269
+ "pygments_lexer": "ipython3",
270
+ "version": "3.10.8"
271
+ }
272
+ },
273
+ "nbformat": 4,
274
+ "nbformat_minor": 2
275
+ }
protac_degradation_predictor/config.py CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass, field
3
  @dataclass(frozen=True)
4
  class Config:
5
  # Embeddings information
6
- morgan_radius: int = 15
7
  fingerprint_size: int = 256 # 224
8
  protein_embedding_size: int = 1024
9
  cell_embedding_size: int = 768
 
3
  @dataclass(frozen=True)
4
  class Config:
5
  # Embeddings information
6
+ morgan_radius: int = 10 # 15
7
  fingerprint_size: int = 256 # 224
8
  protein_embedding_size: int = 1024
9
  cell_embedding_size: int = 768
protac_degradation_predictor/optuna_utils.py CHANGED
@@ -173,13 +173,11 @@ def pytorch_model_objective(
173
  disabled_embeddings=disabled_embeddings,
174
  )
175
  if test_df is not None:
176
- _, trainer, metrics, val_pred, test_pred = ret
177
  test_preds.append(test_pred)
178
  else:
179
- _, trainer, metrics, val_pred = ret
180
- train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
181
  stats.update(metrics)
182
- stats.update(train_metrics)
183
  report.append(stats.copy())
184
  val_preds.append(val_pred)
185
 
@@ -252,7 +250,7 @@ def hyperparameter_tuning_and_training(
252
  batch_size_options = [4, 8, 16, 32, 64, 128]
253
  learning_rate_options = (1e-5, 1e-3) # min and max values for loguniform distribution
254
  smote_k_neighbors_options = list(range(3, 16))
255
- dropout_options = (0.1, 0.9)
256
 
257
  # Set the verbosity of Optuna
258
  optuna.logging.set_verbosity(optuna.logging.WARNING)
@@ -325,13 +323,6 @@ def hyperparameter_tuning_and_training(
325
  metrics['test_model_id'] = i
326
  metrics.update(dfs_stats)
327
 
328
- # Add the training metrics
329
- train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
330
- logging.info(f'Training metrics: {train_metrics}')
331
- logging.info(f'Training trainer.logged_metrics: {trainer.logged_metrics}')
332
- logging.info(f'Training trainer.callback_metrics: {trainer.callback_metrics}')
333
-
334
- metrics.update(train_metrics)
335
  test_report.append(metrics.copy())
336
  test_preds.append(test_pred)
337
  test_report = pd.DataFrame(test_report)
 
173
  disabled_embeddings=disabled_embeddings,
174
  )
175
  if test_df is not None:
176
+ _, _, metrics, val_pred, test_pred = ret
177
  test_preds.append(test_pred)
178
  else:
179
+ _, _, metrics, val_pred = ret
 
180
  stats.update(metrics)
 
181
  report.append(stats.copy())
182
  val_preds.append(val_pred)
183
 
 
250
  batch_size_options = [4, 8, 16, 32, 64, 128]
251
  learning_rate_options = (1e-5, 1e-3) # min and max values for loguniform distribution
252
  smote_k_neighbors_options = list(range(3, 16))
253
+ dropout_options = (0.2, 0.9)
254
 
255
  # Set the verbosity of Optuna
256
  optuna.logging.set_verbosity(optuna.logging.WARNING)
 
323
  metrics['test_model_id'] = i
324
  metrics.update(dfs_stats)
325
 
 
 
 
 
 
 
 
326
  test_report.append(metrics.copy())
327
  test_preds.append(test_pred)
328
  test_report = pd.DataFrame(test_report)
protac_degradation_predictor/pytorch_models.py CHANGED
@@ -207,8 +207,8 @@ class PROTAC_Model(pl.LightningModule):
207
  'precision': Precision(task='binary'),
208
  'recall': Recall(task='binary'),
209
  'f1_score': F1Score(task='binary'),
210
- 'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
211
- 'hp_metric': Accuracy(task='binary'),
212
  }, prefix=s.replace('metrics', '')) for s in stages})
213
 
214
  # Misc settings
@@ -314,8 +314,8 @@ class PROTAC_Model(pl.LightningModule):
314
  'lr_scheduler': optim.lr_scheduler.ReduceLROnPlateau(
315
  optimizer=optimizer,
316
  mode='min',
317
- factor=0.5,
318
- patience=2,
319
  ),
320
  'interval': 'step', # or 'epoch'
321
  'frequency': 1,
@@ -508,6 +508,7 @@ def train_model(
508
  logger=loggers if use_logger else False,
509
  callbacks=callbacks,
510
  max_epochs=max_epochs,
 
511
  fast_dev_run=fast_dev_run,
512
  enable_model_summary=False,
513
  enable_checkpointing=enable_checkpointing,
@@ -534,11 +535,22 @@ def train_model(
534
  with warnings.catch_warnings():
535
  warnings.simplefilter("ignore")
536
  trainer.fit(model)
537
- metrics = trainer.validate(model, verbose=False)[0]
 
 
 
 
 
 
 
 
538
  # Add test metrics to metrics
539
  if test_df is not None:
540
  test_metrics = trainer.test(model, verbose=False)[0]
 
541
  metrics.update(test_metrics)
 
 
542
  if return_predictions:
543
  val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
544
  val_pred = trainer.predict(model, val_dl)
 
207
  'precision': Precision(task='binary'),
208
  'recall': Recall(task='binary'),
209
  'f1_score': F1Score(task='binary'),
210
+ # 'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
211
+ # 'hp_metric': Accuracy(task='binary'),
212
  }, prefix=s.replace('metrics', '')) for s in stages})
213
 
214
  # Misc settings
 
314
  'lr_scheduler': optim.lr_scheduler.ReduceLROnPlateau(
315
  optimizer=optimizer,
316
  mode='min',
317
+ factor=0.1,
318
+ patience=0,
319
  ),
320
  'interval': 'step', # or 'epoch'
321
  'frequency': 1,
 
508
  logger=loggers if use_logger else False,
509
  callbacks=callbacks,
510
  max_epochs=max_epochs,
511
+ val_check_interval=0.5,
512
  fast_dev_run=fast_dev_run,
513
  enable_model_summary=False,
514
  enable_checkpointing=enable_checkpointing,
 
535
  with warnings.catch_warnings():
536
  warnings.simplefilter("ignore")
537
  trainer.fit(model)
538
+ metrics = {}
539
+ # Add train metrics
540
+ train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
541
+ metrics.update(train_metrics)
542
+ # Add validation metrics
543
+ val_metrics = trainer.validate(model, verbose=False)[0]
544
+ val_metrics = {m: v for m, v in val_metrics.items() if 'val' in m}
545
+ metrics.update(val_metrics)
546
+
547
  # Add test metrics to metrics
548
  if test_df is not None:
549
  test_metrics = trainer.test(model, verbose=False)[0]
550
+ test_metrics = {m: v for m, v in test_metrics.items() if 'test' in m}
551
  metrics.update(test_metrics)
552
+
553
+ # Return predictions
554
  if return_predictions:
555
  val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
556
  val_pred = trainer.predict(model, val_dl)
src/plot_experiment_results.py CHANGED
@@ -12,7 +12,7 @@ import numpy as np
12
  palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']
13
 
14
 
15
- def plot_metrics(df, title):
16
  # Clean the data
17
  df = df.dropna(how='all', axis=1)
18
 
@@ -37,6 +37,10 @@ def plot_metrics(df, title):
37
  ax[1].set_ylabel('Accuracy')
38
  ax[1].legend(loc='lower right')
39
  ax[1].grid(axis='both', alpha=0.5)
 
 
 
 
40
 
41
  # Plot training ROC-AUC
42
  ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
@@ -44,16 +48,16 @@ def plot_metrics(df, title):
44
  ax[2].set_ylabel('ROC-AUC')
45
  ax[2].legend(loc='lower right')
46
  ax[2].grid(axis='both', alpha=0.5)
47
-
 
48
  # Set x-axis label
49
  ax[2].set_xlabel('Epoch')
50
 
51
- plt.title(title)
52
  plt.tight_layout()
53
- plt.savefig(f'plots/{title}_metrics.pdf', bbox_inches='tight')
54
 
55
 
56
- def plot_report(df_cv, df_test, title=None):
57
 
58
  # Extract and prepare CV data
59
  cv_data = df_cv[['model_type', 'fold', 'val_acc', 'val_roc_auc', 'test_acc', 'test_roc_auc', 'split_type']]
@@ -114,7 +118,13 @@ def plot_report(df_cv, df_test, title=None):
114
 
115
  # Plotting
116
  plt.figure(figsize=(12, 6))
117
- sns.barplot(data=combined_data, x='Metric', y='Score', hue='Split Type', errorbar='sd', palette=palette)
 
 
 
 
 
 
118
  plt.title('')
119
  plt.ylabel('')
120
  plt.xlabel('')
@@ -134,9 +144,9 @@ def plot_report(df_cv, df_test, title=None):
134
  if p.get_height() < 0.01:
135
  continue
136
  if i % 2 == 0:
137
- value = '{:.1f}%'.format(100 * p.get_height())
138
  else:
139
- value = '{:.2f}'.format(p.get_height())
140
 
141
  print(f'Plotting value: {p.get_height()} -> {value}')
142
  x = p.get_x() + p.get_width() / 2
@@ -146,6 +156,120 @@ def plot_report(df_cv, df_test, title=None):
146
  plt.savefig(f'plots/{title}.pdf', bbox_inches='tight')
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def main():
150
  active_col = 'Active (Dmax 0.6, pDC50 6.0)'
151
  test_split = 0.1
@@ -156,28 +280,50 @@ def main():
156
 
157
  # Load the data
158
  reports = {
159
- 'cv_train': pd.read_csv(f'reports/report_cv_train_{report_base_name}.csv'),
160
- 'test': pd.read_csv(f'reports/report_test_{report_base_name}.csv'),
161
- 'ablation': pd.read_csv(f'reports/report_ablation_{report_base_name}.csv'),
162
- 'hparam': pd.read_csv(f'reports/report_hparam_{report_base_name}.csv'),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  }
164
 
165
 
166
- # metrics = {}
167
- # for i in range(n_models_for_test):
168
- # for split_type in ['random', 'tanimoto', 'uniprot', 'e3_ligase']:
169
- # logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
170
- # metrics[f'{split_type}_{i}'] = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
171
- # metrics[f'{split_type}_{i}']['model_id'] = i
172
- # # Rename 'val_' columns to 'test_' columns
173
- # metrics[f'{split_type}_{i}'] = metrics[f'{split_type}_{i}'].rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'})
174
 
175
- # plot_metrics(metrics[f'{split_type}_{i}'], f'{split_type}_{i}')
176
 
177
 
178
  df_val = reports['cv_train']
179
  df_test = reports['test']
180
- plot_report(df_val, df_test, title=f'{active_name}_metrics')
 
 
 
 
 
 
181
 
182
 
183
  if __name__ == '__main__':
 
12
  palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']
13
 
14
 
15
+ def plot_training_curves(df, split_type):
16
  # Clean the data
17
  df = df.dropna(how='all', axis=1)
18
 
 
37
  ax[1].set_ylabel('Accuracy')
38
  ax[1].legend(loc='lower right')
39
  ax[1].grid(axis='both', alpha=0.5)
40
+ # Set limit to y-axis
41
+ ax[1].set_ylim(0, 1.0)
42
+ # Set y-axis to percentage
43
+ ax[1].yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))
44
 
45
  # Plot training ROC-AUC
46
  ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
 
48
  ax[2].set_ylabel('ROC-AUC')
49
  ax[2].legend(loc='lower right')
50
  ax[2].grid(axis='both', alpha=0.5)
51
+ # Set limit to y-axis
52
+ ax[2].set_ylim(0, 1.0)
53
  # Set x-axis label
54
  ax[2].set_xlabel('Epoch')
55
 
 
56
  plt.tight_layout()
57
+ plt.savefig(f'plots/training_metrics_{split_type}.pdf', bbox_inches='tight')
58
 
59
 
60
+ def plot_performance_metrics(df_cv, df_test, title=None):
61
 
62
  # Extract and prepare CV data
63
  cv_data = df_cv[['model_type', 'fold', 'val_acc', 'val_roc_auc', 'test_acc', 'test_roc_auc', 'split_type']]
 
118
 
119
  # Plotting
120
  plt.figure(figsize=(12, 6))
121
+ sns.barplot(
122
+ data=combined_data,
123
+ x='Metric',
124
+ y='Score',
125
+ hue='Split Type',
126
+ errorbar=('sd', 1),
127
+ palette=palette)
128
  plt.title('')
129
  plt.ylabel('')
130
  plt.xlabel('')
 
144
  if p.get_height() < 0.01:
145
  continue
146
  if i % 2 == 0:
147
+ value = f'{p.get_height():.1%}'
148
  else:
149
+ value = f'{p.get_height():.3f}'
150
 
151
  print(f'Plotting value: {p.get_height()} -> {value}')
152
  x = p.get_x() + p.get_width() / 2
 
156
  plt.savefig(f'plots/{title}.pdf', bbox_inches='tight')
157
 
158
 
159
+ def plot_ablation_study(report):
160
+ # Define the ablation study combinations
161
+ ablation_study_combinations = [
162
+ 'disabled smiles',
163
+ 'disabled poi',
164
+ 'disabled e3',
165
+ 'disabled cell',
166
+ 'disabled poi e3 smiles',
167
+ 'disabled poi e3 cell',
168
+ ]
169
+
170
+ for group in report['split_type'].unique():
171
+ baseline = report[report['disabled_embeddings'].isna()].copy()
172
+ baseline = baseline[baseline['split_type'] == group]
173
+ baseline['disabled_embeddings'] = 'all embeddings enabled'
174
+ # metrics_to_show = ['val_acc', 'test_acc']
175
+ metrics_to_show = ['test_acc']
176
+ # baseline = baseline.melt(id_vars=['fold', 'disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')
177
+ baseline = baseline.melt(id_vars=['disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')
178
+
179
+ print(f'Group: {group}, avg: {(0.755814 + 0.720930 + 0.732558) / 3:.1%}')
180
+ print(f'Group: {group}, avg: {(0.7558139562606812 + 0.7209302186965942 + 0.7325581312179565) / 3:.1%}')
181
+ print(baseline)
182
+
183
+ ablation_dfs = []
184
+ for disabled_embeddings in ablation_study_combinations:
185
+ if pd.isnull(disabled_embeddings):
186
+ continue
187
+ tmp = report[report['disabled_embeddings'] == disabled_embeddings].copy()
188
+ tmp = tmp[tmp['split_type'] == group]
189
+ # tmp = tmp.melt(id_vars=['fold', 'disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')
190
+ tmp = tmp.melt(id_vars=['disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')
191
+ ablation_dfs.append(tmp)
192
+ ablation_df = pd.concat(ablation_dfs)
193
+
194
+ # dummy_val_df = pd.DataFrame()
195
+ # tmp = report[report['split_type'] == group]
196
+ # dummy_val_df['score'] = tmp[['val_active_perc', 'val_inactive_perc']].max(axis=1)
197
+ # dummy_val_df['metric'] = 'val_acc'
198
+ # dummy_val_df['disabled_embeddings'] = 'dummy'
199
+
200
+ dummy_test_df = pd.DataFrame()
201
+ tmp = report[report['split_type'] == group]
202
+ dummy_test_df['score'] = tmp[['test_active_perc', 'test_inactive_perc']].max(axis=1)
203
+ dummy_test_df['metric'] = 'test_acc'
204
+ dummy_test_df['disabled_embeddings'] = 'dummy'
205
+
206
+ # dummy_df = pd.concat([dummy_val_df, dummy_test_df])
207
+ dummy_df = dummy_test_df
208
+
209
+ final_df = pd.concat([dummy_df, baseline, ablation_df])
210
+
211
+ final_df['metric'] = final_df['metric'].map({
212
+ 'val_acc': 'Validation Accuracy',
213
+ 'test_acc': 'Test Accuracy',
214
+ 'val_roc_auc': 'Val ROC-AUC',
215
+ 'test_roc_auc': 'Test ROC-AUC',
216
+ })
217
+
218
+ final_df['disabled_embeddings'] = final_df['disabled_embeddings'].map({
219
+ 'all embeddings enabled': 'All embeddings enabled',
220
+ 'dummy': 'Dummy model',
221
+ 'disabled smiles': 'Disabled compound information',
222
+ 'disabled e3': 'Disabled E3 information',
223
+ 'disabled poi': 'Disabled target information',
224
+ 'disabled cell': 'Disabled cell information',
225
+ 'disabled poi e3 smiles': 'Disabled compound, E3, and target info\n(only cell information left)',
226
+ 'disabled poi e3 cell': 'Disabled cell, E3, and target info\n(only compound information left)',
227
+ })
228
+
229
+ # Print final_df to latex
230
+ tmp = final_df.groupby(['disabled_embeddings', 'metric']).mean().round(3)
231
+ # Remove fold column to tmp
232
+ tmp = tmp.reset_index() #.drop('fold', axis=1)
233
+
234
+ # fig, ax = plt.subplots(figsize=(5, 5))
235
+ fig, ax = plt.subplots()
236
+
237
+ sns.barplot(data=final_df,
238
+ y='disabled_embeddings',
239
+ x='score',
240
+ hue='metric',
241
+ ax=ax,
242
+ errorbar=('sd', 1),
243
+ palette=sns.color_palette(palette, len(palette)),
244
+ saturation=1,
245
+ )
246
+
247
+ # ax.set_title(f'{group.replace("random", "standard")} CV split')
248
+ ax.grid(axis='x', alpha=0.5)
249
+ ax.tick_params(axis='y', rotation=0)
250
+ ax.set_xlim(0, 1.0)
251
+ ax.xaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))
252
+ ax.set_ylabel('')
253
+ ax.set_xlabel('')
254
+ # Set the legend outside the plot and below
255
+ # ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.08), ncol=2)
256
+ # Set the legend in the upper right corner
257
+ ax.legend(loc='upper right')
258
+
259
+ # For each bar, add the rotated value (as percentage), inside the bar
260
+ for i, p in enumerate(plt.gca().patches):
261
+ # TODO: For some reasons, there is an additional bar being added at
262
+ # the end of the plot... it's not in the dataframe
263
+ if i == len(plt.gca().patches) - 1:
264
+ continue
265
+ value = '{:.1f}%'.format(100 * p.get_width())
266
+ y = p.get_y() + p.get_height() / 2
267
+ x = 0.4 # p.get_height() - p.get_height() / 2
268
+ plt.annotate(value, (x, y), ha='center', va='center', color='black', fontsize=10, alpha=0.8)
269
+
270
+ plt.savefig(f'plots/ablation_study_{group}.pdf', bbox_inches='tight')
271
+
272
+
273
  def main():
274
  active_col = 'Active (Dmax 0.6, pDC50 6.0)'
275
  test_split = 0.1
 
280
 
281
  # Load the data
282
  reports = {
283
+ 'cv_train': pd.concat([
284
+ pd.read_csv(f'reports/cv_report_{report_base_name}_random.csv'),
285
+ pd.read_csv(f'reports/cv_report_{report_base_name}_uniprot.csv'),
286
+ pd.read_csv(f'reports/cv_report_{report_base_name}_tanimoto.csv'),
287
+ ]),
288
+ 'test': pd.concat([
289
+ pd.read_csv(f'reports/test_report_{report_base_name}_random.csv'),
290
+ pd.read_csv(f'reports/test_report_{report_base_name}_uniprot.csv'),
291
+ pd.read_csv(f'reports/test_report_{report_base_name}_tanimoto.csv'),
292
+ ]),
293
+ 'ablation': pd.concat([
294
+ pd.read_csv(f'reports/ablation_report_{report_base_name}_random.csv'),
295
+ pd.read_csv(f'reports/ablation_report_{report_base_name}_uniprot.csv'),
296
+ pd.read_csv(f'reports/ablation_report_{report_base_name}_tanimoto.csv'),
297
+ ]),
298
+ 'hparam': pd.concat([
299
+ pd.read_csv(f'reports/hparam_report_{report_base_name}_random.csv'),
300
+ pd.read_csv(f'reports/hparam_report_{report_base_name}_uniprot.csv'),
301
+ pd.read_csv(f'reports/hparam_report_{report_base_name}_tanimoto.csv'),
302
+ ]),
303
  }
304
 
305
 
306
+ metrics = {}
307
+ for i in range(n_models_for_test):
308
+ for split_type in ['random', 'tanimoto', 'uniprot', 'e3_ligase']:
309
+ logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
310
+ metrics[f'{split_type}_{i}'] = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
311
+ metrics[f'{split_type}_{i}']['model_id'] = i
312
+ # Rename 'val_' columns to 'test_' columns
313
+ metrics[f'{split_type}_{i}'] = metrics[f'{split_type}_{i}'].rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'})
314
 
315
+ plot_training_curves(metrics[f'{split_type}_{i}'], f'{split_type}_{i}')
316
 
317
 
318
  df_val = reports['cv_train']
319
  df_test = reports['test']
320
+ plot_performance_metrics(df_val, df_test, title=f'{active_name}_metrics')
321
+
322
+ reports['test']['disabled_embeddings'] = pd.NA
323
+ plot_ablation_study(pd.concat([
324
+ reports['ablation'],
325
+ reports['test'],
326
+ ]))
327
 
328
 
329
  if __name__ == '__main__':