ribesstefano
commited on
Commit
•
fda7af7
1
Parent(s):
62ccb16
Added FP search + Added train metrics to logs
Browse files
notebooks/best_fingerprint_search.ipynb
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 12,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"import sys\n",
|
11 |
+
"from collections import defaultdict\n",
|
12 |
+
"import warnings\n",
|
13 |
+
"import logging\n",
|
14 |
+
"from typing import Literal\n",
|
15 |
+
"\n",
|
16 |
+
"sys.path.append('~/PROTAC-Degradation-Predictor/protac_degradation_predictor')\n",
|
17 |
+
"import protac_degradation_predictor as pdp\n",
|
18 |
+
"\n",
|
19 |
+
"import pytorch_lightning as pl\n",
|
20 |
+
"from rdkit import Chem\n",
|
21 |
+
"from rdkit.Chem import AllChem\n",
|
22 |
+
"from rdkit import DataStructs\n",
|
23 |
+
"from jsonargparse import CLI\n",
|
24 |
+
"import pandas as pd\n",
|
25 |
+
"# Import tqdm for notebook\n",
|
26 |
+
"from tqdm.notebook import tqdm\n",
|
27 |
+
"import numpy as np\n",
|
28 |
+
"from sklearn.preprocessing import OrdinalEncoder\n",
|
29 |
+
"from sklearn.model_selection import (\n",
|
30 |
+
" StratifiedKFold,\n",
|
31 |
+
" StratifiedGroupKFold,\n",
|
32 |
+
")\n",
|
33 |
+
"\n",
|
34 |
+
"\n",
|
35 |
+
"active_col = 'Active (Dmax 0.6, pDC50 6.0)'\n",
|
36 |
+
"pDC50_threshold = 6.0\n",
|
37 |
+
"Dmax_threshold = 0.6\n",
|
38 |
+
"\n",
|
39 |
+
"protac_df = pd.read_csv('~/PROTAC-Degradation-Predictor/data/PROTAC-Degradation-DB.csv')\n",
|
40 |
+
"protac_df['E3 Ligase'] = protac_df['E3 Ligase'].str.replace('Iap', 'IAP')\n",
|
41 |
+
"protac_df[active_col] = protac_df.apply(\n",
|
42 |
+
" lambda x: pdp.is_active(x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold), axis=1\n",
|
43 |
+
")"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "code",
|
48 |
+
"execution_count": 13,
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [
|
51 |
+
{
|
52 |
+
"data": {
|
53 |
+
"text/plain": [
|
54 |
+
"771"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
"execution_count": 13,
|
58 |
+
"metadata": {},
|
59 |
+
"output_type": "execute_result"
|
60 |
+
}
|
61 |
+
],
|
62 |
+
"source": [
|
63 |
+
"def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index:\n",
|
64 |
+
" \"\"\" Get the indices of the test set using a random split.\n",
|
65 |
+
" \n",
|
66 |
+
" Args:\n",
|
67 |
+
" active_df (pd.DataFrame): The DataFrame containing the active PROTACs.\n",
|
68 |
+
" test_split (float): The percentage of the active PROTACs to use as the test set.\n",
|
69 |
+
" \n",
|
70 |
+
" Returns:\n",
|
71 |
+
" pd.Index: The indices of the test set.\n",
|
72 |
+
" \"\"\"\n",
|
73 |
+
" test_df = active_df.sample(frac=test_split, random_state=42)\n",
|
74 |
+
" return test_df.index\n",
|
75 |
+
"\n",
|
76 |
+
"active_df = protac_df[protac_df[active_col].notna()].copy()\n",
|
77 |
+
"test_split = 0.1\n",
|
78 |
+
"test_indices = get_random_split_indices(active_df, test_split)\n",
|
79 |
+
"train_val_df = active_df[~active_df.index.isin(test_indices)].copy()\n",
|
80 |
+
"len(train_val_df)"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"execution_count": 29,
|
86 |
+
"metadata": {},
|
87 |
+
"outputs": [],
|
88 |
+
"source": [
|
89 |
+
"import optuna\n",
|
90 |
+
"\n",
|
91 |
+
"def objective(trial: optuna.Trial, verbose: int = 0) -> float:\n",
|
92 |
+
" \n",
|
93 |
+
" radius = trial.suggest_int('radius', 1, 15)\n",
|
94 |
+
" fpsize = trial.suggest_int('fpsize', 128, 2048, step=128)\n",
|
95 |
+
"\n",
|
96 |
+
" morgan_fpgen = AllChem.GetMorganGenerator(\n",
|
97 |
+
" radius=radius,\n",
|
98 |
+
" fpSize=fpsize,\n",
|
99 |
+
" includeChirality=True,\n",
|
100 |
+
" )\n",
|
101 |
+
"\n",
|
102 |
+
" smiles2fp = {}\n",
|
103 |
+
" for smiles in train_val_df['Smiles'].unique().tolist():\n",
|
104 |
+
" smiles2fp[smiles] = pdp.get_fingerprint(smiles, morgan_fpgen)\n",
|
105 |
+
"\n",
|
106 |
+
" # Count the number of unique SMILES and the number of unique Morgan fingerprints\n",
|
107 |
+
" unique_fps = set([tuple(fp) for fp in smiles2fp.values()])\n",
|
108 |
+
" # Get the list of SMILES with overlapping fingerprints\n",
|
109 |
+
" overlapping_smiles = []\n",
|
110 |
+
" unique_fps = set()\n",
|
111 |
+
" for smiles, fp in smiles2fp.items():\n",
|
112 |
+
" if tuple(fp) in unique_fps:\n",
|
113 |
+
" overlapping_smiles.append(smiles)\n",
|
114 |
+
" else:\n",
|
115 |
+
" unique_fps.add(tuple(fp))\n",
|
116 |
+
" num_overlaps = len(train_val_df[train_val_df[\"Smiles\"].isin(overlapping_smiles)])\n",
|
117 |
+
" num_overlaps_tot = len(protac_df[protac_df[\"Smiles\"].isin(overlapping_smiles)])\n",
|
118 |
+
"\n",
|
119 |
+
" if verbose:\n",
|
120 |
+
" print(f'Radius: {radius}')\n",
|
121 |
+
" print(f'FP length: {fpsize}')\n",
|
122 |
+
" print(f'Number of unique SMILES: {len(smiles2fp)}')\n",
|
123 |
+
" print(f'Number of unique fingerprints: {len(unique_fps)}')\n",
|
124 |
+
" print(f'Number of SMILES with overlapping fingerprints: {len(overlapping_smiles)}')\n",
|
125 |
+
" print(f'Number of overlapping SMILES in train_val_df: {num_overlaps}')\n",
|
126 |
+
" print(f'Number of overlapping SMILES in protac_df: {num_overlaps_tot}')\n",
|
127 |
+
" return num_overlaps + radius + fpsize / 100"
|
128 |
+
]
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"cell_type": "code",
|
132 |
+
"execution_count": 30,
|
133 |
+
"metadata": {},
|
134 |
+
"outputs": [
|
135 |
+
{
|
136 |
+
"name": "stderr",
|
137 |
+
"output_type": "stream",
|
138 |
+
"text": [
|
139 |
+
"[I 2024-04-29 11:28:05,626] A new study created in memory with name: no-name-4db5d822-6220-4ab8-bc3a-c776b0e5cac2\n"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"data": {
|
144 |
+
"application/vnd.jupyter.widget-view+json": {
|
145 |
+
"model_id": "678150f59ec548bb89562e2230993989",
|
146 |
+
"version_major": 2,
|
147 |
+
"version_minor": 0
|
148 |
+
},
|
149 |
+
"text/plain": [
|
150 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
"metadata": {},
|
154 |
+
"output_type": "display_data"
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"name": "stdout",
|
158 |
+
"output_type": "stream",
|
159 |
+
"text": [
|
160 |
+
"[I 2024-04-29 11:28:07,705] Trial 0 finished with value: 39.480000000000004 and parameters: {'radius': 6, 'fpsize': 2048}. Best is trial 0 with value: 39.480000000000004.\n",
|
161 |
+
"[I 2024-04-29 11:28:09,590] Trial 1 finished with value: 23.8 and parameters: {'radius': 11, 'fpsize': 1280}. Best is trial 1 with value: 23.8.\n",
|
162 |
+
"[I 2024-04-29 11:28:10,474] Trial 2 finished with value: 131.84 and parameters: {'radius': 3, 'fpsize': 384}. Best is trial 1 with value: 23.8.\n",
|
163 |
+
"[I 2024-04-29 11:28:11,978] Trial 3 finished with value: 281.92 and parameters: {'radius': 1, 'fpsize': 1792}. Best is trial 1 with value: 23.8.\n",
|
164 |
+
"[I 2024-04-29 11:28:13,994] Trial 4 finished with value: 25.36 and parameters: {'radius': 10, 'fpsize': 1536}. Best is trial 1 with value: 23.8.\n",
|
165 |
+
"[I 2024-04-29 11:28:15,642] Trial 5 finished with value: 284.48 and parameters: {'radius': 1, 'fpsize': 2048}. Best is trial 1 with value: 23.8.\n",
|
166 |
+
"[I 2024-04-29 11:28:17,154] Trial 6 finished with value: 18.12 and parameters: {'radius': 13, 'fpsize': 512}. Best is trial 6 with value: 18.12.\n",
|
167 |
+
"[I 2024-04-29 11:28:18,057] Trial 7 finished with value: 131.84 and parameters: {'radius': 3, 'fpsize': 384}. Best is trial 6 with value: 18.12.\n",
|
168 |
+
"[I 2024-04-29 11:28:19,570] Trial 8 finished with value: 41.519999999999996 and parameters: {'radius': 5, 'fpsize': 1152}. Best is trial 6 with value: 18.12.\n",
|
169 |
+
"[I 2024-04-29 11:28:20,860] Trial 9 finished with value: 23.4 and parameters: {'radius': 7, 'fpsize': 640}. Best is trial 6 with value: 18.12.\n",
|
170 |
+
"[I 2024-04-29 11:28:22,631] Trial 10 finished with value: 22.68 and parameters: {'radius': 15, 'fpsize': 768}. Best is trial 6 with value: 18.12.\n",
|
171 |
+
"[I 2024-04-29 11:28:24,427] Trial 11 finished with value: 22.68 and parameters: {'radius': 15, 'fpsize': 768}. Best is trial 6 with value: 18.12.\n",
|
172 |
+
"[I 2024-04-29 11:28:25,756] Trial 12 finished with value: 92.28 and parameters: {'radius': 15, 'fpsize': 128}. Best is trial 6 with value: 18.12.\n",
|
173 |
+
"[I 2024-04-29 11:28:27,466] Trial 13 finished with value: 20.96 and parameters: {'radius': 12, 'fpsize': 896}. Best is trial 6 with value: 18.12.\n",
|
174 |
+
"[I 2024-04-29 11:28:29,156] Trial 14 finished with value: 20.96 and parameters: {'radius': 12, 'fpsize': 896}. Best is trial 6 with value: 18.12.\n",
|
175 |
+
"[I 2024-04-29 11:28:30,727] Trial 15 finished with value: 18.12 and parameters: {'radius': 13, 'fpsize': 512}. Best is trial 6 with value: 18.12.\n",
|
176 |
+
"[I 2024-04-29 11:28:31,842] Trial 16 finished with value: 22.28 and parameters: {'radius': 9, 'fpsize': 128}. Best is trial 6 with value: 18.12.\n",
|
177 |
+
"[I 2024-04-29 11:28:33,365] Trial 17 finished with value: 18.12 and parameters: {'radius': 13, 'fpsize': 512}. Best is trial 6 with value: 18.12.\n",
|
178 |
+
"[I 2024-04-29 11:28:34,801] Trial 18 finished with value: 16.84 and parameters: {'radius': 13, 'fpsize': 384}. Best is trial 18 with value: 16.84.\n",
|
179 |
+
"[I 2024-04-29 11:28:35,986] Trial 19 finished with value: 13.56 and parameters: {'radius': 9, 'fpsize': 256}. Best is trial 19 with value: 13.56.\n",
|
180 |
+
"[I 2024-04-29 11:28:37,122] Trial 20 finished with value: 14.56 and parameters: {'radius': 8, 'fpsize': 256}. Best is trial 19 with value: 13.56.\n",
|
181 |
+
"[I 2024-04-29 11:28:38,175] Trial 21 finished with value: 30.28 and parameters: {'radius': 8, 'fpsize': 128}. Best is trial 19 with value: 13.56.\n",
|
182 |
+
"[I 2024-04-29 11:28:39,406] Trial 22 finished with value: 13.56 and parameters: {'radius': 9, 'fpsize': 256}. Best is trial 19 with value: 13.56.\n",
|
183 |
+
"[I 2024-04-29 11:28:40,649] Trial 23 finished with value: 13.56 and parameters: {'radius': 9, 'fpsize': 256}. Best is trial 19 with value: 13.56.\n",
|
184 |
+
"[I 2024-04-29 11:28:41,868] Trial 24 finished with value: 12.56 and parameters: {'radius': 10, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n",
|
185 |
+
"[I 2024-04-29 11:28:43,109] Trial 25 finished with value: 12.56 and parameters: {'radius': 10, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n",
|
186 |
+
"[I 2024-04-29 11:28:44,587] Trial 26 finished with value: 16.4 and parameters: {'radius': 10, 'fpsize': 640}. Best is trial 24 with value: 12.56.\n",
|
187 |
+
"[I 2024-04-29 11:28:46,599] Trial 27 finished with value: 25.08 and parameters: {'radius': 11, 'fpsize': 1408}. Best is trial 24 with value: 12.56.\n",
|
188 |
+
"[I 2024-04-29 11:28:48,015] Trial 28 finished with value: 31.96 and parameters: {'radius': 6, 'fpsize': 896}. Best is trial 24 with value: 12.56.\n",
|
189 |
+
"[I 2024-04-29 11:28:49,347] Trial 29 finished with value: 23.4 and parameters: {'radius': 7, 'fpsize': 640}. Best is trial 24 with value: 12.56.\n",
|
190 |
+
"[I 2024-04-29 11:28:51,503] Trial 30 finished with value: 27.64 and parameters: {'radius': 11, 'fpsize': 1664}. Best is trial 24 with value: 12.56.\n",
|
191 |
+
"[I 2024-04-29 11:28:52,657] Trial 31 finished with value: 13.56 and parameters: {'radius': 9, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n",
|
192 |
+
"[I 2024-04-29 11:28:53,840] Trial 32 finished with value: 12.56 and parameters: {'radius': 10, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n",
|
193 |
+
"[I 2024-04-29 11:28:55,159] Trial 33 finished with value: 13.84 and parameters: {'radius': 10, 'fpsize': 384}. Best is trial 24 with value: 12.56.\n",
|
194 |
+
"[I 2024-04-29 11:28:56,140] Trial 34 finished with value: 39.28 and parameters: {'radius': 7, 'fpsize': 128}. Best is trial 24 with value: 12.56.\n",
|
195 |
+
"[I 2024-04-29 11:28:57,508] Trial 35 finished with value: 14.84 and parameters: {'radius': 11, 'fpsize': 384}. Best is trial 24 with value: 12.56.\n",
|
196 |
+
"[I 2024-04-29 11:28:58,900] Trial 36 finished with value: 15.120000000000001 and parameters: {'radius': 10, 'fpsize': 512}. Best is trial 24 with value: 12.56.\n",
|
197 |
+
"[I 2024-04-29 11:29:00,203] Trial 37 finished with value: 14.56 and parameters: {'radius': 12, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n",
|
198 |
+
"[I 2024-04-29 11:29:02,225] Trial 38 finished with value: 49.2 and parameters: {'radius': 5, 'fpsize': 1920}. Best is trial 24 with value: 12.56.\n",
|
199 |
+
"[I 2024-04-29 11:29:03,942] Trial 39 finished with value: 22.52 and parameters: {'radius': 8, 'fpsize': 1152}. Best is trial 24 with value: 12.56.\n",
|
200 |
+
"[I 2024-04-29 11:29:05,240] Trial 40 finished with value: 13.84 and parameters: {'radius': 10, 'fpsize': 384}. Best is trial 24 with value: 12.56.\n",
|
201 |
+
"[I 2024-04-29 11:29:06,396] Trial 41 finished with value: 13.56 and parameters: {'radius': 9, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n",
|
202 |
+
"[I 2024-04-29 11:29:07,422] Trial 42 finished with value: 30.28 and parameters: {'radius': 8, 'fpsize': 128}. Best is trial 24 with value: 12.56.\n",
|
203 |
+
"[I 2024-04-29 11:29:08,590] Trial 43 finished with value: 13.56 and parameters: {'radius': 9, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n",
|
204 |
+
"[I 2024-04-29 11:29:09,949] Trial 44 finished with value: 14.84 and parameters: {'radius': 11, 'fpsize': 384}. Best is trial 24 with value: 12.56.\n",
|
205 |
+
"[I 2024-04-29 11:29:11,378] Trial 45 finished with value: 15.120000000000001 and parameters: {'radius': 10, 'fpsize': 512}. Best is trial 24 with value: 12.56.\n",
|
206 |
+
"[I 2024-04-29 11:29:12,637] Trial 46 finished with value: 26.4 and parameters: {'radius': 6, 'fpsize': 640}. Best is trial 24 with value: 12.56.\n",
|
207 |
+
"[I 2024-04-29 11:29:14,232] Trial 47 finished with value: 18.68 and parameters: {'radius': 11, 'fpsize': 768}. Best is trial 24 with value: 12.56.\n",
|
208 |
+
"[I 2024-04-29 11:29:14,904] Trial 48 finished with value: 214.28 and parameters: {'radius': 2, 'fpsize': 128}. Best is trial 24 with value: 12.56.\n",
|
209 |
+
"[I 2024-04-29 11:29:16,323] Trial 49 finished with value: 16.56 and parameters: {'radius': 14, 'fpsize': 256}. Best is trial 24 with value: 12.56.\n"
|
210 |
+
]
|
211 |
+
}
|
212 |
+
],
|
213 |
+
"source": [
|
214 |
+
"sampler = optuna.samplers.TPESampler(seed=42)\n",
|
215 |
+
"study = optuna.create_study(sampler=sampler, direction='minimize')\n",
|
216 |
+
"study.optimize(objective, n_trials=50, show_progress_bar=True)"
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "code",
|
221 |
+
"execution_count": 31,
|
222 |
+
"metadata": {},
|
223 |
+
"outputs": [
|
224 |
+
{
|
225 |
+
"name": "stdout",
|
226 |
+
"output_type": "stream",
|
227 |
+
"text": [
|
228 |
+
"Radius: 10\n",
|
229 |
+
"FP length: 256\n",
|
230 |
+
"Number of unique SMILES: 532\n",
|
231 |
+
"Number of unique fingerprints: 532\n",
|
232 |
+
"Number of SMILES with overlapping fingerprints: 0\n",
|
233 |
+
"Number of overlapping SMILES in train_val_df: 0\n",
|
234 |
+
"Number of overlapping SMILES in protac_df: 0\n"
|
235 |
+
]
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"data": {
|
239 |
+
"text/plain": [
|
240 |
+
"12.56"
|
241 |
+
]
|
242 |
+
},
|
243 |
+
"execution_count": 31,
|
244 |
+
"metadata": {},
|
245 |
+
"output_type": "execute_result"
|
246 |
+
}
|
247 |
+
],
|
248 |
+
"source": [
|
249 |
+
"# Run objective with best params and verbose\n",
|
250 |
+
"objective(study.best_trial, verbose=1)"
|
251 |
+
]
|
252 |
+
}
|
253 |
+
],
|
254 |
+
"metadata": {
|
255 |
+
"kernelspec": {
|
256 |
+
"display_name": "Python 3 (ipykernel)",
|
257 |
+
"language": "python",
|
258 |
+
"name": "python3"
|
259 |
+
},
|
260 |
+
"language_info": {
|
261 |
+
"codemirror_mode": {
|
262 |
+
"name": "ipython",
|
263 |
+
"version": 3
|
264 |
+
},
|
265 |
+
"file_extension": ".py",
|
266 |
+
"mimetype": "text/x-python",
|
267 |
+
"name": "python",
|
268 |
+
"nbconvert_exporter": "python",
|
269 |
+
"pygments_lexer": "ipython3",
|
270 |
+
"version": "3.10.8"
|
271 |
+
}
|
272 |
+
},
|
273 |
+
"nbformat": 4,
|
274 |
+
"nbformat_minor": 2
|
275 |
+
}
|
protac_degradation_predictor/config.py
CHANGED
@@ -3,7 +3,7 @@ from dataclasses import dataclass, field
|
|
3 |
@dataclass(frozen=True)
|
4 |
class Config:
|
5 |
# Embeddings information
|
6 |
-
morgan_radius: int = 15
|
7 |
fingerprint_size: int = 256 # 224
|
8 |
protein_embedding_size: int = 1024
|
9 |
cell_embedding_size: int = 768
|
|
|
3 |
@dataclass(frozen=True)
|
4 |
class Config:
|
5 |
# Embeddings information
|
6 |
+
morgan_radius: int = 10 # 15
|
7 |
fingerprint_size: int = 256 # 224
|
8 |
protein_embedding_size: int = 1024
|
9 |
cell_embedding_size: int = 768
|
protac_degradation_predictor/optuna_utils.py
CHANGED
@@ -173,13 +173,11 @@ def pytorch_model_objective(
|
|
173 |
disabled_embeddings=disabled_embeddings,
|
174 |
)
|
175 |
if test_df is not None:
|
176 |
-
_,
|
177 |
test_preds.append(test_pred)
|
178 |
else:
|
179 |
-
_,
|
180 |
-
train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
|
181 |
stats.update(metrics)
|
182 |
-
stats.update(train_metrics)
|
183 |
report.append(stats.copy())
|
184 |
val_preds.append(val_pred)
|
185 |
|
@@ -252,7 +250,7 @@ def hyperparameter_tuning_and_training(
|
|
252 |
batch_size_options = [4, 8, 16, 32, 64, 128]
|
253 |
learning_rate_options = (1e-5, 1e-3) # min and max values for loguniform distribution
|
254 |
smote_k_neighbors_options = list(range(3, 16))
|
255 |
-
dropout_options = (0.
|
256 |
|
257 |
# Set the verbosity of Optuna
|
258 |
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
@@ -325,13 +323,6 @@ def hyperparameter_tuning_and_training(
|
|
325 |
metrics['test_model_id'] = i
|
326 |
metrics.update(dfs_stats)
|
327 |
|
328 |
-
# Add the training metrics
|
329 |
-
train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
|
330 |
-
logging.info(f'Training metrics: {train_metrics}')
|
331 |
-
logging.info(f'Training trainer.logged_metrics: {trainer.logged_metrics}')
|
332 |
-
logging.info(f'Training trainer.callback_metrics: {trainer.callback_metrics}')
|
333 |
-
|
334 |
-
metrics.update(train_metrics)
|
335 |
test_report.append(metrics.copy())
|
336 |
test_preds.append(test_pred)
|
337 |
test_report = pd.DataFrame(test_report)
|
|
|
173 |
disabled_embeddings=disabled_embeddings,
|
174 |
)
|
175 |
if test_df is not None:
|
176 |
+
_, _, metrics, val_pred, test_pred = ret
|
177 |
test_preds.append(test_pred)
|
178 |
else:
|
179 |
+
_, _, metrics, val_pred = ret
|
|
|
180 |
stats.update(metrics)
|
|
|
181 |
report.append(stats.copy())
|
182 |
val_preds.append(val_pred)
|
183 |
|
|
|
250 |
batch_size_options = [4, 8, 16, 32, 64, 128]
|
251 |
learning_rate_options = (1e-5, 1e-3) # min and max values for loguniform distribution
|
252 |
smote_k_neighbors_options = list(range(3, 16))
|
253 |
+
dropout_options = (0.2, 0.9)
|
254 |
|
255 |
# Set the verbosity of Optuna
|
256 |
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
|
|
323 |
metrics['test_model_id'] = i
|
324 |
metrics.update(dfs_stats)
|
325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
test_report.append(metrics.copy())
|
327 |
test_preds.append(test_pred)
|
328 |
test_report = pd.DataFrame(test_report)
|
protac_degradation_predictor/pytorch_models.py
CHANGED
@@ -207,8 +207,8 @@ class PROTAC_Model(pl.LightningModule):
|
|
207 |
'precision': Precision(task='binary'),
|
208 |
'recall': Recall(task='binary'),
|
209 |
'f1_score': F1Score(task='binary'),
|
210 |
-
'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
|
211 |
-
'hp_metric': Accuracy(task='binary'),
|
212 |
}, prefix=s.replace('metrics', '')) for s in stages})
|
213 |
|
214 |
# Misc settings
|
@@ -314,8 +314,8 @@ class PROTAC_Model(pl.LightningModule):
|
|
314 |
'lr_scheduler': optim.lr_scheduler.ReduceLROnPlateau(
|
315 |
optimizer=optimizer,
|
316 |
mode='min',
|
317 |
-
factor=0.
|
318 |
-
patience=
|
319 |
),
|
320 |
'interval': 'step', # or 'epoch'
|
321 |
'frequency': 1,
|
@@ -508,6 +508,7 @@ def train_model(
|
|
508 |
logger=loggers if use_logger else False,
|
509 |
callbacks=callbacks,
|
510 |
max_epochs=max_epochs,
|
|
|
511 |
fast_dev_run=fast_dev_run,
|
512 |
enable_model_summary=False,
|
513 |
enable_checkpointing=enable_checkpointing,
|
@@ -534,11 +535,22 @@ def train_model(
|
|
534 |
with warnings.catch_warnings():
|
535 |
warnings.simplefilter("ignore")
|
536 |
trainer.fit(model)
|
537 |
-
metrics =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
538 |
# Add test metrics to metrics
|
539 |
if test_df is not None:
|
540 |
test_metrics = trainer.test(model, verbose=False)[0]
|
|
|
541 |
metrics.update(test_metrics)
|
|
|
|
|
542 |
if return_predictions:
|
543 |
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
|
544 |
val_pred = trainer.predict(model, val_dl)
|
|
|
207 |
'precision': Precision(task='binary'),
|
208 |
'recall': Recall(task='binary'),
|
209 |
'f1_score': F1Score(task='binary'),
|
210 |
+
# 'opt_score': Accuracy(task='binary') + F1Score(task='binary'),
|
211 |
+
# 'hp_metric': Accuracy(task='binary'),
|
212 |
}, prefix=s.replace('metrics', '')) for s in stages})
|
213 |
|
214 |
# Misc settings
|
|
|
314 |
'lr_scheduler': optim.lr_scheduler.ReduceLROnPlateau(
|
315 |
optimizer=optimizer,
|
316 |
mode='min',
|
317 |
+
factor=0.1,
|
318 |
+
patience=0,
|
319 |
),
|
320 |
'interval': 'step', # or 'epoch'
|
321 |
'frequency': 1,
|
|
|
508 |
logger=loggers if use_logger else False,
|
509 |
callbacks=callbacks,
|
510 |
max_epochs=max_epochs,
|
511 |
+
val_check_interval=0.5,
|
512 |
fast_dev_run=fast_dev_run,
|
513 |
enable_model_summary=False,
|
514 |
enable_checkpointing=enable_checkpointing,
|
|
|
535 |
with warnings.catch_warnings():
|
536 |
warnings.simplefilter("ignore")
|
537 |
trainer.fit(model)
|
538 |
+
metrics = {}
|
539 |
+
# Add train metrics
|
540 |
+
train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
|
541 |
+
metrics.update(train_metrics)
|
542 |
+
# Add validation metrics
|
543 |
+
val_metrics = trainer.validate(model, verbose=False)[0]
|
544 |
+
val_metrics = {m: v for m, v in val_metrics.items() if 'val' in m}
|
545 |
+
metrics.update(val_metrics)
|
546 |
+
|
547 |
# Add test metrics to metrics
|
548 |
if test_df is not None:
|
549 |
test_metrics = trainer.test(model, verbose=False)[0]
|
550 |
+
test_metrics = {m: v for m, v in test_metrics.items() if 'test' in m}
|
551 |
metrics.update(test_metrics)
|
552 |
+
|
553 |
+
# Return predictions
|
554 |
if return_predictions:
|
555 |
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
|
556 |
val_pred = trainer.predict(model, val_dl)
|
src/plot_experiment_results.py
CHANGED
@@ -12,7 +12,7 @@ import numpy as np
|
|
12 |
palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']
|
13 |
|
14 |
|
15 |
-
def
|
16 |
# Clean the data
|
17 |
df = df.dropna(how='all', axis=1)
|
18 |
|
@@ -37,6 +37,10 @@ def plot_metrics(df, title):
|
|
37 |
ax[1].set_ylabel('Accuracy')
|
38 |
ax[1].legend(loc='lower right')
|
39 |
ax[1].grid(axis='both', alpha=0.5)
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# Plot training ROC-AUC
|
42 |
ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
|
@@ -44,16 +48,16 @@ def plot_metrics(df, title):
|
|
44 |
ax[2].set_ylabel('ROC-AUC')
|
45 |
ax[2].legend(loc='lower right')
|
46 |
ax[2].grid(axis='both', alpha=0.5)
|
47 |
-
|
|
|
48 |
# Set x-axis label
|
49 |
ax[2].set_xlabel('Epoch')
|
50 |
|
51 |
-
plt.title(title)
|
52 |
plt.tight_layout()
|
53 |
-
plt.savefig(f'plots/{
|
54 |
|
55 |
|
56 |
-
def
|
57 |
|
58 |
# Extract and prepare CV data
|
59 |
cv_data = df_cv[['model_type', 'fold', 'val_acc', 'val_roc_auc', 'test_acc', 'test_roc_auc', 'split_type']]
|
@@ -114,7 +118,13 @@ def plot_report(df_cv, df_test, title=None):
|
|
114 |
|
115 |
# Plotting
|
116 |
plt.figure(figsize=(12, 6))
|
117 |
-
sns.barplot(
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
plt.title('')
|
119 |
plt.ylabel('')
|
120 |
plt.xlabel('')
|
@@ -134,9 +144,9 @@ def plot_report(df_cv, df_test, title=None):
|
|
134 |
if p.get_height() < 0.01:
|
135 |
continue
|
136 |
if i % 2 == 0:
|
137 |
-
value = '{
|
138 |
else:
|
139 |
-
value = '{
|
140 |
|
141 |
print(f'Plotting value: {p.get_height()} -> {value}')
|
142 |
x = p.get_x() + p.get_width() / 2
|
@@ -146,6 +156,120 @@ def plot_report(df_cv, df_test, title=None):
|
|
146 |
plt.savefig(f'plots/{title}.pdf', bbox_inches='tight')
|
147 |
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
def main():
|
150 |
active_col = 'Active (Dmax 0.6, pDC50 6.0)'
|
151 |
test_split = 0.1
|
@@ -156,28 +280,50 @@ def main():
|
|
156 |
|
157 |
# Load the data
|
158 |
reports = {
|
159 |
-
'cv_train': pd.
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
}
|
164 |
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
|
175 |
-
|
176 |
|
177 |
|
178 |
df_val = reports['cv_train']
|
179 |
df_test = reports['test']
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
|
183 |
if __name__ == '__main__':
|
|
|
12 |
palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']
|
13 |
|
14 |
|
15 |
+
def plot_training_curves(df, split_type):
|
16 |
# Clean the data
|
17 |
df = df.dropna(how='all', axis=1)
|
18 |
|
|
|
37 |
ax[1].set_ylabel('Accuracy')
|
38 |
ax[1].legend(loc='lower right')
|
39 |
ax[1].grid(axis='both', alpha=0.5)
|
40 |
+
# Set limit to y-axis
|
41 |
+
ax[1].set_ylim(0, 1.0)
|
42 |
+
# Set y-axis to percentage
|
43 |
+
ax[1].yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))
|
44 |
|
45 |
# Plot training ROC-AUC
|
46 |
ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
|
|
|
48 |
ax[2].set_ylabel('ROC-AUC')
|
49 |
ax[2].legend(loc='lower right')
|
50 |
ax[2].grid(axis='both', alpha=0.5)
|
51 |
+
# Set limit to y-axis
|
52 |
+
ax[2].set_ylim(0, 1.0)
|
53 |
# Set x-axis label
|
54 |
ax[2].set_xlabel('Epoch')
|
55 |
|
|
|
56 |
plt.tight_layout()
|
57 |
+
plt.savefig(f'plots/training_metrics_{split_type}.pdf', bbox_inches='tight')
|
58 |
|
59 |
|
60 |
+
def plot_performance_metrics(df_cv, df_test, title=None):
|
61 |
|
62 |
# Extract and prepare CV data
|
63 |
cv_data = df_cv[['model_type', 'fold', 'val_acc', 'val_roc_auc', 'test_acc', 'test_roc_auc', 'split_type']]
|
|
|
118 |
|
119 |
# Plotting
|
120 |
plt.figure(figsize=(12, 6))
|
121 |
+
sns.barplot(
|
122 |
+
data=combined_data,
|
123 |
+
x='Metric',
|
124 |
+
y='Score',
|
125 |
+
hue='Split Type',
|
126 |
+
errorbar=('sd', 1),
|
127 |
+
palette=palette)
|
128 |
plt.title('')
|
129 |
plt.ylabel('')
|
130 |
plt.xlabel('')
|
|
|
144 |
if p.get_height() < 0.01:
|
145 |
continue
|
146 |
if i % 2 == 0:
|
147 |
+
value = f'{p.get_height():.1%}'
|
148 |
else:
|
149 |
+
value = f'{p.get_height():.3f}'
|
150 |
|
151 |
print(f'Plotting value: {p.get_height()} -> {value}')
|
152 |
x = p.get_x() + p.get_width() / 2
|
|
|
156 |
plt.savefig(f'plots/{title}.pdf', bbox_inches='tight')
|
157 |
|
158 |
|
159 |
+
def plot_ablation_study(report):
|
160 |
+
# Define the ablation study combinations
|
161 |
+
ablation_study_combinations = [
|
162 |
+
'disabled smiles',
|
163 |
+
'disabled poi',
|
164 |
+
'disabled e3',
|
165 |
+
'disabled cell',
|
166 |
+
'disabled poi e3 smiles',
|
167 |
+
'disabled poi e3 cell',
|
168 |
+
]
|
169 |
+
|
170 |
+
for group in report['split_type'].unique():
|
171 |
+
baseline = report[report['disabled_embeddings'].isna()].copy()
|
172 |
+
baseline = baseline[baseline['split_type'] == group]
|
173 |
+
baseline['disabled_embeddings'] = 'all embeddings enabled'
|
174 |
+
# metrics_to_show = ['val_acc', 'test_acc']
|
175 |
+
metrics_to_show = ['test_acc']
|
176 |
+
# baseline = baseline.melt(id_vars=['fold', 'disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')
|
177 |
+
baseline = baseline.melt(id_vars=['disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')
|
178 |
+
|
179 |
+
print(f'Group: {group}, avg: {(0.755814 + 0.720930 + 0.732558) / 3:.1%}')
|
180 |
+
print(f'Group: {group}, avg: {(0.7558139562606812 + 0.7209302186965942 + 0.7325581312179565) / 3:.1%}')
|
181 |
+
print(baseline)
|
182 |
+
|
183 |
+
ablation_dfs = []
|
184 |
+
for disabled_embeddings in ablation_study_combinations:
|
185 |
+
if pd.isnull(disabled_embeddings):
|
186 |
+
continue
|
187 |
+
tmp = report[report['disabled_embeddings'] == disabled_embeddings].copy()
|
188 |
+
tmp = tmp[tmp['split_type'] == group]
|
189 |
+
# tmp = tmp.melt(id_vars=['fold', 'disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')
|
190 |
+
tmp = tmp.melt(id_vars=['disabled_embeddings'], value_vars=metrics_to_show, var_name='metric', value_name='score')
|
191 |
+
ablation_dfs.append(tmp)
|
192 |
+
ablation_df = pd.concat(ablation_dfs)
|
193 |
+
|
194 |
+
# dummy_val_df = pd.DataFrame()
|
195 |
+
# tmp = report[report['split_type'] == group]
|
196 |
+
# dummy_val_df['score'] = tmp[['val_active_perc', 'val_inactive_perc']].max(axis=1)
|
197 |
+
# dummy_val_df['metric'] = 'val_acc'
|
198 |
+
# dummy_val_df['disabled_embeddings'] = 'dummy'
|
199 |
+
|
200 |
+
dummy_test_df = pd.DataFrame()
|
201 |
+
tmp = report[report['split_type'] == group]
|
202 |
+
dummy_test_df['score'] = tmp[['test_active_perc', 'test_inactive_perc']].max(axis=1)
|
203 |
+
dummy_test_df['metric'] = 'test_acc'
|
204 |
+
dummy_test_df['disabled_embeddings'] = 'dummy'
|
205 |
+
|
206 |
+
# dummy_df = pd.concat([dummy_val_df, dummy_test_df])
|
207 |
+
dummy_df = dummy_test_df
|
208 |
+
|
209 |
+
final_df = pd.concat([dummy_df, baseline, ablation_df])
|
210 |
+
|
211 |
+
final_df['metric'] = final_df['metric'].map({
|
212 |
+
'val_acc': 'Validation Accuracy',
|
213 |
+
'test_acc': 'Test Accuracy',
|
214 |
+
'val_roc_auc': 'Val ROC-AUC',
|
215 |
+
'test_roc_auc': 'Test ROC-AUC',
|
216 |
+
})
|
217 |
+
|
218 |
+
final_df['disabled_embeddings'] = final_df['disabled_embeddings'].map({
|
219 |
+
'all embeddings enabled': 'All embeddings enabled',
|
220 |
+
'dummy': 'Dummy model',
|
221 |
+
'disabled smiles': 'Disabled compound information',
|
222 |
+
'disabled e3': 'Disabled E3 information',
|
223 |
+
'disabled poi': 'Disabled target information',
|
224 |
+
'disabled cell': 'Disabled cell information',
|
225 |
+
'disabled poi e3 smiles': 'Disabled compound, E3, and target info\n(only cell information left)',
|
226 |
+
'disabled poi e3 cell': 'Disabled cell, E3, and target info\n(only compound information left)',
|
227 |
+
})
|
228 |
+
|
229 |
+
# Print final_df to latex
|
230 |
+
tmp = final_df.groupby(['disabled_embeddings', 'metric']).mean().round(3)
|
231 |
+
# Remove fold column to tmp
|
232 |
+
tmp = tmp.reset_index() #.drop('fold', axis=1)
|
233 |
+
|
234 |
+
# fig, ax = plt.subplots(figsize=(5, 5))
|
235 |
+
fig, ax = plt.subplots()
|
236 |
+
|
237 |
+
sns.barplot(data=final_df,
|
238 |
+
y='disabled_embeddings',
|
239 |
+
x='score',
|
240 |
+
hue='metric',
|
241 |
+
ax=ax,
|
242 |
+
errorbar=('sd', 1),
|
243 |
+
palette=sns.color_palette(palette, len(palette)),
|
244 |
+
saturation=1,
|
245 |
+
)
|
246 |
+
|
247 |
+
# ax.set_title(f'{group.replace("random", "standard")} CV split')
|
248 |
+
ax.grid(axis='x', alpha=0.5)
|
249 |
+
ax.tick_params(axis='y', rotation=0)
|
250 |
+
ax.set_xlim(0, 1.0)
|
251 |
+
ax.xaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))
|
252 |
+
ax.set_ylabel('')
|
253 |
+
ax.set_xlabel('')
|
254 |
+
# Set the legend outside the plot and below
|
255 |
+
# ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.08), ncol=2)
|
256 |
+
# Set the legend in the upper right corner
|
257 |
+
ax.legend(loc='upper right')
|
258 |
+
|
259 |
+
# For each bar, add the rotated value (as percentage), inside the bar
|
260 |
+
for i, p in enumerate(plt.gca().patches):
|
261 |
+
# TODO: For some reasons, there is an additional bar being added at
|
262 |
+
# the end of the plot... it's not in the dataframe
|
263 |
+
if i == len(plt.gca().patches) - 1:
|
264 |
+
continue
|
265 |
+
value = '{:.1f}%'.format(100 * p.get_width())
|
266 |
+
y = p.get_y() + p.get_height() / 2
|
267 |
+
x = 0.4 # p.get_height() - p.get_height() / 2
|
268 |
+
plt.annotate(value, (x, y), ha='center', va='center', color='black', fontsize=10, alpha=0.8)
|
269 |
+
|
270 |
+
plt.savefig(f'plots/ablation_study_{group}.pdf', bbox_inches='tight')
|
271 |
+
|
272 |
+
|
273 |
def main():
|
274 |
active_col = 'Active (Dmax 0.6, pDC50 6.0)'
|
275 |
test_split = 0.1
|
|
|
280 |
|
281 |
# Load the data
|
282 |
reports = {
|
283 |
+
'cv_train': pd.concat([
|
284 |
+
pd.read_csv(f'reports/cv_report_{report_base_name}_random.csv'),
|
285 |
+
pd.read_csv(f'reports/cv_report_{report_base_name}_uniprot.csv'),
|
286 |
+
pd.read_csv(f'reports/cv_report_{report_base_name}_tanimoto.csv'),
|
287 |
+
]),
|
288 |
+
'test': pd.concat([
|
289 |
+
pd.read_csv(f'reports/test_report_{report_base_name}_random.csv'),
|
290 |
+
pd.read_csv(f'reports/test_report_{report_base_name}_uniprot.csv'),
|
291 |
+
pd.read_csv(f'reports/test_report_{report_base_name}_tanimoto.csv'),
|
292 |
+
]),
|
293 |
+
'ablation': pd.concat([
|
294 |
+
pd.read_csv(f'reports/ablation_report_{report_base_name}_random.csv'),
|
295 |
+
pd.read_csv(f'reports/ablation_report_{report_base_name}_uniprot.csv'),
|
296 |
+
pd.read_csv(f'reports/ablation_report_{report_base_name}_tanimoto.csv'),
|
297 |
+
]),
|
298 |
+
'hparam': pd.concat([
|
299 |
+
pd.read_csv(f'reports/hparam_report_{report_base_name}_random.csv'),
|
300 |
+
pd.read_csv(f'reports/hparam_report_{report_base_name}_uniprot.csv'),
|
301 |
+
pd.read_csv(f'reports/hparam_report_{report_base_name}_tanimoto.csv'),
|
302 |
+
]),
|
303 |
}
|
304 |
|
305 |
|
306 |
+
metrics = {}
|
307 |
+
for i in range(n_models_for_test):
|
308 |
+
for split_type in ['random', 'tanimoto', 'uniprot', 'e3_ligase']:
|
309 |
+
logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
|
310 |
+
metrics[f'{split_type}_{i}'] = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
|
311 |
+
metrics[f'{split_type}_{i}']['model_id'] = i
|
312 |
+
# Rename 'val_' columns to 'test_' columns
|
313 |
+
metrics[f'{split_type}_{i}'] = metrics[f'{split_type}_{i}'].rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'})
|
314 |
|
315 |
+
plot_training_curves(metrics[f'{split_type}_{i}'], f'{split_type}_{i}')
|
316 |
|
317 |
|
318 |
df_val = reports['cv_train']
|
319 |
df_test = reports['test']
|
320 |
+
plot_performance_metrics(df_val, df_test, title=f'{active_name}_metrics')
|
321 |
+
|
322 |
+
reports['test']['disabled_embeddings'] = pd.NA
|
323 |
+
plot_ablation_study(pd.concat([
|
324 |
+
reports['ablation'],
|
325 |
+
reports['test'],
|
326 |
+
]))
|
327 |
|
328 |
|
329 |
if __name__ == '__main__':
|