melvinalves commited on
Commit
de647c2
·
verified ·
1 Parent(s): 3eb6bf9

Delete notebooks

Browse files
notebooks/Ensemble.ipynb DELETED
@@ -1,292 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "id": "0fbbb46c-1a00-4585-9ecd-a490a46e8b99",
7
- "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stdout",
11
- "output_type": "stream",
12
- "text": [
13
- "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
14
- "mlb com 597 GO terms guardado em data/mlb_597.pkl\n",
15
- " ProtBERT Fmax=0.6616 Thr=0.40 AuPRC=0.7009 Smin=13.9047\n",
16
- " ProtBERT-BFD Fmax=0.6573 Thr=0.41 AuPRC=0.6925 Smin=13.7060\n",
17
- " ESM-2 Fmax=0.6375 Thr=0.39 AuPRC=0.6875 Smin=14.1194\n",
18
- " Ensemble Fmax=0.6864 Thr=0.37 AuPRC=0.7332 Smin=12.7879\n"
19
- ]
20
- }
21
- ],
22
- "source": [
23
- "# %%\n",
24
- "import numpy as np, joblib, math\n",
25
- "from sklearn.metrics import precision_recall_curve, auc\n",
26
- "from goatools.obo_parser import GODag\n",
27
- "from sklearn.preprocessing import MultiLabelBinarizer\n",
28
- "\n",
29
- "GO_FILE = \"go.obo\"\n",
30
- "dag = GODag(GO_FILE)\n",
31
- "\n",
32
- "# ---------- 1. y_true + GO terms (referência ProtBERT) ----------\n",
33
- "test_pb = joblib.load(\"embeddings/test_protbert.pkl\")\n",
34
- "y_true = test_pb[\"labels\"] # (1724, 597) ← ground-truth\n",
35
- "go_ref = list(test_pb[\"go_terms\"]) # ordem exacta das colunas\n",
36
- "\n",
37
- "n_go = len(go_ref) # 597\n",
38
- "\n",
39
- "# --- Recriar o MultiLabelBinarizer com os 597 termos corretos ---\n",
40
- "mlb = MultiLabelBinarizer(classes=go_ref)\n",
41
- "mlb.fit([go_ref]) # necessário para permitir inverse_transform depois\n",
42
- "\n",
43
- "# ---------- 2. Carregar predições ----------\n",
44
- "y_pb = np.load(\"predictions/mf-protbert-pam1.npy\") # 1724×597\n",
45
- "y_bfd = np.load(\"predictions/mf-protbertbfd-pam1.npy\") # 1724×597\n",
46
- "y_esm0 = np.load(\"predictions/mf-esm2.npy\") # 1724×602\n",
47
- "\n",
48
- "# ---------- 3. Remapear ESM-2 para ordem ProtBERT ----------\n",
49
- "mlb_esm = joblib.load(\"data/mlb.pkl\") # 602 GO terms\n",
50
- "idx_map = [list(mlb_esm.classes_).index(t) for t in go_ref]\n",
51
- "y_esm = y_esm0[:, idx_map] # 1724×597\n",
52
- "\n",
53
- "# ---------- 4. Garantir shapes iguais ----------\n",
54
- "assert (y_true.shape == y_pb.shape == y_bfd.shape\n",
55
- " == y_esm.shape == (1724, n_go)), \"Ainda há desalinhamento!\"\n",
56
- "\n",
57
- "# ---------- 4. Guardar mlb (y_true) alinhado ----------\n",
58
- "joblib.dump(mlb, \"data/mlb_597.pkl\")\n",
59
- "print(\"mlb com 597 GO terms guardado em data/mlb_597.pkl\")\n",
60
- "\n",
61
- "# ---------- 5. Métricas ----------\n",
62
- "THR = np.linspace(0,1,101)\n",
63
- "def fmax(y_t,y_p):\n",
64
- " best,thr = 0,0\n",
65
- " for t in THR:\n",
66
- " y_b = (y_p>=t).astype(int)\n",
67
- " tp = (y_t*y_b).sum(1); fp=((1-y_t)*y_b).sum(1); fn=(y_t*(1-y_b)).sum(1)\n",
68
- " f1 = 2*tp/(2*tp+fp+fn+1e-8); m=f1.mean()\n",
69
- " if m>best: best,thr = m,t\n",
70
- " return best,thr\n",
71
- "\n",
72
- "def auprc(y_t,y_p):\n",
73
- " p,r,_ = precision_recall_curve(y_t.ravel(), y_p.ravel()); return auc(r,p)\n",
74
- "\n",
75
- "def smin(y_t,y_p,thr,alpha=0.5):\n",
76
- " y_b=(y_p>=thr).astype(int)\n",
77
- " ic=-(np.log((y_t+y_b).sum(0)+1e-8)-np.log((y_t+y_b).sum()+1e-8))\n",
78
- " ru=np.logical_and(y_b, np.logical_not(y_t))*ic\n",
79
- " mi=np.logical_and(y_t, np.logical_not(y_b))*ic\n",
80
- " return np.sqrt((alpha*ru.sum(1))**2 + ((1-alpha)*mi.sum(1))**2).mean()\n",
81
- "\n",
82
- "def show(name,y_p):\n",
83
- " f,thr=fmax(y_true,y_p)\n",
84
- " print(f\"{name:>13s} Fmax={f:.4f} Thr={thr:.2f} \"\n",
85
- " f\"AuPRC={auprc(y_true,y_p):.4f} Smin={smin(y_true,y_p,thr):.4f}\")\n",
86
- "\n",
87
- "show(\"ProtBERT\", y_pb)\n",
88
- "show(\"ProtBERT-BFD\", y_bfd)\n",
89
- "show(\"ESM-2\", y_esm)\n",
90
- "show(\"Ensemble\", (y_pb + y_bfd + y_esm)/3)\n",
91
- "\n"
92
- ]
93
- },
94
- {
95
- "cell_type": "code",
96
- "execution_count": 9,
97
- "id": "f1807404-c2ce-48d0-b87c-a7e0fecc1728",
98
- "metadata": {},
99
- "outputs": [
100
- {
101
- "name": "stdout",
102
- "output_type": "stream",
103
- "text": [
104
- "Epoch 1/50\n",
105
- "19/19 [==============================] - 0s 17ms/step - loss: 0.3811 - val_loss: 0.0868\n",
106
- "Epoch 2/50\n",
107
- "19/19 [==============================] - 0s 8ms/step - loss: 0.0882 - val_loss: 0.0696\n",
108
- "Epoch 3/50\n",
109
- "19/19 [==============================] - 0s 6ms/step - loss: 0.0628 - val_loss: 0.0563\n",
110
- "Epoch 4/50\n",
111
- "19/19 [==============================] - 0s 6ms/step - loss: 0.0552 - val_loss: 0.0520\n",
112
- "Epoch 5/50\n",
113
- "19/19 [==============================] - 0s 5ms/step - loss: 0.0507 - val_loss: 0.0486\n",
114
- "Epoch 6/50\n",
115
- "19/19 [==============================] - 0s 5ms/step - loss: 0.0473 - val_loss: 0.0455\n",
116
- "Epoch 7/50\n",
117
- "19/19 [==============================] - 0s 5ms/step - loss: 0.0437 - val_loss: 0.0431\n",
118
- "Epoch 8/50\n",
119
- "19/19 [==============================] - 0s 8ms/step - loss: 0.0414 - val_loss: 0.0414\n",
120
- "Epoch 9/50\n",
121
- "19/19 [==============================] - 0s 7ms/step - loss: 0.0391 - val_loss: 0.0395\n",
122
- "Epoch 10/50\n",
123
- "19/19 [==============================] - 0s 3ms/step - loss: 0.0371 - val_loss: 0.0383\n",
124
- "Epoch 11/50\n",
125
- "19/19 [==============================] - 0s 3ms/step - loss: 0.0355 - val_loss: 0.0372\n",
126
- "Epoch 12/50\n",
127
- "19/19 [==============================] - 0s 4ms/step - loss: 0.0341 - val_loss: 0.0362\n",
128
- "Epoch 13/50\n",
129
- "19/19 [==============================] - 0s 3ms/step - loss: 0.0329 - val_loss: 0.0352\n",
130
- "Epoch 14/50\n",
131
- "19/19 [==============================] - 0s 3ms/step - loss: 0.0318 - val_loss: 0.0345\n",
132
- "Epoch 15/50\n",
133
- "19/19 [==============================] - 0s 3ms/step - loss: 0.0307 - val_loss: 0.0341\n",
134
- "Epoch 16/50\n",
135
- "19/19 [==============================] - 0s 3ms/step - loss: 0.0300 - val_loss: 0.0337\n",
136
- "Epoch 17/50\n",
137
- "19/19 [==============================] - 0s 3ms/step - loss: 0.0291 - val_loss: 0.0334\n",
138
- "Epoch 18/50\n",
139
- "19/19 [==============================] - 0s 3ms/step - loss: 0.0281 - val_loss: 0.0331\n",
140
- "Epoch 19/50\n",
141
- "19/19 [==============================] - 0s 4ms/step - loss: 0.0278 - val_loss: 0.0329\n",
142
- "Epoch 20/50\n",
143
- "19/19 [==============================] - 0s 5ms/step - loss: 0.0270 - val_loss: 0.0328\n",
144
- "Epoch 21/50\n",
145
- "19/19 [==============================] - 0s 4ms/step - loss: 0.0266 - val_loss: 0.0329\n",
146
- "Epoch 22/50\n",
147
- "19/19 [==============================] - 0s 3ms/step - loss: 0.0261 - val_loss: 0.0326\n",
148
- "Epoch 23/50\n",
149
- "19/19 [==============================] - 0s 3ms/step - loss: 0.0257 - val_loss: 0.0325\n",
150
- "Epoch 24/50\n",
151
- "19/19 [==============================] - 0s 3ms/step - loss: 0.0249 - val_loss: 0.0324\n",
152
- "Epoch 25/50\n",
153
- "19/19 [==============================] - 0s 3ms/step - loss: 0.0247 - val_loss: 0.0325\n",
154
- "Epoch 26/50\n",
155
- "19/19 [==============================] - 0s 3ms/step - loss: 0.0243 - val_loss: 0.0322\n",
156
- "Epoch 27/50\n",
157
- "19/19 [==============================] - 0s 3ms/step - loss: 0.0239 - val_loss: 0.0325\n",
158
- "Epoch 28/50\n",
159
- "19/19 [==============================] - 0s 3ms/step - loss: 0.0235 - val_loss: 0.0325\n",
160
- "Epoch 29/50\n",
161
- "19/19 [==============================] - 0s 5ms/step - loss: 0.0226 - val_loss: 0.0328\n",
162
- "Epoch 30/50\n",
163
- "19/19 [==============================] - 0s 5ms/step - loss: 0.0227 - val_loss: 0.0326\n",
164
- "Epoch 31/50\n",
165
- "19/19 [==============================] - 0s 4ms/step - loss: 0.0224 - val_loss: 0.0326\n",
166
- "\n",
167
- " STACKING (GPU-Keras MLP)\n",
168
- "Fmax = 0.6937\n",
169
- "Thr. = 0.34\n",
170
- "AuPRC = 0.7551\n",
171
- "Smin = 12.2407\n"
172
- ]
173
- }
174
- ],
175
- "source": [
176
- "# %%\n",
177
- "from tensorflow.keras.models import Sequential\n",
178
- "from tensorflow.keras.layers import Dense, Dropout\n",
179
- "from tensorflow.keras.optimizers import Adam\n",
180
- "from sklearn.model_selection import train_test_split\n",
181
- "from sklearn.metrics import precision_recall_curve, auc\n",
182
- "import numpy as np\n",
183
- "import math\n",
184
- "\n",
185
- "# --- Preparar dados para stacking ---\n",
186
- "# (já com y_pb, y_bfd, y_esm com shape (1724, 597))\n",
187
- "X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) # (1724, 597*3)\n",
188
- "y_stack = y_true.copy() # (1724, 597)\n",
189
- "\n",
190
- "# --- Divisão treino/validação ---\n",
191
- "X_train, X_val, y_train, y_val = train_test_split(X_stack, y_stack, test_size=0.3, random_state=42)\n",
192
- "\n",
193
- "# --- Modelo MLP (usa GPU automaticamente se disponível) ---\n",
194
- "from tensorflow.keras.callbacks import EarlyStopping\n",
195
- "\n",
196
- "model = Sequential([\n",
197
- " Dense(512, activation=\"relu\", input_shape=(X_train.shape[1],)),\n",
198
- " Dropout(0.3),\n",
199
- " Dense(256, activation=\"relu\"),\n",
200
- " Dropout(0.3),\n",
201
- " Dense(y_stack.shape[1], activation=\"sigmoid\")\n",
202
- "])\n",
203
- "\n",
204
- "model.compile(optimizer=Adam(1e-3), loss=\"binary_crossentropy\")\n",
205
- "\n",
206
- "model.fit(X_train, y_train, validation_data=(X_val, y_val),\n",
207
- " epochs=50, batch_size=64, verbose=1,\n",
208
- " callbacks=[EarlyStopping(patience=5, restore_best_weights=True)])\n",
209
- "\n",
210
- "# --- Prever com stacking ---\n",
211
- "y_pred_stack = model.predict(X_stack, batch_size=64)\n",
212
- "\n",
213
- "# --- Métricas ---\n",
214
- "THR = np.linspace(0, 1, 101)\n",
215
- "def fmax(y_t, y_p):\n",
216
- " best, thr = 0, 0\n",
217
- " for t in THR:\n",
218
- " y_b = (y_p >= t).astype(int)\n",
219
- " tp = (y_t * y_b).sum(1); fp = ((1 - y_t) * y_b).sum(1); fn = (y_t * (1 - y_b)).sum(1)\n",
220
- " f1 = 2 * tp / (2 * tp + fp + fn + 1e-8); m = f1.mean()\n",
221
- " if m > best: best, thr = m, t\n",
222
- " return best, thr\n",
223
- "\n",
224
- "def auprc(y_t, y_p):\n",
225
- " p, r, _ = precision_recall_curve(y_t.ravel(), y_p.ravel())\n",
226
- " return auc(r, p)\n",
227
- "\n",
228
- "def smin(y_t, y_p, thr, alpha=0.5):\n",
229
- " y_b = (y_p >= thr).astype(int)\n",
230
- " ic = -(np.log((y_t + y_b).sum(0) + 1e-8) - np.log((y_t + y_b).sum() + 1e-8))\n",
231
- " ru = np.logical_and(y_b, np.logical_not(y_t)) * ic\n",
232
- " mi = np.logical_and(y_t, np.logical_not(y_b)) * ic\n",
233
- " return np.sqrt((alpha * ru.sum(1))**2 + ((1 - alpha) * mi.sum(1))**2).mean()\n",
234
- "\n",
235
- "f, thr = fmax(y_stack, y_pred_stack)\n",
236
- "print(f\"\\n STACKING (GPU-Keras MLP)\")\n",
237
- "print(f\"Fmax = {f:.4f}\")\n",
238
- "print(f\"Thr. = {thr:.2f}\")\n",
239
- "print(f\"AuPRC = {auprc(y_stack, y_pred_stack):.4f}\")\n",
240
- "print(f\"Smin = {smin(y_stack, y_pred_stack, thr):.4f}\")\n"
241
- ]
242
- },
243
- {
244
- "cell_type": "code",
245
- "execution_count": 10,
246
- "id": "00695029-3d24-4803-a6e1-8ac5fd70b710",
247
- "metadata": {},
248
- "outputs": [
249
- {
250
- "name": "stdout",
251
- "output_type": "stream",
252
- "text": [
253
- "guardado em models/modelo_ensemble_stacking.keras\n"
254
- ]
255
- }
256
- ],
257
- "source": [
258
- "model.save(\"models/modelo_ensemble_stacking.keras\")\n",
259
- "print('guardado em models/modelo_ensemble_stacking.keras')"
260
- ]
261
- },
262
- {
263
- "cell_type": "code",
264
- "execution_count": null,
265
- "id": "37629e3a-1c24-4f0f-9d12-dddf48be8724",
266
- "metadata": {},
267
- "outputs": [],
268
- "source": []
269
- }
270
- ],
271
- "metadata": {
272
- "kernelspec": {
273
- "display_name": "Python 3 (ipykernel)",
274
- "language": "python",
275
- "name": "python3"
276
- },
277
- "language_info": {
278
- "codemirror_mode": {
279
- "name": "ipython",
280
- "version": 3
281
- },
282
- "file_extension": ".py",
283
- "mimetype": "text/x-python",
284
- "name": "python",
285
- "nbconvert_exporter": "python",
286
- "pygments_lexer": "ipython3",
287
- "version": "3.8.18"
288
- }
289
- },
290
- "nbformat": 4,
291
- "nbformat_minor": 5
292
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/Input.ipynb DELETED
@@ -1,157 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "id": "9eca7d69-3f17-4306-84d0-58a0363144fa",
7
- "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stdout",
11
- "output_type": "stream",
12
- "text": [
13
- "A gerar embeddings por chunks...\n"
14
- ]
15
- },
16
- {
17
- "name": "stderr",
18
- "output_type": "stream",
19
- "text": [
20
- "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
21
- " warnings.warn(\n",
22
- "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
23
- " warnings.warn(\n",
24
- "Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
25
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
26
- ]
27
- },
28
- {
29
- "name": "stdout",
30
- "output_type": "stream",
31
- "text": [
32
- "A fazer predições base...\n",
33
- "\n",
34
- " GO terms com prob ≥ 0.5:\n",
35
- "('GO:0003674', 'GO:0003824', 'GO:0005488', 'GO:0016491', 'GO:0036094', 'GO:0043167')\n",
36
- "\n",
37
- " Top 10 GO terms mais prováveis:\n",
38
- "GO:0003674 : 0.9975\n",
39
- "GO:0003824 : 0.9156\n",
40
- "GO:0036094 : 0.6652\n",
41
- "GO:0043167 : 0.6336\n",
42
- "GO:0016491 : 0.6327\n",
43
- "GO:0005488 : 0.5595\n",
44
- "GO:0043169 : 0.4801\n",
45
- "GO:0140096 : 0.4790\n",
46
- "GO:0051213 : 0.4551\n",
47
- "GO:0046872 : 0.4098\n"
48
- ]
49
- }
50
- ],
51
- "source": [
52
- "# %%\n",
53
- "import numpy as np\n",
54
- "import torch\n",
55
- "from transformers import AutoTokenizer, AutoModel\n",
56
- "from tensorflow.keras.models import load_model\n",
57
- "import joblib\n",
58
- "\n",
59
- "# --- Parâmetros ---\n",
60
- "SEQ_FASTA = \"MPISSSSSSSTKSMRRAASELERSDSVTSPRFIGRRQSLIEDARKEREAAAAAAEAAEATEQIVFEEEDGKALLNLFFTLRSSKTPALSRSLKVFETFEAKIHHLETRPCRKPRDSLEGLEYFVRCEVHLSDVSTLISSIKRIAEDVKTTKEVKFHWFPKKISELDRCHHLITKFDPDLDQEHPGFTDPVYRQRRKMIGDIAFRYKQGEPIPRVEYTEEEIGTWREVYSTLRDLYTTHACSEHLEAFNLLERHCGYSPENIPQLEDVSRFLRERTGFQLRPVAGLLSARDFLASLAFRVFQCTQYIRHASSPMHSPEPDCVHELLGHVPILADRVFAQFSQNIGLASLGASEEDIEKLSTLYWFTVEFGLCKQGGIVKAYGAGLLSSYGELVHALSDEPERREFDPEAAAIQPYQDQNYQSVYFVSESFTDAKEKLRSYVAGIKRPFSVRFDPYTYSIEVLDNPLKIRGGLESVKDELKMLTDALNVLA\"\n",
61
- "TOP_N = 10\n",
62
- "\n",
63
- "# --- 1. Função para dividir sequência (512 para Protbert e Protbertbfd. 1024 para ESM2) ---\n",
64
- "def slice_sequence(seq, chunk_size):\n",
65
- " return [seq[i:i+chunk_size] for i in range(0, len(seq), chunk_size)]\n",
66
- "\n",
67
- "# --- 2. Função para gerar embeddings médios ---\n",
68
- "def get_embedding_mean(model_name, seq, chunk_size):\n",
69
- " tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)\n",
70
- " model = AutoModel.from_pretrained(model_name)\n",
71
- " model.eval()\n",
72
- "\n",
73
- " chunks = [seq[i:i+chunk_size] for i in range(0, len(seq), chunk_size)]\n",
74
- " embeddings = []\n",
75
- "\n",
76
- " for chunk in chunks:\n",
77
- " seq_chunk = \" \".join(list(chunk))\n",
78
- " # tokenizar SEM truncar\n",
79
- " inputs = tokenizer(seq_chunk,\n",
80
- " return_tensors=\"pt\",\n",
81
- " truncation=False, # ≤ 512 ou 1024 já garantido\n",
82
- " padding=False)\n",
83
- " with torch.no_grad():\n",
84
- " cls = model(**inputs).last_hidden_state[:, 0, :].squeeze().numpy()\n",
85
- " embeddings.append(cls)\n",
86
- "\n",
87
- " return np.mean(embeddings, axis=0, keepdims=True) # (1, dim)\n",
88
- "\n",
89
- "print(\"A gerar embeddings por chunks...\")\n",
90
- "emb_pb = get_embedding_mean(\"Rostlab/prot_bert\", SEQ_FASTA, 512)\n",
91
- "emb_bfd = get_embedding_mean(\"Rostlab/prot_bert_bfd\", SEQ_FASTA, 512)\n",
92
- "emb_esm = get_embedding_mean(\"facebook/esm2_t33_650M_UR50D\", SEQ_FASTA, 1024)\n",
93
- "\n",
94
- "# --- 3. Carregar os MLPs base ---\n",
95
- "mlp_pb = load_model(\"models/protbert_mlp.keras\")\n",
96
- "mlp_bfd = load_model(\"models/protbertbfd_mlp.keras\")\n",
97
- "mlp_esm = load_model(\"models/esm2_mlp.keras\")\n",
98
- "\n",
99
- "# --- 4. Gerar predições base (garantir 597 colunas) ---\n",
100
- "print(\"A fazer predições base...\")\n",
101
- "y_pb = mlp_pb.predict(emb_pb)[:, :597]\n",
102
- "y_bfd = mlp_bfd.predict(emb_bfd)[:, :597]\n",
103
- "y_esm = mlp_esm.predict(emb_esm)[:, :597]\n",
104
- "\n",
105
- "# --- 5. Concatenar para o stacking ---\n",
106
- "X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)\n",
107
- "\n",
108
- "# --- 6. Carregar modelo de stacking ---\n",
109
- "stacking = load_model(\"models/modelo_ensemble_stacking.keras\")\n",
110
- "y_pred = stacking.predict(X_stack)\n",
111
- "\n",
112
- "# --- 7. Carregar binarizador (597 GO terms) ---\n",
113
- "mlb = joblib.load(\"data/mlb_597.pkl\")\n",
114
- "go_terms = mlb.classes_\n",
115
- "\n",
116
- "# --- 8. Mostrar resultados ---\n",
117
- "print(\"\\n GO terms com prob ≥ 0.5:\")\n",
118
- "predicted_terms = mlb.inverse_transform((y_pred >= 0.5).astype(int))\n",
119
- "print(predicted_terms[0] if predicted_terms[0] else \"Nenhum GO term acima de 0.5\")\n",
120
- "\n",
121
- "print(f\"\\n Top {TOP_N} GO terms mais prováveis:\")\n",
122
- "top_idx = np.argsort(-y_pred[0])[:TOP_N]\n",
123
- "for i in top_idx:\n",
124
- " print(f\"{go_terms[i]} : {y_pred[0][i]:.4f}\")\n"
125
- ]
126
- },
127
- {
128
- "cell_type": "code",
129
- "execution_count": null,
130
- "id": "e959e7d9-15ba-4533-a2bb-ddd7df2a639d",
131
- "metadata": {},
132
- "outputs": [],
133
- "source": []
134
- }
135
- ],
136
- "metadata": {
137
- "kernelspec": {
138
- "display_name": "Python 3 (ipykernel)",
139
- "language": "python",
140
- "name": "python3"
141
- },
142
- "language_info": {
143
- "codemirror_mode": {
144
- "name": "ipython",
145
- "version": 3
146
- },
147
- "file_extension": ".py",
148
- "mimetype": "text/x-python",
149
- "name": "python",
150
- "nbconvert_exporter": "python",
151
- "pygments_lexer": "ipython3",
152
- "version": "3.8.18"
153
- }
154
- },
155
- "nbformat": 4,
156
- "nbformat_minor": 5
157
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/PAM1_ESM2.ipynb DELETED
@@ -1,561 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 9,
6
- "id": "641053e3-7fec-4f9b-a75e-ddd957af03c4",
7
- "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stdout",
11
- "output_type": "stream",
12
- "text": [
13
- "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
14
- "✓ Dataset preparado:\n",
15
- " - Training: (31142, 3)\n",
16
- " - Validation: (1724, 3)\n",
17
- " - Test: (1724, 3)\n",
18
- " - GO terms: 602\n"
19
- ]
20
- }
21
- ],
22
- "source": [
23
- "# %%\n",
24
- "import pandas as pd\n",
25
- "import numpy as np\n",
26
- "from Bio import SeqIO\n",
27
- "from goatools.obo_parser import GODag\n",
28
- "from collections import Counter\n",
29
- "from sklearn.preprocessing import MultiLabelBinarizer\n",
30
- "from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
31
- "import os, random\n",
32
- "\n",
33
- "# --- 1. Carregar ficheiros principais ---\n",
34
- "FASTA = \"uniprot_sprot_exp.fasta\"\n",
35
- "ANNOT = \"uniprot_sprot_exp.txt\"\n",
36
- "GO_OBO = \"go.obo\"\n",
37
- "\n",
38
- "# --- 2. Ler sequências ---\n",
39
- "seqs, ids = [], []\n",
40
- "for record in SeqIO.parse(FASTA, \"fasta\"):\n",
41
- " ids.append(record.id)\n",
42
- " seqs.append(str(record.seq))\n",
43
- "\n",
44
- "df_seq = pd.DataFrame({\"protein_id\": ids, \"sequence\": seqs})\n",
45
- "\n",
46
- "# --- 3. Ler anotações GO:MF ---\n",
47
- "df_ann = pd.read_csv(ANNOT, sep=\"\\t\", names=[\"protein_id\", \"go_term\", \"category\"])\n",
48
- "df_ann = df_ann[df_ann[\"category\"] == \"F\"]\n",
49
- "\n",
50
- "# --- 4. Propagação hierárquica dos GO terms ---\n",
51
- "go_dag = GODag(GO_OBO)\n",
52
- "mf_terms = {t for t, o in go_dag.items() if o.namespace == \"molecular_function\"}\n",
53
- "\n",
54
- "def propagate_terms(terms):\n",
55
- " expanded = set()\n",
56
- " for t in terms:\n",
57
- " if t in go_dag:\n",
58
- " expanded |= go_dag[t].get_all_parents()\n",
59
- " expanded.add(t)\n",
60
- " return list(expanded & mf_terms)\n",
61
- "\n",
62
- "grouped = df_ann.groupby(\"protein_id\")[\"go_term\"].apply(list).reset_index()\n",
63
- "grouped[\"go_term\"] = grouped[\"go_term\"].apply(propagate_terms)\n",
64
- "\n",
65
- "# --- 5. Juntar com sequência ---\n",
66
- "df = df_seq.merge(grouped, on=\"protein_id\")\n",
67
- "df = df[df[\"go_term\"].str.len() > 0]\n",
68
- "\n",
69
- "# --- 6. Filtrar GO terms com ≥50 proteínas ---\n",
70
- "all_terms = [term for sublist in df[\"go_term\"] for term in sublist]\n",
71
- "term_counts = Counter(all_terms)\n",
72
- "valid_terms = {t for t, count in term_counts.items() if count >= 50}\n",
73
- "\n",
74
- "df[\"go_term\"] = df[\"go_term\"].apply(lambda ts: [t for t in ts if t in valid_terms])\n",
75
- "df = df[df[\"go_term\"].str.len() > 0]\n",
76
- "\n",
77
- "# --- 7. Preparar labels e dividir por proteína ---\n",
78
- "df[\"go_terms\"] = df[\"go_term\"].apply(lambda x: ';'.join(sorted(set(x))))\n",
79
- "df = df[[\"protein_id\", \"sequence\", \"go_terms\"]].drop_duplicates()\n",
80
- "\n",
81
- "mlb = MultiLabelBinarizer()\n",
82
- "Y = mlb.fit_transform(df[\"go_terms\"].str.split(\";\"))\n",
83
- "X = df[[\"protein_id\", \"sequence\"]].values\n",
84
- "\n",
85
- "mskf = MultilabelStratifiedKFold(n_splits=10, random_state=42, shuffle=True)\n",
86
- "train_idx, temp_idx = next(mskf.split(X, Y))\n",
87
- "val_idx, test_idx = np.array_split(temp_idx, 2)\n",
88
- "\n",
89
- "df_train = df.iloc[train_idx].copy()\n",
90
- "df_val = df.iloc[val_idx].copy()\n",
91
- "df_test = df.iloc[test_idx].copy()\n",
92
- "\n",
93
- "os.makedirs(\"data\", exist_ok=True)\n",
94
- "df_train.to_csv(\"data/mf-training.csv\", index=False)\n",
95
- "df_val.to_csv(\"data/mf-validation.csv\", index=False)\n",
96
- "df_test.to_csv(\"data/mf-test.csv\", index=False)\n",
97
- "\n",
98
- "# --- 8. Guardar o binarizador ---\n",
99
- "import joblib\n",
100
- "joblib.dump(mlb, \"data/mlb.pkl\")\n",
101
- "\n",
102
- "print(\"✓ Dataset preparado:\")\n",
103
- "print(\" - Training:\", df_train.shape)\n",
104
- "print(\" - Validation:\", df_val.shape)\n",
105
- "print(\" - Test:\", df_test.shape)\n",
106
- "print(\" - GO terms:\", len(mlb.classes_))\n"
107
- ]
108
- },
109
- {
110
- "cell_type": "code",
111
- "execution_count": 10,
112
- "id": "40ba1798-daf8-4649-ae3f-bfe81df6437f",
113
- "metadata": {},
114
- "outputs": [],
115
- "source": [
116
- "# %%\n",
117
- "import random\n",
118
- "from collections import defaultdict\n",
119
- "\n",
120
- "# --- PAM1 matrix normalizada ---\n",
121
- "pam_data = {\n",
122
- " 'A': [9948, 19, 27, 42, 31, 46, 50, 92, 17, 7, 40, 88, 42, 41, 122, 279, 255, 9, 72, 723],\n",
123
- " 'R': [14, 9871, 24, 38, 37, 130, 38, 62, 49, 4, 58, 205, 26, 33, 47, 103, 104, 5, 36, 52],\n",
124
- " 'N': [20, 22, 9860, 181, 29, 36, 41, 67, 31, 5, 22, 49, 23, 10, 33, 83, 66, 3, 43, 32],\n",
125
- " 'D': [40, 34, 187, 9818, 11, 63, 98, 61, 23, 5, 25, 54, 43, 13, 27, 88, 55, 4, 29, 36],\n",
126
- " 'C': [20, 16, 26, 9, 9987, 10, 17, 37, 12, 2, 16, 26, 10, 19, 27, 26, 25, 2, 6, 67],\n",
127
- " 'Q': [29, 118, 29, 49, 8, 9816, 72, 55, 36, 4, 60, 158, 35, 22, 39, 86, 74, 3, 34, 28],\n",
128
- " 'E': [35, 29, 41, 101, 12, 71, 9804, 56, 33, 5, 36, 107, 42, 20, 38, 87, 69, 4, 30, 42],\n",
129
- " 'G': [96, 61, 77, 70, 38, 51, 58, 9868, 26, 6, 37, 53, 39, 28, 69, 134, 116, 5, 47, 60],\n",
130
- " 'H': [17, 53, 33, 19, 15, 39, 34, 24, 9907, 3, 32, 57, 24, 15, 27, 47, 43, 2, 22, 19],\n",
131
- " 'I': [6, 3, 6, 6, 3, 5, 6, 7, 3, 9973, 23, 13, 12, 41, 93, 84, 115, 3, 8, 102],\n",
132
- " 'L': [26, 39, 17, 15, 7, 33, 22, 20, 19, 27, 9864, 49, 24, 78, 117, 148, 193, 5, 24, 70],\n",
133
- " 'K': [60, 198, 43, 52, 12, 142, 96, 53, 42, 10, 63, 9710, 33, 26, 54, 109, 102, 5, 43, 42],\n",
134
- " 'M': [21, 22, 15, 18, 6, 20, 18, 18, 17, 11, 27, 32, 9945, 26, 34, 61, 71, 3, 12, 31],\n",
135
- " 'F': [18, 17, 8, 6, 8, 11, 10, 16, 10, 44, 92, 24, 29, 9899, 89, 88, 142, 7, 14, 68],\n",
136
- " 'P': [97, 47, 35, 29, 23, 35, 38, 57, 21, 24, 47, 56, 28, 76, 9785, 115, 77, 4, 24, 35],\n",
137
- " 'S': [241, 87, 76, 73, 17, 56, 60, 99, 32, 13, 69, 92, 42, 67, 100, 9605, 212, 8, 63, 70],\n",
138
- " 'T': [186, 78, 54, 37, 14, 42, 42, 83, 28, 23, 84, 85, 53, 93, 66, 182, 9676, 8, 39, 90],\n",
139
- " 'W': [2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 5, 3, 4, 4, 9960, 3, 4],\n",
140
- " 'Y': [29, 21, 17, 9, 4, 13, 9, 21, 10, 7, 20, 17, 11, 23, 19, 41, 31, 3, 9935, 23],\n",
141
- " 'V': [368, 27, 18, 18, 50, 23, 34, 64, 15, 85, 72, 42, 33, 88, 42, 112, 137, 4, 20, 9514]\n",
142
- "}\n",
143
- "\n",
144
- "pam_raw = pd.DataFrame(pam_data, index=pam_data.keys())\n",
145
- "pam_matrix = pam_raw.div(pam_raw.sum(axis=1), axis=0)\n",
146
- "pam_dict = {aa: pam_matrix.loc[aa].to_dict() for aa in pam_matrix.index}\n",
147
- "\n",
148
- "def pam1_substitution(aa):\n",
149
- " if aa not in pam_dict:\n",
150
- " return aa\n",
151
- " subs = list(pam_dict[aa].keys())\n",
152
- " probs = list(pam_dict[aa].values())\n",
153
- " return np.random.choice(subs, p=probs)\n",
154
- "\n",
155
- "def augment_sequence(seq, sub_prob=0.05):\n",
156
- " return ''.join([pam1_substitution(aa) if random.random() < sub_prob else aa for aa in seq])\n",
157
- "\n",
158
- "def slice_sequence(seq, win=1024):\n",
159
- " if len(seq) <= win:\n",
160
- " return [seq]\n",
161
- " return [seq[i:i+win] for i in range(0, len(seq), win)]\n",
162
- "\n",
163
- "def format_seq(seq):\n",
164
- " return \" \".join(seq)\n",
165
- "\n",
166
- "# --- Carregar labels e datasets ---\n",
167
- "import joblib\n",
168
- "mlb = joblib.load(\"data/mlb.pkl\")\n",
169
- "df_train = pd.read_csv(\"data/mf-training.csv\")\n",
170
- "df_val = pd.read_csv(\"data/mf-validation.csv\")\n",
171
- "df_test = pd.read_csv(\"data/mf-test.csv\")\n",
172
- "\n",
173
- "# --- Slicing + augmentação no treino ---\n",
174
- "X_train, y_train = [], []\n",
175
- "\n",
176
- "for _, row in df_train.iterrows():\n",
177
- " seq_aug = augment_sequence(row[\"sequence\"], sub_prob=0.05)\n",
178
- " slices = slice_sequence(seq_aug, win=1024)\n",
179
- " label = mlb.transform([row[\"go_terms\"].split(\";\")])[0]\n",
180
- " for sl in slices:\n",
181
- " X_train.append(format_seq(sl))\n",
182
- " y_train.append(label)\n",
183
- "\n",
184
- "# --- Sem slicing no val/test ---\n",
185
- "X_val = [format_seq(seq) for seq in df_val[\"sequence\"]]\n",
186
- "X_test = [format_seq(seq) for seq in df_test[\"sequence\"]]\n",
187
- "\n",
188
- "y_val = mlb.transform(df_val[\"go_terms\"].str.split(\";\"))\n",
189
- "y_test = mlb.transform(df_test[\"go_terms\"].str.split(\";\"))\n",
190
- "\n",
191
- "np.save(\"embeddings/y_test.npy\", y_test)"
192
- ]
193
- },
194
- {
195
- "cell_type": "code",
196
- "execution_count": 11,
197
- "id": "80d5c1fb-9c84-463d-8d8c-bfcc2982afc9",
198
- "metadata": {},
199
- "outputs": [
200
- {
201
- "name": "stderr",
202
- "output_type": "stream",
203
- "text": [
204
- "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
205
- " warnings.warn(\n",
206
- "Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
207
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
208
- "100%|██████████| 2189/2189 [1:17:26<00:00, 2.12s/it]\n",
209
- "100%|██████████| 108/108 [03:43<00:00, 2.07s/it]\n",
210
- "100%|██████████| 108/108 [03:56<00:00, 2.19s/it]\n"
211
- ]
212
- }
213
- ],
214
- "source": [
215
- "# %%\n",
216
- "from transformers import AutoTokenizer, AutoModel\n",
217
- "import torch\n",
218
- "from tqdm import tqdm\n",
219
- "import numpy as np\n",
220
- "import os\n",
221
- "\n",
222
- "# --- Configurações ---\n",
223
- "MODEL_NAME = \"facebook/esm2_t33_650M_UR50D\"\n",
224
- "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
225
- "CHUNK_SIZE = 16\n",
226
- "\n",
227
- "# --- Carregar modelo ---\n",
228
- "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)\n",
229
- "model = AutoModel.from_pretrained(MODEL_NAME)\n",
230
- "model.to(DEVICE)\n",
231
- "model.eval()\n",
232
- "\n",
233
- "def extract_embeddings(texts):\n",
234
- " embeddings = []\n",
235
- " for i in tqdm(range(0, len(texts), CHUNK_SIZE)):\n",
236
- " batch = texts[i:i+CHUNK_SIZE]\n",
237
- " with torch.no_grad():\n",
238
- " inputs = tokenizer(batch, return_tensors=\"pt\", padding=True, truncation=True, max_length=1024)\n",
239
- " inputs = {k: v.to(DEVICE) for k, v in inputs.items()}\n",
240
- " outputs = model(**inputs).last_hidden_state\n",
241
- " cls_tokens = outputs[:, 0, :] # token CLS\n",
242
- " embeddings.append(cls_tokens.cpu().numpy())\n",
243
- " return np.vstack(embeddings)\n",
244
- "\n",
245
- "# --- Extrair e guardar embeddings ---\n",
246
- "os.makedirs(\"embeddings\", exist_ok=True)\n",
247
- "\n",
248
- "emb_train = extract_embeddings(X_train)\n",
249
- "emb_val = extract_embeddings(X_val)\n",
250
- "emb_test = extract_embeddings(X_test)\n",
251
- "\n",
252
- "np.save(\"embeddings/esm2_train.npy\", emb_train)\n",
253
- "np.save(\"embeddings/esm2_val.npy\", emb_val)\n",
254
- "np.save(\"embeddings/esm2_test.npy\", emb_test)\n",
255
- "\n",
256
- "np.save(\"embeddings/y_train.npy\", np.array(y_train))\n",
257
- "np.save(\"embeddings/y_val.npy\", np.array(y_val))\n"
258
- ]
259
- },
260
- {
261
- "cell_type": "code",
262
- "execution_count": 1,
263
- "id": "592e4f6c-b871-4f0b-b84c-f3918c698544",
264
- "metadata": {},
265
- "outputs": [
266
- {
267
- "name": "stdout",
268
- "output_type": "stream",
269
- "text": [
270
- "Epoch 1/100\n",
271
- "1095/1095 [==============================] - 4s 2ms/step - loss: 0.0552 - val_loss: 0.0455\n",
272
- "Epoch 2/100\n",
273
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0445 - val_loss: 0.0424\n",
274
- "Epoch 3/100\n",
275
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0419 - val_loss: 0.0394\n",
276
- "Epoch 4/100\n",
277
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0403 - val_loss: 0.0381\n",
278
- "Epoch 5/100\n",
279
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0392 - val_loss: 0.0373\n",
280
- "Epoch 6/100\n",
281
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0383 - val_loss: 0.0362\n",
282
- "Epoch 7/100\n",
283
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0374 - val_loss: 0.0358\n",
284
- "Epoch 8/100\n",
285
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0368 - val_loss: 0.0351\n",
286
- "Epoch 9/100\n",
287
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0362 - val_loss: 0.0348\n",
288
- "Epoch 10/100\n",
289
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0357 - val_loss: 0.0344\n",
290
- "Epoch 11/100\n",
291
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0353 - val_loss: 0.0340\n",
292
- "Epoch 12/100\n",
293
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0349 - val_loss: 0.0335\n",
294
- "Epoch 13/100\n",
295
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0344 - val_loss: 0.0334\n",
296
- "Epoch 14/100\n",
297
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0342 - val_loss: 0.0330\n",
298
- "Epoch 15/100\n",
299
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0338 - val_loss: 0.0325\n",
300
- "Epoch 16/100\n",
301
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0337 - val_loss: 0.0327\n",
302
- "Epoch 17/100\n",
303
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0333 - val_loss: 0.0325\n",
304
- "Epoch 18/100\n",
305
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0330 - val_loss: 0.0322\n",
306
- "Epoch 19/100\n",
307
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0328 - val_loss: 0.0321\n",
308
- "Epoch 20/100\n",
309
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0326 - val_loss: 0.0322\n",
310
- "Epoch 21/100\n",
311
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0323 - val_loss: 0.0320\n",
312
- "Epoch 22/100\n",
313
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0322 - val_loss: 0.0317\n",
314
- "Epoch 23/100\n",
315
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0320 - val_loss: 0.0318\n",
316
- "Epoch 24/100\n",
317
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0317 - val_loss: 0.0315\n",
318
- "Epoch 25/100\n",
319
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0316 - val_loss: 0.0317\n",
320
- "Epoch 26/100\n",
321
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0314 - val_loss: 0.0313\n",
322
- "Epoch 27/100\n",
323
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0313 - val_loss: 0.0320\n",
324
- "Epoch 28/100\n",
325
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0311 - val_loss: 0.0315\n",
326
- "Epoch 29/100\n",
327
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0310 - val_loss: 0.0313\n",
328
- "Epoch 30/100\n",
329
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0309 - val_loss: 0.0313\n",
330
- "Epoch 31/100\n",
331
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0307 - val_loss: 0.0310\n",
332
- "Epoch 32/100\n",
333
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0306 - val_loss: 0.0310\n",
334
- "Epoch 33/100\n",
335
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0304 - val_loss: 0.0310\n",
336
- "Epoch 34/100\n",
337
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0303 - val_loss: 0.0312\n",
338
- "Epoch 35/100\n",
339
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0302 - val_loss: 0.0309\n",
340
- "Epoch 36/100\n",
341
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0300 - val_loss: 0.0310\n",
342
- "Epoch 37/100\n",
343
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0299 - val_loss: 0.0313\n",
344
- "Epoch 38/100\n",
345
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0298 - val_loss: 0.0312\n",
346
- "Epoch 39/100\n",
347
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0296 - val_loss: 0.0307\n",
348
- "Epoch 40/100\n",
349
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0296 - val_loss: 0.0306\n",
350
- "Epoch 41/100\n",
351
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0295 - val_loss: 0.0310\n",
352
- "Epoch 42/100\n",
353
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0294 - val_loss: 0.0304\n",
354
- "Epoch 43/100\n",
355
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0294 - val_loss: 0.0308\n",
356
- "Epoch 44/100\n",
357
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0293 - val_loss: 0.0306\n",
358
- "Epoch 45/100\n",
359
- "1095/1095 [==============================] - 3s 3ms/step - loss: 0.0292 - val_loss: 0.0307\n",
360
- "Epoch 46/100\n",
361
- "1095/1095 [==============================] - 4s 4ms/step - loss: 0.0290 - val_loss: 0.0305\n",
362
- "Epoch 47/100\n",
363
- "1095/1095 [==============================] - 4s 4ms/step - loss: 0.0290 - val_loss: 0.0305\n",
364
- "Modelo guardado em models/esm2_mlp.keras\n",
365
- " Predições do ESM-2 salvas com forma: (1724, 602)\n"
366
- ]
367
- }
368
- ],
369
- "source": [
370
- "# %%\n",
371
- "import numpy as np\n",
372
- "import tensorflow as tf\n",
373
- "from tensorflow.keras.models import Sequential\n",
374
- "from tensorflow.keras.layers import Dense, Dropout\n",
375
- "from tensorflow.keras.callbacks import EarlyStopping\n",
376
- "from sklearn.metrics import average_precision_score\n",
377
- "\n",
378
- "# --- Carregar os embeddings e labels ---\n",
379
- "X_train = np.load(\"embeddings/esm2_train.npy\")\n",
380
- "X_val = np.load(\"embeddings/esm2_val.npy\")\n",
381
- "X_test = np.load(\"embeddings/esm2_test.npy\")\n",
382
- "\n",
383
- "y_train = np.load(\"embeddings/y_train.npy\")\n",
384
- "y_val = np.load(\"embeddings/y_val.npy\")\n",
385
- "y_test = np.load(\"embeddings/y_test.npy\")\n",
386
- "\n",
387
- "# --- Definir o modelo ---\n",
388
- "model = Sequential([\n",
389
- " Dense(1024, activation='relu', input_shape=(X_train.shape[1],)),\n",
390
- " Dropout(0.3),\n",
391
- " Dense(512, activation='relu'),\n",
392
- " Dropout(0.3),\n",
393
- " Dense(y_train.shape[1], activation='sigmoid')\n",
394
- "])\n",
395
- "\n",
396
- "model.compile(optimizer='adam', loss='binary_crossentropy')\n",
397
- "\n",
398
- "# --- Treinar ---\n",
399
- "early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)\n",
400
- "\n",
401
- "history = model.fit(\n",
402
- " X_train, y_train,\n",
403
- " validation_data=(X_val, y_val),\n",
404
- " epochs=100,\n",
405
- " batch_size=32,\n",
406
- " callbacks=[early_stop],\n",
407
- " verbose=1\n",
408
- ")\n",
409
- "\n",
410
- "# --- Salvar o modelo ---\n",
411
- "model.save(\"models/esm2_mlp.keras\")\n",
412
- "print(\"Modelo guardado em models/esm2_mlp.keras\")\n",
413
- "\n",
414
- "# --- Fazer predições no conjunto de teste ---\n",
415
- "y_prob = model.predict(X_test)\n",
416
- "np.save(\"predictions/mf-esm2.npy\", y_prob)\n",
417
- "\n",
418
- "print(\" Predições do ESM-2 salvas com forma:\", y_prob.shape)\n"
419
- ]
420
- },
421
- {
422
- "cell_type": "code",
423
- "execution_count": 15,
424
- "id": "3dddb0df-3ea5-4e32-8cf0-45e90be8ba66",
425
- "metadata": {},
426
- "outputs": [
427
- {
428
- "name": "stdout",
429
- "output_type": "stream",
430
- "text": [
431
- "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
432
- "✓ Dados carregados: (1724, 602) proteínas × 602 GO terms\n",
433
- "\n",
434
- " Resultados finais (ESM-2 + PAM1 + propagação):\n",
435
- "Fmax = 0.6439\n",
436
- "Thr. = 0.34\n",
437
- "AuPRC = 0.6948\n",
438
- "Smin = 14.1500\n"
439
- ]
440
- }
441
- ],
442
- "source": [
443
- "# %%\n",
444
- "import numpy as np\n",
445
- "import joblib\n",
446
- "import math\n",
447
- "from goatools.obo_parser import GODag\n",
448
- "from sklearn.metrics import precision_recall_curve, auc\n",
449
- "\n",
450
- "# --- 1. Carregar dados e parâmetros ---\n",
451
- "GO_FILE = \"go.obo\"\n",
452
- "THRESHOLDS = np.arange(0.0, 1.01, 0.01)\n",
453
- "ALPHA = 0.5\n",
454
- "\n",
455
- "mlb = joblib.load(\"data/mlb.pkl\")\n",
456
- "y_true = np.load(\"embeddings/y_test.npy\")\n",
457
- "y_prob = np.load(\"predictions/mf-esm2.npy\")\n",
458
- "terms = mlb.classes_\n",
459
- "go_dag = GODag(GO_FILE)\n",
460
- "\n",
461
- "print(f\"✓ Dados carregados: {y_true.shape} proteínas × {len(terms)} GO terms\")\n",
462
- "\n",
463
- "# --- 2. Fmax ---\n",
464
- "def compute_fmax(y_true, y_prob, thresholds):\n",
465
- " fmax, best_thr = 0, 0\n",
466
- " for t in thresholds:\n",
467
- " y_pred = (y_prob >= t).astype(int)\n",
468
- " tp = (y_true * y_pred).sum(axis=1)\n",
469
- " fp = ((1 - y_true) * y_pred).sum(axis=1)\n",
470
- " fn = (y_true * (1 - y_pred)).sum(axis=1)\n",
471
- " precision = tp / (tp + fp + 1e-8)\n",
472
- " recall = tp / (tp + fn + 1e-8)\n",
473
- " f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
474
- " avg_f1 = np.mean(f1)\n",
475
- " if avg_f1 > fmax:\n",
476
- " fmax, best_thr = avg_f1, t\n",
477
- " return fmax, best_thr\n",
478
- "\n",
479
- "# --- 3. AuPRC (micro) ---\n",
480
- "def compute_auprc(y_true, y_prob):\n",
481
- " precision, recall, _ = precision_recall_curve(y_true.ravel(), y_prob.ravel())\n",
482
- " return auc(recall, precision)\n",
483
- "\n",
484
- "# --- 4. Smin ---\n",
485
- "def compute_smin(y_true, y_prob, terms, threshold, go_dag, alpha=ALPHA):\n",
486
- " y_pred = (y_prob >= threshold).astype(int)\n",
487
- "\n",
488
- " # Informação semântica: IC (Information Content)\n",
489
- " ic = {}\n",
490
- " total = (y_true + y_pred).sum(axis=0).sum()\n",
491
- " for i, term in enumerate(terms):\n",
492
- " freq = (y_true[:, i] + y_pred[:, i]).sum()\n",
493
- " ic[term] = -np.log((freq + 1e-8) / total)\n",
494
- "\n",
495
- " # Para cada proteína, calcular RU e MI\n",
496
- " s_values = []\n",
497
- " for true_vec, pred_vec in zip(y_true, y_pred):\n",
498
- " true_terms = {terms[i] for i in np.where(true_vec)[0]}\n",
499
- " pred_terms = {terms[i] for i in np.where(pred_vec)[0]}\n",
500
- "\n",
501
- " anc_true = set()\n",
502
- " for t in true_terms:\n",
503
- " if t in go_dag:\n",
504
- " anc_true |= go_dag[t].get_all_parents()\n",
505
- " anc_pred = set()\n",
506
- " for t in pred_terms:\n",
507
- " if t in go_dag:\n",
508
- " anc_pred |= go_dag[t].get_all_parents()\n",
509
- "\n",
510
- " ru = pred_terms - true_terms\n",
511
- " mi = true_terms - pred_terms\n",
512
- " dist_ru = sum(ic.get(t, 0) for t in ru)\n",
513
- " dist_mi = sum(ic.get(t, 0) for t in mi)\n",
514
- " s = math.sqrt((alpha * dist_ru)**2 + ((1 - alpha) * dist_mi)**2)\n",
515
- " s_values.append(s)\n",
516
- "\n",
517
- " return np.mean(s_values)\n",
518
- "\n",
519
- "# --- 5. Avaliação ---\n",
520
- "fmax, thr = compute_fmax(y_true, y_prob, THRESHOLDS)\n",
521
- "auprc = compute_auprc(y_true, y_prob)\n",
522
- "smin = compute_smin(y_true, y_prob, terms, thr, go_dag)\n",
523
- "\n",
524
- "print(f\"\\n Resultados finais (ESM-2 + PAM1 + propagação):\")\n",
525
- "print(f\"Fmax = {fmax:.4f}\")\n",
526
- "print(f\"Thr. = {thr:.2f}\")\n",
527
- "print(f\"AuPRC = {auprc:.4f}\")\n",
528
- "print(f\"Smin = {smin:.4f}\")\n"
529
- ]
530
- },
531
- {
532
- "cell_type": "code",
533
- "execution_count": null,
534
- "id": "1a1ea084-01de-4dc4-88da-e7ffeb8c94c9",
535
- "metadata": {},
536
- "outputs": [],
537
- "source": []
538
- }
539
- ],
540
- "metadata": {
541
- "kernelspec": {
542
- "display_name": "Python 3 (ipykernel)",
543
- "language": "python",
544
- "name": "python3"
545
- },
546
- "language_info": {
547
- "codemirror_mode": {
548
- "name": "ipython",
549
- "version": 3
550
- },
551
- "file_extension": ".py",
552
- "mimetype": "text/x-python",
553
- "name": "python",
554
- "nbconvert_exporter": "python",
555
- "pygments_lexer": "ipython3",
556
- "version": "3.8.18"
557
- }
558
- },
559
- "nbformat": 4,
560
- "nbformat_minor": 5
561
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/PAM1_protbert.ipynb DELETED
@@ -1,935 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 2,
6
- "id": "c6dbc330-062a-48f0-8242-3f21cc1c9c2b",
7
- "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stdout",
11
- "output_type": "stream",
12
- "text": [
13
- "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
14
- "✓ Ficheiros criados:\n",
15
- " - data/mf-training.csv : (31142, 3)\n",
16
- " - data/mf-validation.csv: (1724, 3)\n",
17
- " - data/mf-test.csv : (1724, 3)\n",
18
- "GO terms únicos (após propagação e filtro): 602\n"
19
- ]
20
- }
21
- ],
22
- "source": [
23
- "import pandas as pd\n",
24
- "from Bio import SeqIO\n",
25
- "from collections import Counter\n",
26
- "from goatools.obo_parser import GODag\n",
27
- "from sklearn.model_selection import train_test_split\n",
28
- "from sklearn.preprocessing import MultiLabelBinarizer\n",
29
- "from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
30
- "import numpy as np\n",
31
- "import os\n",
32
- "\n",
33
- "# --- 1. Carregar GO anotações ------------------------------------------\n",
34
- "annotations = pd.read_csv(\"uniprot_sprot_exp.txt\", sep=\"\\t\", names=[\"protein_id\", \"go_term\", \"go_category\"])\n",
35
- "annotations_f = annotations[annotations[\"go_category\"] == \"F\"]\n",
36
- "\n",
37
- "# --- 2. Carregar DAG e propagar GO terms -------------------------------\n",
38
- "# propagação hierárquica\n",
39
- "# https://geneontology.org/docs/download-ontology/\n",
40
- "go_dag = GODag(\"go.obo\")\n",
41
- "mf_terms = {t for t, o in go_dag.items() if o.namespace == \"molecular_function\"}\n",
42
- "\n",
43
- "def propagate_terms(term_list):\n",
44
- " full = set()\n",
45
- " for t in term_list:\n",
46
- " if t not in go_dag:\n",
47
- " continue\n",
48
- " full.add(t)\n",
49
- " full.update(go_dag[t].get_all_parents())\n",
50
- " return list(full & mf_terms)\n",
51
- "\n",
52
- "# --- 3. Carregar sequências --------------------------------------------\n",
53
- "seqs, ids = [], []\n",
54
- "for record in SeqIO.parse(\"uniprot_sprot_exp.fasta\", \"fasta\"):\n",
55
- " ids.append(record.id)\n",
56
- " seqs.append(str(record.seq))\n",
57
- "\n",
58
- "seq_df = pd.DataFrame({\"protein_id\": ids, \"sequence\": seqs})\n",
59
- "\n",
60
- "# --- 4. Juntar com GO anotado e propagar -------------------------------\n",
61
- "grouped = annotations_f.groupby(\"protein_id\")[\"go_term\"].apply(list).reset_index()\n",
62
- "data = seq_df.merge(grouped, on=\"protein_id\")\n",
63
- "data = data[data[\"go_term\"].apply(len) > 0]\n",
64
- "data[\"go_term\"] = data[\"go_term\"].apply(propagate_terms)\n",
65
- "data = data[data[\"go_term\"].apply(len) > 0]\n",
66
- "\n",
67
- "# --- 5. Filtrar GO terms raros -----------------------------------------\n",
68
- "# todos os terms com menos de 50 proteinas associadas\n",
69
- "all_terms = [term for sublist in data[\"go_term\"] for term in sublist]\n",
70
- "term_counts = Counter(all_terms)\n",
71
- "valid_terms = {term for term, count in term_counts.items() if count >= 50}\n",
72
- "data[\"go_term\"] = data[\"go_term\"].apply(lambda terms: [t for t in terms if t in valid_terms])\n",
73
- "data = data[data[\"go_term\"].apply(len) > 0]\n",
74
- "\n",
75
- "# --- 6. Preparar dataset final -----------------------------------------\n",
76
- "data[\"go_terms\"] = data[\"go_term\"].apply(lambda x: ';'.join(sorted(set(x))))\n",
77
- "data = data[[\"protein_id\", \"sequence\", \"go_terms\"]].drop_duplicates()\n",
78
- "\n",
79
- "# --- 7. Binarizar labels e dividir -------------------------------------\n",
80
- "mlb = MultiLabelBinarizer()\n",
81
- "Y = mlb.fit_transform(data[\"go_terms\"].str.split(\";\"))\n",
82
- "X = data[[\"protein_id\", \"sequence\"]].values\n",
83
- "\n",
84
- "mskf = MultilabelStratifiedKFold(n_splits=10, random_state=42, shuffle=True)\n",
85
- "train_idx, temp_idx = next(mskf.split(X, Y))\n",
86
- "val_idx, test_idx = np.array_split(temp_idx, 2)\n",
87
- "\n",
88
- "df_train = data.iloc[train_idx].copy()\n",
89
- "df_val = data.iloc[val_idx].copy()\n",
90
- "df_test = data.iloc[test_idx].copy()\n",
91
- "\n",
92
- "# --- 8. Guardar em CSV -------------------------------------------------\n",
93
- "os.makedirs(\"data\", exist_ok=True)\n",
94
- "df_train.to_csv(\"data/mf-training.csv\", index=False)\n",
95
- "df_val.to_csv(\"data/mf-validation.csv\", index=False)\n",
96
- "df_test.to_csv(\"data/mf-test.csv\", index=False)\n",
97
- "\n",
98
- "# --- 9. Confirmar ------------------------------------------------------\n",
99
- "print(\"✓ Ficheiros criados:\")\n",
100
- "print(\" - data/mf-training.csv :\", df_train.shape)\n",
101
- "print(\" - data/mf-validation.csv:\", df_val.shape)\n",
102
- "print(\" - data/mf-test.csv :\", df_test.shape)\n",
103
- "print(f\"GO terms únicos (após propagação e filtro): {len(mlb.classes_)}\")\n"
104
- ]
105
- },
106
- {
107
- "cell_type": "code",
108
- "execution_count": 2,
109
- "id": "6cf7aaa6-4941-4951-8d73-1f4f1f4362f3",
110
- "metadata": {},
111
- "outputs": [
112
- {
113
- "name": "stderr",
114
- "output_type": "stream",
115
- "text": [
116
- "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
117
- " from .autonotebook import tqdm as notebook_tqdm\n",
118
- "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
119
- " _torch_pytree._register_pytree_node(\n",
120
- "100%|██████████| 31142/31142 [00:24<00:00, 1262.18it/s]\n",
121
- "100%|██████████| 1724/1724 [00:00<00:00, 2628.24it/s]\n",
122
- "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\text\\preprocessor.py:382: UserWarning: The class_names argument is replacing the classes argument. Please update your code.\n",
123
- " warnings.warn(\n"
124
- ]
125
- },
126
- {
127
- "name": "stdout",
128
- "output_type": "stream",
129
- "text": [
130
- "preprocessing train...\n",
131
- "language: de\n",
132
- "train sequence lengths:\n",
133
- "\tmean : 423\n",
134
- "\t95percentile : 604\n",
135
- "\t99percentile : 715\n"
136
- ]
137
- },
138
- {
139
- "data": {
140
- "text/html": [
141
- "\n",
142
- "<style>\n",
143
- " /* Turns off some styling */\n",
144
- " progress {\n",
145
- " /* gets rid of default border in Firefox and Opera. */\n",
146
- " border: none;\n",
147
- " /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
148
- " background-size: auto;\n",
149
- " }\n",
150
- " progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
151
- " background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
152
- " }\n",
153
- " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
154
- " background: #F44336;\n",
155
- " }\n",
156
- "</style>\n"
157
- ],
158
- "text/plain": [
159
- "<IPython.core.display.HTML object>"
160
- ]
161
- },
162
- "metadata": {},
163
- "output_type": "display_data"
164
- },
165
- {
166
- "data": {
167
- "text/html": [],
168
- "text/plain": [
169
- "<IPython.core.display.HTML object>"
170
- ]
171
- },
172
- "metadata": {},
173
- "output_type": "display_data"
174
- },
175
- {
176
- "name": "stdout",
177
- "output_type": "stream",
178
- "text": [
179
- "Is Multi-Label? True\n",
180
- "preprocessing test...\n",
181
- "language: de\n",
182
- "test sequence lengths:\n",
183
- "\tmean : 408\n",
184
- "\t95percentile : 603\n",
185
- "\t99percentile : 714\n"
186
- ]
187
- },
188
- {
189
- "data": {
190
- "text/html": [
191
- "\n",
192
- "<style>\n",
193
- " /* Turns off some styling */\n",
194
- " progress {\n",
195
- " /* gets rid of default border in Firefox and Opera. */\n",
196
- " border: none;\n",
197
- " /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
198
- " background-size: auto;\n",
199
- " }\n",
200
- " progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
201
- " background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
202
- " }\n",
203
- " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
204
- " background: #F44336;\n",
205
- " }\n",
206
- "</style>\n"
207
- ],
208
- "text/plain": [
209
- "<IPython.core.display.HTML object>"
210
- ]
211
- },
212
- "metadata": {},
213
- "output_type": "display_data"
214
- },
215
- {
216
- "data": {
217
- "text/html": [],
218
- "text/plain": [
219
- "<IPython.core.display.HTML object>"
220
- ]
221
- },
222
- "metadata": {},
223
- "output_type": "display_data"
224
- },
225
- {
226
- "name": "stderr",
227
- "output_type": "stream",
228
- "text": [
229
- "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\text\\preprocessor.py:1093: UserWarning: Could not load a Tensorflow version of model. (If this worked before, it might be an out-of-memory issue.) Attempting to download/load PyTorch version as TensorFlow model using from_pt=True. You will need PyTorch installed for this.\n",
230
- " warnings.warn(\n"
231
- ]
232
- },
233
- {
234
- "name": "stdout",
235
- "output_type": "stream",
236
- "text": [
237
- "\n",
238
- "\n",
239
- "begin training using triangular learning rate policy with max lr of 1e-05...\n",
240
- "Epoch 1/10\n",
241
- "40995/40995 [==============================] - 13053s 318ms/step - loss: 0.0745 - binary_accuracy: 0.9866 - val_loss: 0.0582 - val_binary_accuracy: 0.9859\n",
242
- "Epoch 2/10\n",
243
- "40995/40995 [==============================] - 14484s 353ms/step - loss: 0.0504 - binary_accuracy: 0.9873 - val_loss: 0.0499 - val_binary_accuracy: 0.9867\n",
244
- "Epoch 3/10\n",
245
- "40995/40995 [==============================] - 14472s 353ms/step - loss: 0.0450 - binary_accuracy: 0.9879 - val_loss: 0.0449 - val_binary_accuracy: 0.9873\n",
246
- "Epoch 4/10\n",
247
- "40995/40995 [==============================] - 14445s 352ms/step - loss: 0.0407 - binary_accuracy: 0.9884 - val_loss: 0.0413 - val_binary_accuracy: 0.9878\n",
248
- "Epoch 5/10\n",
249
- "40995/40995 [==============================] - 12524s 305ms/step - loss: 0.0378 - binary_accuracy: 0.9888 - val_loss: 0.0394 - val_binary_accuracy: 0.9881\n",
250
- "Epoch 6/10\n",
251
- "40995/40995 [==============================] - 14737s 359ms/step - loss: 0.0359 - binary_accuracy: 0.9891 - val_loss: 0.0383 - val_binary_accuracy: 0.9883\n",
252
- "Epoch 7/10\n",
253
- "40995/40995 [==============================] - 20317s 495ms/step - loss: 0.0343 - binary_accuracy: 0.9894 - val_loss: 0.0371 - val_binary_accuracy: 0.9885\n",
254
- "Epoch 8/10\n",
255
- "40995/40995 [==============================] - 9073s 221ms/step - loss: 0.0331 - binary_accuracy: 0.9896 - val_loss: 0.0364 - val_binary_accuracy: 0.9887\n",
256
- "Epoch 9/10\n",
257
- "40995/40995 [==============================] - 9001s 219ms/step - loss: 0.0320 - binary_accuracy: 0.9898 - val_loss: 0.0360 - val_binary_accuracy: 0.9888\n",
258
- "Epoch 10/10\n",
259
- "40995/40995 [==============================] - 8980s 219ms/step - loss: 0.0311 - binary_accuracy: 0.9900 - val_loss: 0.0356 - val_binary_accuracy: 0.9890\n"
260
- ]
261
- },
262
- {
263
- "ename": "RuntimeError",
264
- "evalue": "Can't decrement id ref count (unable to extend file properly)",
265
- "output_type": "error",
266
- "traceback": [
267
- "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
268
- "\u001b[1;31mOSError\u001b[0m Traceback (most recent call last)",
269
- "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\engine\\training.py:2252\u001b[0m, in \u001b[0;36mModel.save_weights\u001b[1;34m(self, filepath, overwrite, save_format, options)\u001b[0m\n\u001b[0;32m 2251\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m h5py\u001b[38;5;241m.\u001b[39mFile(filepath, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m-> 2252\u001b[0m \u001b[43mhdf5_format\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_weights_to_hdf5_group\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayers\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 2253\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
270
- "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\saving\\hdf5_format.py:646\u001b[0m, in \u001b[0;36msave_weights_to_hdf5_group\u001b[1;34m(f, layers)\u001b[0m\n\u001b[0;32m 645\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 646\u001b[0m param_dset[:] \u001b[38;5;241m=\u001b[39m val\n",
271
- "File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
272
- "File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
273
- "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\h5py\\_hl\\dataset.py:999\u001b[0m, in \u001b[0;36mDataset.__setitem__\u001b[1;34m(self, args, val)\u001b[0m\n\u001b[0;32m 998\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m fspace \u001b[38;5;129;01min\u001b[39;00m selection\u001b[38;5;241m.\u001b[39mbroadcast(mshape):\n\u001b[1;32m--> 999\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmspace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfspace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdxpl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dxpl\u001b[49m\u001b[43m)\u001b[49m\n",
274
- "File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
275
- "File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
276
- "File \u001b[1;32mh5py\\\\h5d.pyx:282\u001b[0m, in \u001b[0;36mh5py.h5d.DatasetID.write\u001b[1;34m()\u001b[0m\n",
277
- "File \u001b[1;32mh5py\\\\_proxy.pyx:115\u001b[0m, in \u001b[0;36mh5py._proxy.dset_rw\u001b[1;34m()\u001b[0m\n",
278
- "\u001b[1;31mOSError\u001b[0m: [Errno 28] Can't write data (file write failed: time = Wed May 7 10:48:36 2025\n, filename = 'mf-fine-tuned-protbert\\weights-10-0.04.hdf5', file descriptor = 4, errno = 28, error message = 'No space left on device', buf = 000002CC552FF040, total write size = 4194304, bytes this sub-write = 4194304, bytes actually written = 18446744073709551615, offset = 1180551864)",
279
- "\nDuring handling of the above exception, another exception occurred:\n",
280
- "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)",
281
- "Cell \u001b[1;32mIn[2], line 119\u001b[0m\n\u001b[0;32m 113\u001b[0m model \u001b[38;5;241m=\u001b[39m t\u001b[38;5;241m.\u001b[39mget_classifier()\n\u001b[0;32m 114\u001b[0m learner \u001b[38;5;241m=\u001b[39m ktrain\u001b[38;5;241m.\u001b[39mget_learner(model,\n\u001b[0;32m 115\u001b[0m train_data\u001b[38;5;241m=\u001b[39mtrn,\n\u001b[0;32m 116\u001b[0m val_data\u001b[38;5;241m=\u001b[39mval,\n\u001b[0;32m 117\u001b[0m batch_size\u001b[38;5;241m=\u001b[39mBATCH_SIZE)\n\u001b[1;32m--> 119\u001b[0m \u001b[43mlearner\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautofit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-5\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 120\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 121\u001b[0m \u001b[43m \u001b[49m\u001b[43mearly_stopping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 122\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmf-fine-tuned-protbert\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
282
- "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\core.py:1239\u001b[0m, in \u001b[0;36mLearner.autofit\u001b[1;34m(self, lr, epochs, early_stopping, reduce_on_plateau, reduce_factor, cycle_momentum, max_momentum, min_momentum, monitor, checkpoint_folder, class_weight, callbacks, steps_per_epoch, verbose)\u001b[0m\n\u001b[0;32m 1234\u001b[0m policy \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtriangular learning rate\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1235\u001b[0m U\u001b[38;5;241m.\u001b[39mvprint(\n\u001b[0;32m 1236\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbegin training using \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m policy with max lr of \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m...\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (policy, lr),\n\u001b[0;32m 1237\u001b[0m verbose\u001b[38;5;241m=\u001b[39mverbose,\n\u001b[0;32m 1238\u001b[0m )\n\u001b[1;32m-> 1239\u001b[0m hist \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1240\u001b[0m \u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1241\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1242\u001b[0m \u001b[43m \u001b[49m\u001b[43mearly_stopping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mearly_stopping\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1243\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheckpoint_folder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1244\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1245\u001b[0m \u001b[43m \u001b[49m\u001b[43mclass_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclass_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1246\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkcallbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1247\u001b[0m \u001b[43m \u001b[49m\u001b[43msteps_per_epoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msteps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1248\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1249\u001b[0m hist\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m clr\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m 1250\u001b[0m hist\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miterations\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m clr\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miterations\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
283
- "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\core.py:1650\u001b[0m, in \u001b[0;36mGenLearner.fit\u001b[1;34m(self, lr, n_cycles, cycle_len, cycle_mult, lr_decay, checkpoint_folder, early_stopping, class_weight, callbacks, steps_per_epoch, verbose)\u001b[0m\n\u001b[0;32m 1648\u001b[0m warnings\u001b[38;5;241m.\u001b[39mfilterwarnings(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mignore\u001b[39m\u001b[38;5;124m\"\u001b[39m, message\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.*Check your callbacks.*\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 1649\u001b[0m fit_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mfit\n\u001b[1;32m-> 1650\u001b[0m hist \u001b[38;5;241m=\u001b[39m \u001b[43mfit_fn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1651\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_data\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1652\u001b[0m \u001b[43m \u001b[49m\u001b[43msteps_per_epoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msteps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1653\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidation_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1654\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1655\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_data\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mval_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1656\u001b[0m \u001b[43m \u001b[49m\u001b[43mworkers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mworkers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1657\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_multiprocessing\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_multiprocessing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1658\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1659\u001b[0m \u001b[43m \u001b[49m\u001b[43mshuffle\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 1660\u001b[0m \u001b[43m \u001b[49m\u001b[43mclass_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclass_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1661\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkcallbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1662\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1663\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m sgdr \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 1664\u001b[0m hist\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m sgdr\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
284
- "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\engine\\training.py:1230\u001b[0m, in \u001b[0;36mModel.fit\u001b[1;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[0;32m 1227\u001b[0m val_logs \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m name: val \u001b[38;5;28;01mfor\u001b[39;00m name, val \u001b[38;5;129;01min\u001b[39;00m val_logs\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m 1228\u001b[0m epoch_logs\u001b[38;5;241m.\u001b[39mupdate(val_logs)\n\u001b[1;32m-> 1230\u001b[0m \u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mon_epoch_end\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch_logs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1231\u001b[0m training_logs \u001b[38;5;241m=\u001b[39m epoch_logs\n\u001b[0;32m 1232\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstop_training:\n",
285
- "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\callbacks.py:413\u001b[0m, in \u001b[0;36mCallbackList.on_epoch_end\u001b[1;34m(self, epoch, logs)\u001b[0m\n\u001b[0;32m 411\u001b[0m logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_logs(logs)\n\u001b[0;32m 412\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m callback \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallbacks:\n\u001b[1;32m--> 413\u001b[0m \u001b[43mcallback\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mon_epoch_end\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlogs\u001b[49m\u001b[43m)\u001b[49m\n",
286
- "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\callbacks.py:1368\u001b[0m, in \u001b[0;36mModelCheckpoint.on_epoch_end\u001b[1;34m(self, epoch, logs)\u001b[0m\n\u001b[0;32m 1366\u001b[0m \u001b[38;5;66;03m# pylint: disable=protected-access\u001b[39;00m\n\u001b[0;32m 1367\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msave_freq \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mepoch\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m-> 1368\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlogs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlogs\u001b[49m\u001b[43m)\u001b[49m\n",
287
- "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\callbacks.py:1431\u001b[0m, in \u001b[0;36mModelCheckpoint._save_model\u001b[1;34m(self, epoch, batch, logs)\u001b[0m\n\u001b[0;32m 1429\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m%05d\u001b[39;00m\u001b[38;5;124m: saving model to \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m%\u001b[39m (epoch \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m, filepath))\n\u001b[0;32m 1430\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msave_weights_only:\n\u001b[1;32m-> 1431\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_weights\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1432\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilepath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverwrite\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_options\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1433\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 1434\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39msave(filepath, overwrite\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, options\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_options)\n",
288
- "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\engine\\training.py:2252\u001b[0m, in \u001b[0;36mModel.save_weights\u001b[1;34m(self, filepath, overwrite, save_format, options)\u001b[0m\n\u001b[0;32m 2250\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m save_format \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mh5\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m 2251\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m h5py\u001b[38;5;241m.\u001b[39mFile(filepath, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m-> 2252\u001b[0m hdf5_format\u001b[38;5;241m.\u001b[39msave_weights_to_hdf5_group(f, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers)\n\u001b[0;32m 2253\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 2254\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tf\u001b[38;5;241m.\u001b[39mexecuting_eagerly():\n",
289
- "File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
290
- "File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
291
- "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\h5py\\_hl\\files.py:599\u001b[0m, in \u001b[0;36mFile.__exit__\u001b[1;34m(self, *args)\u001b[0m\n\u001b[0;32m 596\u001b[0m \u001b[38;5;129m@with_phil\u001b[39m\n\u001b[0;32m 597\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__exit__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs):\n\u001b[0;32m 598\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid:\n\u001b[1;32m--> 599\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclose\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
292
- "File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\h5py\\_hl\\files.py:581\u001b[0m, in \u001b[0;36mFile.close\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 575\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid\u001b[38;5;241m.\u001b[39mvalid:\n\u001b[0;32m 576\u001b[0m \u001b[38;5;66;03m# We have to explicitly murder all open objects related to the file\u001b[39;00m\n\u001b[0;32m 577\u001b[0m \n\u001b[0;32m 578\u001b[0m \u001b[38;5;66;03m# Close file-resident objects first, then the files.\u001b[39;00m\n\u001b[0;32m 579\u001b[0m \u001b[38;5;66;03m# Otherwise we get errors in MPI mode.\u001b[39;00m\n\u001b[0;32m 580\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid\u001b[38;5;241m.\u001b[39m_close_open_objects(h5f\u001b[38;5;241m.\u001b[39mOBJ_LOCAL \u001b[38;5;241m|\u001b[39m \u001b[38;5;241m~\u001b[39mh5f\u001b[38;5;241m.\u001b[39mOBJ_FILE)\n\u001b[1;32m--> 581\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_close_open_objects\u001b[49m\u001b[43m(\u001b[49m\u001b[43mh5f\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mOBJ_LOCAL\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m|\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mh5f\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mOBJ_FILE\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 583\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid\u001b[38;5;241m.\u001b[39mclose()\n\u001b[0;32m 584\u001b[0m _objects\u001b[38;5;241m.\u001b[39mnonlocal_close()\n",
293
- "File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
294
- "File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
295
- "File \u001b[1;32mh5py\\\\h5f.pyx:355\u001b[0m, in \u001b[0;36mh5py.h5f.FileID._close_open_objects\u001b[1;34m()\u001b[0m\n",
296
- "\u001b[1;31mRuntimeError\u001b[0m: Can't decrement id ref count (unable to extend file properly)"
297
- ]
298
- }
299
- ],
300
- "source": [
301
- "import pandas as pd\n",
302
- "import numpy as np\n",
303
- "from tqdm import tqdm\n",
304
- "import random\n",
305
- "import os\n",
306
- "import ktrain\n",
307
- "from ktrain import text\n",
308
- "from sklearn.preprocessing import MultiLabelBinarizer\n",
309
- "\n",
310
- "\n",
311
- "# PAM1\n",
312
- "# PAM matrix model of protein evolution\n",
313
- "# DOI:10.1093/oxfordjournals.molbev.a040360\n",
314
- "pam_data = {\n",
315
- " 'A': [9948, 19, 27, 42, 31, 46, 50, 92, 17, 7, 40, 88, 42, 41, 122, 279, 255, 9, 72, 723],\n",
316
- " 'R': [14, 9871, 24, 38, 37, 130, 38, 62, 49, 4, 58, 205, 26, 33, 47, 103, 104, 5, 36, 52],\n",
317
- " 'N': [20, 22, 9860, 181, 29, 36, 41, 67, 31, 5, 22, 49, 23, 10, 33, 83, 66, 3, 43, 32],\n",
318
- " 'D': [40, 34, 187, 9818, 11, 63, 98, 61, 23, 5, 25, 54, 43, 13, 27, 88, 55, 4, 29, 36],\n",
319
- " 'C': [20, 16, 26, 9, 9987, 10, 17, 37, 12, 2, 16, 26, 10, 19, 27, 26, 25, 2, 6, 67],\n",
320
- " 'Q': [29, 118, 29, 49, 8, 9816, 72, 55, 36, 4, 60, 158, 35, 22, 39, 86, 74, 3, 34, 28],\n",
321
- " 'E': [35, 29, 41, 101, 12, 71, 9804, 56, 33, 5, 36, 107, 42, 20, 38, 87, 69, 4, 30, 42],\n",
322
- " 'G': [96, 61, 77, 70, 38, 51, 58, 9868, 26, 6, 37, 53, 39, 28, 69, 134, 116, 5, 47, 60],\n",
323
- " 'H': [17, 53, 33, 19, 15, 39, 34, 24, 9907, 3, 32, 57, 24, 15, 27, 47, 43, 2, 22, 19],\n",
324
- " 'I': [6, 3, 6, 6, 3, 5, 6, 7, 3, 9973, 23, 13, 12, 41, 93, 84, 115, 3, 8, 102],\n",
325
- " 'L': [26, 39, 17, 15, 7, 33, 22, 20, 19, 27, 9864, 49, 24, 78, 117, 148, 193, 5, 24, 70],\n",
326
- " 'K': [60, 198, 43, 52, 12, 142, 96, 53, 42, 10, 63, 9710, 33, 26, 54, 109, 102, 5, 43, 42],\n",
327
- " 'M': [21, 22, 15, 18, 6, 20, 18, 18, 17, 11, 27, 32, 9945, 26, 34, 61, 71, 3, 12, 31],\n",
328
- " 'F': [18, 17, 8, 6, 8, 11, 10, 16, 10, 44, 92, 24, 29, 9899, 89, 88, 142, 7, 14, 68],\n",
329
- " 'P': [97, 47, 35, 29, 23, 35, 38, 57, 21, 24, 47, 56, 28, 76, 9785, 115, 77, 4, 24, 35],\n",
330
- " 'S': [241, 87, 76, 73, 17, 56, 60, 99, 32, 13, 69, 92, 42, 67, 100, 9605, 212, 8, 63, 70],\n",
331
- " 'T': [186, 78, 54, 37, 14, 42, 42, 83, 28, 23, 84, 85, 53, 93, 66, 182, 9676, 8, 39, 90],\n",
332
- " 'W': [2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 5, 3, 4, 4, 9960, 3, 4],\n",
333
- " 'Y': [29, 21, 17, 9, 4, 13, 9, 21, 10, 7, 20, 17, 11, 23, 19, 41, 31, 3, 9935, 23],\n",
334
- " 'V': [368, 27, 18, 18, 50, 23, 34, 64, 15, 85, 72, 42, 33, 88, 42, 112, 137, 4, 20, 9514]\n",
335
- "}\n",
336
- "pam_raw = pd.DataFrame(pam_data, index=list(pam_data.keys()))\n",
337
- "pam_matrix = pam_raw.div(pam_raw.sum(axis=1), axis=0)\n",
338
- "list_amino = pam_raw.columns.tolist()\n",
339
- "pam_dict = {\n",
340
- " aa: {sub: pam_matrix.loc[aa, sub] for sub in list_amino}\n",
341
- " for aa in list_amino\n",
342
- "}\n",
343
- "\n",
344
- "def pam1_substitution(aa):\n",
345
- " if aa not in pam_dict:\n",
346
- " return aa\n",
347
- " subs = list(pam_dict[aa].keys())\n",
348
- " probs = list(pam_dict[aa].values())\n",
349
- " return np.random.choice(subs, p=probs)\n",
350
- "\n",
351
- "def augment_sequence(seq, sub_prob=0.05):\n",
352
- " return ''.join([pam1_substitution(aa) if random.random() < sub_prob else aa for aa in seq])\n",
353
- "\n",
354
- "def slice_sequence(seq, win=500, min_overlap=250):\n",
355
- " if len(seq) <= win:\n",
356
- " return [seq]\n",
357
- " slices, start = [], 0\n",
358
- " while start + win <= len(seq):\n",
359
- " slices.append(seq[start:start+win])\n",
360
- " start += win\n",
361
- " leftover = seq[start:]\n",
362
- " if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
363
- " extra = slices[-1][-min_overlap:] + leftover\n",
364
- " slices.append(extra)\n",
365
- " return slices\n",
366
- "\n",
367
- "def generate_data(df, augment=False):\n",
368
- " X, y = [], []\n",
369
- " label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
370
- " for _, row in tqdm(df.iterrows(), total=len(df)):\n",
371
- " seq = row[\"sequence\"]\n",
372
- " if augment:\n",
373
- " seq = augment_sequence(seq)\n",
374
- " seq_slices = slice_sequence(seq)\n",
375
- " X.extend(seq_slices)\n",
376
- " lbl = row[label_cols].values.astype(int)\n",
377
- " y.extend([lbl] * len(seq_slices))\n",
378
- " return X, np.array(y), label_cols\n",
379
- "\n",
380
- "def format_sequence(seq): return \" \".join(list(seq))\n",
381
- "\n",
382
- "# Função para carregar e binarizar\n",
383
- "def load_and_binarize(csv_path, mlb=None):\n",
384
- " df = pd.read_csv(csv_path)\n",
385
- " df[\"go_terms\"] = df[\"go_terms\"].str.split(\";\")\n",
386
- " if mlb is None:\n",
387
- " mlb = MultiLabelBinarizer()\n",
388
- " labels = mlb.fit_transform(df[\"go_terms\"])\n",
389
- " else:\n",
390
- " labels = mlb.transform(df[\"go_terms\"])\n",
391
- " labels_df = pd.DataFrame(labels, columns=mlb.classes_)\n",
392
- " df = df.reset_index(drop=True).join(labels_df)\n",
393
- " return df, mlb\n",
394
- "\n",
395
- "# Carregar os dados\n",
396
- "df_train, mlb = load_and_binarize(\"data/mf-training.csv\")\n",
397
- "df_val, _ = load_and_binarize(\"data/mf-validation.csv\", mlb=mlb)\n",
398
- "\n",
399
- "# Gerar com augmentation no treino\n",
400
- "X_train, y_train, term_cols = generate_data(df_train, augment=True)\n",
401
- "X_val, y_val, _ = generate_data(df_val, augment=False)\n",
402
- "\n",
403
- "# Preparar texto para tokenizer\n",
404
- "X_train_fmt = list(map(format_sequence, X_train))\n",
405
- "X_val_fmt = list(map(format_sequence, X_val))\n",
406
- "\n",
407
- "# Fine-tune ProtBERT\n",
408
- "# https://huggingface.co/Rostlab/prot_bert\n",
409
- "# https://doi.org/10.1093/bioinformatics/btac020\n",
410
- "# dados de treino-> UniRef100 (216 milhões de sequências)\n",
411
- "MODEL_NAME = \"Rostlab/prot_bert\"\n",
412
- "MAX_LEN = 512\n",
413
- "BATCH_SIZE = 1\n",
414
- "\n",
415
- "t = text.Transformer(MODEL_NAME, maxlen=MAX_LEN, classes=term_cols)\n",
416
- "trn = t.preprocess_train(X_train_fmt, y_train)\n",
417
- "val = t.preprocess_test(X_val_fmt, y_val)\n",
418
- "\n",
419
- "model = t.get_classifier()\n",
420
- "learner = ktrain.get_learner(model,\n",
421
- " train_data=trn,\n",
422
- " val_data=val,\n",
423
- " batch_size=BATCH_SIZE)\n",
424
- "\n",
425
- "learner.autofit(lr=1e-5,\n",
426
- " epochs=10,\n",
427
- " early_stopping=1,\n",
428
- " checkpoint_folder=\"mf-fine-tuned-protbert\")\n"
429
- ]
430
- },
431
- {
432
- "cell_type": "code",
433
- "execution_count": 7,
434
- "id": "c66774b3-6cf0-41c5-bb01-9467a5283102",
435
- "metadata": {},
436
- "outputs": [
437
- {
438
- "name": "stdout",
439
- "output_type": "stream",
440
- "text": [
441
- "✅ Existe: weights/mf-fine-tuned-protbert-epoch10\n",
442
- "📁 Conteúdo:\n",
443
- " - config.json\n",
444
- " - tf_model.h5\n"
445
- ]
446
- }
447
- ],
448
- "source": [
449
- "import os\n",
450
- "\n",
451
- "path = \"weights/mf-fine-tuned-protbert-epoch10\"\n",
452
- "\n",
453
- "if os.path.exists(path):\n",
454
- " print(f\"✅ Existe: {path}\")\n",
455
- " print(\"📁 Conteúdo:\")\n",
456
- " for f in os.listdir(path):\n",
457
- " print(\" -\", f)\n",
458
- "else:\n",
459
- " print(f\"❌ Não existe: {path}\")\n",
460
- "\n"
461
- ]
462
- },
463
- {
464
- "cell_type": "code",
465
- "execution_count": 19,
466
- "id": "9b39c439-5708-4787-bfee-d3a4d3aa190d",
467
- "metadata": {},
468
- "outputs": [
469
- {
470
- "name": "stderr",
471
- "output_type": "stream",
472
- "text": [
473
- "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
474
- " from .autonotebook import tqdm as notebook_tqdm\n",
475
- "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
476
- " _torch_pytree._register_pytree_node(\n",
477
- "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
478
- " warnings.warn(\n",
479
- "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:309: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
480
- " _torch_pytree._register_pytree_node(\n",
481
- "Some layers from the model checkpoint at weights/mf-fine-tuned-protbert-epoch10 were not used when initializing TFBertModel: ['classifier', 'dropout_183']\n",
482
- "- This IS expected if you are initializing TFBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
483
- "- This IS NOT expected if you are initializing TFBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
484
- "All the layers of TFBertModel were initialized from the model checkpoint at weights/mf-fine-tuned-protbert-epoch10.\n",
485
- "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.\n"
486
- ]
487
- },
488
- {
489
- "name": "stdout",
490
- "output_type": "stream",
491
- "text": [
492
- "✓ Tokenizer base e modelo fine-tuned carregados com sucesso\n"
493
- ]
494
- },
495
- {
496
- "name": "stderr",
497
- "output_type": "stream",
498
- "text": [
499
- "Processando data/mf-training.csv: 0%| | 25/31142 [00:06<2:23:28, 3.61it/s]\n"
500
- ]
501
- },
502
- {
503
- "ename": "KeyboardInterrupt",
504
- "evalue": "",
505
- "output_type": "error",
506
- "traceback": [
507
- "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
508
- "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
509
- "Cell \u001b[1;32mIn[19], line 78\u001b[0m\n\u001b[0;32m 75\u001b[0m \u001b[38;5;66;03m# --- 4. Aplicar -----------------------------------------------------------\u001b[39;00m\n\u001b[0;32m 76\u001b[0m os\u001b[38;5;241m.\u001b[39mmakedirs(OUT_DIR, exist_ok\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m---> 78\u001b[0m \u001b[43mprocess_split\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdata/mf-training.csv\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mOUT_DIR\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtrain_protbert.pkl\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 79\u001b[0m process_split(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata/mf-validation.csv\u001b[39m\u001b[38;5;124m\"\u001b[39m, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(OUT_DIR, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mval_protbert.pkl\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[0;32m 80\u001b[0m process_split(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata/mf-test.csv\u001b[39m\u001b[38;5;124m\"\u001b[39m, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(OUT_DIR, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_protbert.pkl\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n",
510
- "Cell \u001b[1;32mIn[19], line 61\u001b[0m, in \u001b[0;36mprocess_split\u001b[1;34m(csv_path, out_path)\u001b[0m\n\u001b[0;32m 59\u001b[0m embeds\u001b[38;5;241m.\u001b[39mappend(prot_embed\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32))\n\u001b[0;32m 60\u001b[0m labels\u001b[38;5;241m.\u001b[39mappend(row[label_cols]\u001b[38;5;241m.\u001b[39mvalues\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mint8))\n\u001b[1;32m---> 61\u001b[0m gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[0;32m 63\u001b[0m embeds \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mvstack(embeds)\n\u001b[0;32m 64\u001b[0m labels \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mvstack(labels)\n",
511
- "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
512
- ]
513
- }
514
- ],
515
- "source": [
516
- "import os\n",
517
- "import pandas as pd\n",
518
- "import numpy as np\n",
519
- "from tqdm import tqdm\n",
520
- "import joblib\n",
521
- "import gc\n",
522
- "from transformers import AutoTokenizer, TFAutoModel\n",
523
- "\n",
524
- "# --- 1. Parâmetros --------------------------------------------------------\n",
525
- "MODEL_DIR = \"weights/mf-fine-tuned-protbert-epoch10\"\n",
526
- "BASE_MODEL = \"Rostlab/prot_bert\"\n",
527
- "OUT_DIR = \"embeddings\"\n",
528
- "BATCH_TOK = 16\n",
529
- "\n",
530
- "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, do_lower_case=False)\n",
531
- "model = TFAutoModel.from_pretrained(MODEL_DIR, from_pt=False)\n",
532
- "\n",
533
- "print(\"✓ Tokenizer base e modelo fine-tuned carregados com sucesso\")\n",
534
- "\n",
535
- "# --- 3. Funções auxiliares ------------------------------------------------\n",
536
- "def format_sequence(seq):\n",
537
- " return \" \".join(list(seq))\n",
538
- "\n",
539
- "def slice_sequence(seq, win=500, min_overlap=250):\n",
540
- " if len(seq) <= win:\n",
541
- " return [seq]\n",
542
- " slices, start = [], 0\n",
543
- " while start + win <= len(seq):\n",
544
- " slices.append(seq[start:start+win])\n",
545
- " start += win\n",
546
- " leftover = seq[start:]\n",
547
- " if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
548
- " extra = slices[-1][-min_overlap:] + leftover\n",
549
- " slices.append(extra)\n",
550
- " return slices\n",
551
- "\n",
552
- "def get_embeddings(batch, tokenizer, model):\n",
553
- " tokens = tokenizer(batch, return_tensors=\"tf\", padding=True, truncation=True, max_length=512)\n",
554
- " output = model(**tokens)\n",
555
- " return output.last_hidden_state[:, 0, :].numpy()\n",
556
- "\n",
557
- "def process_split(csv_path, out_path):\n",
558
- " df = pd.read_csv(csv_path)\n",
559
- " label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
560
- " prot_ids, embeds, labels = [], [], []\n",
561
- "\n",
562
- " for _, row in tqdm(df.iterrows(), total=len(df), desc=f\"Processando {csv_path}\"):\n",
563
- " slices = slice_sequence(row[\"sequence\"])\n",
564
- " slices_fmt = list(map(format_sequence, slices))\n",
565
- "\n",
566
- " slice_embeds = []\n",
567
- " for i in range(0, len(slices_fmt), BATCH_TOK):\n",
568
- " batch = slices_fmt[i:i+BATCH_TOK]\n",
569
- " slice_embeds.append(get_embeddings(batch, tokenizer, model))\n",
570
- " slice_embeds = np.vstack(slice_embeds)\n",
571
- "\n",
572
- " prot_embed = slice_embeds.mean(axis=0)\n",
573
- " prot_ids.append(row[\"protein_id\"])\n",
574
- " embeds.append(prot_embed.astype(np.float32))\n",
575
- " labels.append(row[label_cols].values.astype(np.int8))\n",
576
- " gc.collect()\n",
577
- "\n",
578
- " embeds = np.vstack(embeds)\n",
579
- " labels = np.vstack(labels)\n",
580
- "\n",
581
- " joblib.dump({\n",
582
- " \"protein_ids\": prot_ids,\n",
583
- " \"embeddings\": embeds,\n",
584
- " \"labels\": labels,\n",
585
- " \"go_terms\": label_cols\n",
586
- " }, out_path, compress=3)\n",
587
- "\n",
588
- " print(f\"✓ Guardado {out_path} — {embeds.shape[0]} proteínas\")\n",
589
- "\n",
590
- "# --- 4. Aplicar -----------------------------------------------------------\n",
591
- "os.makedirs(OUT_DIR, exist_ok=True)\n",
592
- "\n",
593
- "process_split(\"data/mf-training.csv\", os.path.join(OUT_DIR, \"train_protbert.pkl\"))\n",
594
- "process_split(\"data/mf-validation.csv\", os.path.join(OUT_DIR, \"val_protbert.pkl\"))\n",
595
- "process_split(\"data/mf-test.csv\", os.path.join(OUT_DIR, \"test_protbert.pkl\"))\n"
596
- ]
597
- },
598
- {
599
- "cell_type": "code",
600
- "execution_count": 27,
601
- "id": "ad0c5421-e0a1-4a6a-8ace-2c69aeab0e0d",
602
- "metadata": {},
603
- "outputs": [
604
- {
605
- "name": "stdout",
606
- "output_type": "stream",
607
- "text": [
608
- "✓ Corrigido: embeddings/train_protbert.pkl — 31142 exemplos, 597 GO terms\n",
609
- "✓ Corrigido: embeddings/val_protbert.pkl — 1724 exemplos, 597 GO terms\n",
610
- "✓ Corrigido: embeddings/test_protbert.pkl — 1724 exemplos, 597 GO terms\n"
611
- ]
612
- }
613
- ],
614
- "source": [
615
- "import pandas as pd\n",
616
- "import joblib\n",
617
- "from sklearn.preprocessing import MultiLabelBinarizer\n",
618
- "\n",
619
- "# --- 1. Obter GO terms do ficheiro de teste --------------------------------\n",
620
- "df_test = pd.read_csv(\"data/mf-test.csv\")\n",
621
- "test_terms = sorted(set(term for row in df_test[\"go_terms\"].str.split(\";\") for term in row))\n",
622
- "\n",
623
- "# --- 2. Função para corrigir um .pkl com base nos GO terms do teste --------\n",
624
- "def patch_to_common_terms(csv_path, pkl_path, common_terms):\n",
625
- " df = pd.read_csv(csv_path)\n",
626
- " terms_split = df[\"go_terms\"].str.split(\";\")\n",
627
- " \n",
628
- " # Apenas termos presentes nos common_terms\n",
629
- " terms_filtered = terms_split.apply(lambda lst: [t for t in lst if t in common_terms])\n",
630
- " \n",
631
- " mlb = MultiLabelBinarizer(classes=common_terms)\n",
632
- " Y = mlb.fit_transform(terms_filtered)\n",
633
- "\n",
634
- " data = joblib.load(pkl_path)\n",
635
- " data[\"labels\"] = Y\n",
636
- " data[\"go_terms\"] = mlb.classes_.tolist()\n",
637
- " \n",
638
- " joblib.dump(data, pkl_path, compress=3)\n",
639
- " print(f\"✓ Corrigido: {pkl_path} — {Y.shape[0]} exemplos, {Y.shape[1]} GO terms\")\n",
640
- "\n",
641
- "# --- 3. Aplicar às 3 partições --------------------------------------------\n",
642
- "patch_to_common_terms(\"data/mf-training.csv\", \"embeddings/train_protbert.pkl\", test_terms)\n",
643
- "patch_to_common_terms(\"data/mf-validation.csv\", \"embeddings/val_protbert.pkl\", test_terms)\n",
644
- "patch_to_common_terms(\"data/mf-test.csv\", \"embeddings/test_protbert.pkl\", test_terms)\n"
645
- ]
646
- },
647
- {
648
- "cell_type": "code",
649
- "execution_count": 1,
650
- "id": "dbd5c35f-4a08-4906-9cf4-e1df501d1ecb",
651
- "metadata": {},
652
- "outputs": [],
653
- "source": [
654
- "import joblib\n",
655
- "train = joblib.load(\"embeddings/train_protbert.pkl\")\n",
656
- "val = joblib.load(\"embeddings/val_protbert.pkl\")\n",
657
- "test = joblib.load(\"embeddings/test_protbert.pkl\")\n",
658
- "\n",
659
- "X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
660
- "X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
661
- "X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n"
662
- ]
663
- },
664
- {
665
- "cell_type": "code",
666
- "execution_count": 3,
667
- "id": "1785d8a9-23fc-4490-8d71-29cc91a4cb57",
668
- "metadata": {},
669
- "outputs": [
670
- {
671
- "name": "stdout",
672
- "output_type": "stream",
673
- "text": [
674
- "✓ Embeddings carregados: (31142, 1024) → 597 GO terms\n",
675
- "Epoch 1/100\n",
676
- "974/974 [==============================] - 4s 3ms/step - loss: 0.0358 - binary_accuracy: 0.9893 - val_loss: 0.0336 - val_binary_accuracy: 0.9901\n",
677
- "Epoch 2/100\n",
678
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0276 - binary_accuracy: 0.9914 - val_loss: 0.0331 - val_binary_accuracy: 0.9902\n",
679
- "Epoch 3/100\n",
680
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0268 - binary_accuracy: 0.9916 - val_loss: 0.0330 - val_binary_accuracy: 0.9902\n",
681
- "Epoch 4/100\n",
682
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0264 - binary_accuracy: 0.9917 - val_loss: 0.0320 - val_binary_accuracy: 0.9904\n",
683
- "Epoch 5/100\n",
684
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0260 - binary_accuracy: 0.9917 - val_loss: 0.0319 - val_binary_accuracy: 0.9904\n",
685
- "Epoch 6/100\n",
686
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0256 - binary_accuracy: 0.9918 - val_loss: 0.0322 - val_binary_accuracy: 0.9904\n",
687
- "Epoch 7/100\n",
688
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0255 - binary_accuracy: 0.9918 - val_loss: 0.0317 - val_binary_accuracy: 0.9903\n",
689
- "Epoch 8/100\n",
690
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0252 - binary_accuracy: 0.9919 - val_loss: 0.0320 - val_binary_accuracy: 0.9905\n",
691
- "Epoch 9/100\n",
692
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0251 - binary_accuracy: 0.9919 - val_loss: 0.0316 - val_binary_accuracy: 0.9904\n",
693
- "Epoch 10/100\n",
694
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0250 - binary_accuracy: 0.9920 - val_loss: 0.0314 - val_binary_accuracy: 0.9905\n",
695
- "Epoch 11/100\n",
696
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0248 - binary_accuracy: 0.9920 - val_loss: 0.0317 - val_binary_accuracy: 0.9905\n",
697
- "Epoch 12/100\n",
698
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0247 - binary_accuracy: 0.9920 - val_loss: 0.0315 - val_binary_accuracy: 0.9905\n",
699
- "Epoch 13/100\n",
700
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0246 - binary_accuracy: 0.9920 - val_loss: 0.0322 - val_binary_accuracy: 0.9904\n",
701
- "Epoch 14/100\n",
702
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0245 - binary_accuracy: 0.9920 - val_loss: 0.0319 - val_binary_accuracy: 0.9905\n",
703
- "Epoch 15/100\n",
704
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0244 - binary_accuracy: 0.9920 - val_loss: 0.0319 - val_binary_accuracy: 0.9906\n",
705
- "Previsões guardadas em mf-protbert-pam1.npy\n",
706
- "Modelo guardado em models/protbert_mlp.keras\n"
707
- ]
708
- }
709
- ],
710
- "source": [
711
- "import tensorflow as tf\n",
712
- "import joblib\n",
713
- "import numpy as np\n",
714
- "from tensorflow.keras.models import Sequential\n",
715
- "from tensorflow.keras.layers import Dense, Dropout\n",
716
- "from tensorflow.keras.callbacks import EarlyStopping\n",
717
- "\n",
718
- "# --- 1. Carregar embeddings ----------------------------------------------\n",
719
- "train = joblib.load(\"embeddings/train_protbert.pkl\")\n",
720
- "val = joblib.load(\"embeddings/val_protbert.pkl\")\n",
721
- "test = joblib.load(\"embeddings/test_protbert.pkl\")\n",
722
- "\n",
723
- "X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
724
- "X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
725
- "X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n",
726
- "\n",
727
- "print(f\"✓ Embeddings carregados: {X_train.shape} → {y_train.shape[1]} GO terms\")\n",
728
- "\n",
729
- "# --- 2. Garantir consistência de classes ---------------------------------\n",
730
- "max_classes = y_train.shape[1] # 602 GO terms (do treino)\n",
731
- "\n",
732
- "def pad_labels(y, target_dim=max_classes):\n",
733
- " if y.shape[1] < target_dim:\n",
734
- " padding = np.zeros((y.shape[0], target_dim - y.shape[1]), dtype=np.int8)\n",
735
- " return np.hstack([y, padding])\n",
736
- " return y\n",
737
- "\n",
738
- "y_val = pad_labels(y_val)\n",
739
- "y_test = pad_labels(y_test)\n",
740
- "\n",
741
- "# --- 3. Modelo MLP ------------------------------------------------------\n",
742
- "model = Sequential([\n",
743
- " Dense(1024, activation=\"relu\", input_shape=(X_train.shape[1],)),\n",
744
- " Dropout(0.3),\n",
745
- " Dense(512, activation=\"relu\"),\n",
746
- " Dropout(0.3),\n",
747
- " Dense(max_classes, activation=\"sigmoid\")\n",
748
- "])\n",
749
- "\n",
750
- "model.compile(loss=\"binary_crossentropy\",\n",
751
- " optimizer=\"adam\",\n",
752
- " metrics=[\"binary_accuracy\"])\n",
753
- "\n",
754
- "# --- 4. Early stopping e treino -----------------------------------------\n",
755
- "callbacks = [\n",
756
- " EarlyStopping(monitor=\"val_loss\", patience=5, restore_best_weights=True)\n",
757
- "]\n",
758
- "\n",
759
- "model.fit(X_train, y_train,\n",
760
- " validation_data=(X_val, y_val),\n",
761
- " epochs=100,\n",
762
- " batch_size=32,\n",
763
- " callbacks=callbacks,\n",
764
- " verbose=1)\n",
765
- "\n",
766
- "# --- 5. Previsões --------------------------------------------------------\n",
767
- "y_prob = model.predict(X_test)\n",
768
- "np.save(\"predictions/mf-protbert-pam1.npy\", y_prob)\n",
769
- "print(\"Previsões guardadas em mf-protbert-pam1.npy\")\n",
770
- "\n",
771
- "# --- 6. Modelo ----------------------------------------------------------\n",
772
- "model.save(\"models/protbert_mlp.keras\")\n",
773
- "print(\"Modelo guardado em models/protbert_mlp.keras\")"
774
- ]
775
- },
776
- {
777
- "cell_type": "code",
778
- "execution_count": 30,
779
- "id": "fdb66630-76dc-43a0-bd56-45052175fdba",
780
- "metadata": {},
781
- "outputs": [
782
- {
783
- "name": "stdout",
784
- "output_type": "stream",
785
- "text": [
786
- "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
787
- "✓ Embeddings: (1724, 597) labels × 597 GO terms\n",
788
- "\n",
789
- "📊 Resultados finais (ProtBERT + PAM1 + propagação):\n",
790
- "Fmax = 0.6666\n",
791
- "Thr. = 0.50\n",
792
- "AuPRC = 0.7028\n",
793
- "Smin = 13.1745\n"
794
- ]
795
- }
796
- ],
797
- "source": [
798
- "import numpy as np\n",
799
- "from sklearn.metrics import precision_recall_curve, auc\n",
800
- "from goatools.obo_parser import GODag\n",
801
- "import joblib\n",
802
- "import math\n",
803
- "\n",
804
- "# --- 1. Parâmetros -------------------------------------------------------\n",
805
- "GO_FILE = \"go.obo\"\n",
806
- "THRESHOLDS = np.arange(0.0, 1.01, 0.01)\n",
807
- "ALPHA = 0.5\n",
808
- "\n",
809
- "# --- 2. Carregar dados ---------------------------------------------------\n",
810
- "test = joblib.load(\"embeddings/test_protbert.pkl\")\n",
811
- "y_true = test[\"labels\"]\n",
812
- "terms = test[\"go_terms\"]\n",
813
- "y_prob = np.load(\"predictions/mf-protbert-pam1.npy\")\n",
814
- "go_dag = GODag(GO_FILE)\n",
815
- "\n",
816
- "print(f\"✓ Embeddings: {y_true.shape} labels × {len(terms)} GO terms\")\n",
817
- "\n",
818
- "# --- 3. Fmax -------------------------------------------------------------\n",
819
- "def compute_fmax(y_true, y_prob, thresholds):\n",
820
- " fmax, best_thr = 0, 0\n",
821
- " for t in thresholds:\n",
822
- " y_pred = (y_prob >= t).astype(int)\n",
823
- " tp = (y_true * y_pred).sum(axis=1)\n",
824
- " fp = ((1 - y_true) * y_pred).sum(axis=1)\n",
825
- " fn = (y_true * (1 - y_pred)).sum(axis=1)\n",
826
- " precision = tp / (tp + fp + 1e-8)\n",
827
- " recall = tp / (tp + fn + 1e-8)\n",
828
- " f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
829
- " avg_f1 = np.mean(f1)\n",
830
- " if avg_f1 > fmax:\n",
831
- " fmax, best_thr = avg_f1, t\n",
832
- " return fmax, best_thr\n",
833
- "\n",
834
- "# --- 4. AuPRC micro ------------------------------------------------------\n",
835
- "def compute_auprc(y_true, y_prob):\n",
836
- " precision, recall, _ = precision_recall_curve(y_true.ravel(), y_prob.ravel())\n",
837
- " return auc(recall, precision)\n",
838
- "\n",
839
- "# --- 5. Smin -------------------------------------------------------------\n",
840
- "def compute_smin(y_true, y_prob, terms, threshold, go_dag, alpha=ALPHA):\n",
841
- " y_pred = (y_prob >= threshold).astype(int)\n",
842
- " ic = {}\n",
843
- " total = (y_true + y_pred).sum(axis=0).sum()\n",
844
- " for i, term in enumerate(terms):\n",
845
- " freq = (y_true[:, i] + y_pred[:, i]).sum()\n",
846
- " ic[term] = -np.log((freq + 1e-8) / total)\n",
847
- "\n",
848
- " s_values = []\n",
849
- " for true_vec, pred_vec in zip(y_true, y_pred):\n",
850
- " true_terms = {terms[i] for i in np.where(true_vec)[0]}\n",
851
- " pred_terms = {terms[i] for i in np.where(pred_vec)[0]}\n",
852
- "\n",
853
- " anc_true = set()\n",
854
- " for t in true_terms:\n",
855
- " if t in go_dag:\n",
856
- " anc_true |= go_dag[t].get_all_parents()\n",
857
- " anc_pred = set()\n",
858
- " for t in pred_terms:\n",
859
- " if t in go_dag:\n",
860
- " anc_pred |= go_dag[t].get_all_parents()\n",
861
- "\n",
862
- " ru = pred_terms - true_terms\n",
863
- " mi = true_terms - pred_terms\n",
864
- " dist_ru = sum(ic.get(t, 0) for t in ru)\n",
865
- " dist_mi = sum(ic.get(t, 0) for t in mi)\n",
866
- " s = math.sqrt((alpha * dist_ru)**2 + ((1 - alpha) * dist_mi)**2)\n",
867
- " s_values.append(s)\n",
868
- "\n",
869
- " return np.mean(s_values)\n",
870
- "\n",
871
- "# --- 6. Avaliar ----------------------------------------------------------\n",
872
- "fmax, thr = compute_fmax(y_true, y_prob, THRESHOLDS)\n",
873
- "auprc = compute_auprc(y_true, y_prob)\n",
874
- "smin = compute_smin(y_true, y_prob, terms, thr, go_dag)\n",
875
- "\n",
876
- "print(f\"\\n📊 Resultados finais (ProtBERT + PAM1 + propagação):\")\n",
877
- "print(f\"Fmax = {fmax:.4f}\")\n",
878
- "print(f\"Thr. = {thr:.2f}\")\n",
879
- "print(f\"AuPRC = {auprc:.4f}\")\n",
880
- "print(f\"Smin = {smin:.4f}\")\n"
881
- ]
882
- },
883
- {
884
- "cell_type": "code",
885
- "execution_count": 3,
886
- "id": "70d131ef-ef84-42ee-953b-0d3f1268694d",
887
- "metadata": {},
888
- "outputs": [
889
- {
890
- "data": {
891
- "text/plain": [
892
- "['data/mlb_protbert.pkl']"
893
- ]
894
- },
895
- "execution_count": 3,
896
- "metadata": {},
897
- "output_type": "execute_result"
898
- }
899
- ],
900
- "source": [
901
- "import joblib, pickle\n",
902
- "joblib.dump(mlb, \"data/mlb_protbert.pkl\")"
903
- ]
904
- },
905
- {
906
- "cell_type": "code",
907
- "execution_count": null,
908
- "id": "9f89c3bc-6b78-4a4c-8ddd-b69c7d3d0e65",
909
- "metadata": {},
910
- "outputs": [],
911
- "source": []
912
- }
913
- ],
914
- "metadata": {
915
- "kernelspec": {
916
- "display_name": "Python 3 (ipykernel)",
917
- "language": "python",
918
- "name": "python3"
919
- },
920
- "language_info": {
921
- "codemirror_mode": {
922
- "name": "ipython",
923
- "version": 3
924
- },
925
- "file_extension": ".py",
926
- "mimetype": "text/x-python",
927
- "name": "python",
928
- "nbconvert_exporter": "python",
929
- "pygments_lexer": "ipython3",
930
- "version": "3.8.18"
931
- }
932
- },
933
- "nbformat": 4,
934
- "nbformat_minor": 5
935
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/PAM1_protbertBFD.ipynb DELETED
@@ -1,871 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "id": "c6dbc330-062a-48f0-8242-3f21cc1c9c2b",
7
- "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stdout",
11
- "output_type": "stream",
12
- "text": [
13
- "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
14
- "✓ Ficheiros criados:\n",
15
- " - data/mf-training.csv : (31142, 3)\n",
16
- " - data/mf-validation.csv: (1724, 3)\n",
17
- " - data/mf-test.csv : (1724, 3)\n",
18
- "GO terms únicos (após propagação e filtro): 602\n"
19
- ]
20
- }
21
- ],
22
- "source": [
23
- "import pandas as pd\n",
24
- "from Bio import SeqIO\n",
25
- "from collections import Counter\n",
26
- "from goatools.obo_parser import GODag\n",
27
- "from sklearn.model_selection import train_test_split\n",
28
- "from sklearn.preprocessing import MultiLabelBinarizer\n",
29
- "from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
30
- "import numpy as np\n",
31
- "import os\n",
32
- "\n",
33
- "# --- 1. Carregar GO anotações ------------------------------------------\n",
34
- "annotations = pd.read_csv(\"uniprot_sprot_exp.txt\", sep=\"\\t\", names=[\"protein_id\", \"go_term\", \"go_category\"])\n",
35
- "annotations_f = annotations[annotations[\"go_category\"] == \"F\"]\n",
36
- "\n",
37
- "# --- 2. Carregar DAG e propagar GO terms -------------------------------\n",
38
- "# propagação hierárquica\n",
39
- "# https://geneontology.org/docs/download-ontology/\n",
40
- "go_dag = GODag(\"go.obo\")\n",
41
- "mf_terms = {t for t, o in go_dag.items() if o.namespace == \"molecular_function\"}\n",
42
- "\n",
43
- "def propagate_terms(term_list):\n",
44
- " full = set()\n",
45
- " for t in term_list:\n",
46
- " if t not in go_dag:\n",
47
- " continue\n",
48
- " full.add(t)\n",
49
- " full.update(go_dag[t].get_all_parents())\n",
50
- " return list(full & mf_terms)\n",
51
- "\n",
52
- "# --- 3. Carregar sequências --------------------------------------------\n",
53
- "seqs, ids = [], []\n",
54
- "for record in SeqIO.parse(\"uniprot_sprot_exp.fasta\", \"fasta\"):\n",
55
- " ids.append(record.id)\n",
56
- " seqs.append(str(record.seq))\n",
57
- "\n",
58
- "seq_df = pd.DataFrame({\"protein_id\": ids, \"sequence\": seqs})\n",
59
- "\n",
60
- "# --- 4. Juntar com GO anotado e propagar -------------------------------\n",
61
- "grouped = annotations_f.groupby(\"protein_id\")[\"go_term\"].apply(list).reset_index()\n",
62
- "data = seq_df.merge(grouped, on=\"protein_id\")\n",
63
- "data = data[data[\"go_term\"].apply(len) > 0]\n",
64
- "data[\"go_term\"] = data[\"go_term\"].apply(propagate_terms)\n",
65
- "data = data[data[\"go_term\"].apply(len) > 0]\n",
66
- "\n",
67
- "# --- 5. Filtrar GO terms raros -----------------------------------------\n",
68
- "# todos os terms com menos de 50 proteinas associadas\n",
69
- "all_terms = [term for sublist in data[\"go_term\"] for term in sublist]\n",
70
- "term_counts = Counter(all_terms)\n",
71
- "valid_terms = {term for term, count in term_counts.items() if count >= 50}\n",
72
- "data[\"go_term\"] = data[\"go_term\"].apply(lambda terms: [t for t in terms if t in valid_terms])\n",
73
- "data = data[data[\"go_term\"].apply(len) > 0]\n",
74
- "\n",
75
- "# --- 6. Preparar dataset final -----------------------------------------\n",
76
- "data[\"go_terms\"] = data[\"go_term\"].apply(lambda x: ';'.join(sorted(set(x))))\n",
77
- "data = data[[\"protein_id\", \"sequence\", \"go_terms\"]].drop_duplicates()\n",
78
- "\n",
79
- "# --- 7. Binarizar labels e dividir -------------------------------------\n",
80
- "mlb = MultiLabelBinarizer()\n",
81
- "Y = mlb.fit_transform(data[\"go_terms\"].str.split(\";\"))\n",
82
- "X = data[[\"protein_id\", \"sequence\"]].values\n",
83
- "\n",
84
- "mskf = MultilabelStratifiedKFold(n_splits=10, random_state=42, shuffle=True)\n",
85
- "train_idx, temp_idx = next(mskf.split(X, Y))\n",
86
- "val_idx, test_idx = np.array_split(temp_idx, 2)\n",
87
- "\n",
88
- "df_train = data.iloc[train_idx].copy()\n",
89
- "df_val = data.iloc[val_idx].copy()\n",
90
- "df_test = data.iloc[test_idx].copy()\n",
91
- "\n",
92
- "# --- 8. Guardar em CSV -------------------------------------------------\n",
93
- "os.makedirs(\"data\", exist_ok=True)\n",
94
- "df_train.to_csv(\"data/mf-training.csv\", index=False)\n",
95
- "df_val.to_csv(\"data/mf-validation.csv\", index=False)\n",
96
- "df_test.to_csv(\"data/mf-test.csv\", index=False)\n",
97
- "\n",
98
- "# --- 9. Confirmar ------------------------------------------------------\n",
99
- "print(\"✓ Ficheiros criados:\")\n",
100
- "print(\" - data/mf-training.csv :\", df_train.shape)\n",
101
- "print(\" - data/mf-validation.csv:\", df_val.shape)\n",
102
- "print(\" - data/mf-test.csv :\", df_test.shape)\n",
103
- "print(f\"GO terms únicos (após propagação e filtro): {len(mlb.classes_)}\")\n"
104
- ]
105
- },
106
- {
107
- "cell_type": "code",
108
- "execution_count": 2,
109
- "id": "6cf7aaa6-4941-4951-8d73-1f4f1f4362f3",
110
- "metadata": {},
111
- "outputs": [
112
- {
113
- "name": "stderr",
114
- "output_type": "stream",
115
- "text": [
116
- "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
117
- " from .autonotebook import tqdm as notebook_tqdm\n",
118
- "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
119
- " _torch_pytree._register_pytree_node(\n",
120
- "100%|██████████| 31142/31142 [00:26<00:00, 1192.86it/s]\n",
121
- "100%|██████████| 1724/1724 [00:00<00:00, 2570.68it/s]\n",
122
- "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\text\\preprocessor.py:382: UserWarning: The class_names argument is replacing the classes argument. Please update your code.\n",
123
- " warnings.warn(\n"
124
- ]
125
- },
126
- {
127
- "name": "stdout",
128
- "output_type": "stream",
129
- "text": [
130
- "preprocessing train...\n",
131
- "language: en\n",
132
- "train sequence lengths:\n",
133
- "\tmean : 423\n",
134
- "\t95percentile : 604\n",
135
- "\t99percentile : 715\n"
136
- ]
137
- },
138
- {
139
- "data": {
140
- "text/html": [
141
- "\n",
142
- "<style>\n",
143
- " /* Turns off some styling */\n",
144
- " progress {\n",
145
- " /* gets rid of default border in Firefox and Opera. */\n",
146
- " border: none;\n",
147
- " /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
148
- " background-size: auto;\n",
149
- " }\n",
150
- " progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
151
- " background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
152
- " }\n",
153
- " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
154
- " background: #F44336;\n",
155
- " }\n",
156
- "</style>\n"
157
- ],
158
- "text/plain": [
159
- "<IPython.core.display.HTML object>"
160
- ]
161
- },
162
- "metadata": {},
163
- "output_type": "display_data"
164
- },
165
- {
166
- "data": {
167
- "text/html": [],
168
- "text/plain": [
169
- "<IPython.core.display.HTML object>"
170
- ]
171
- },
172
- "metadata": {},
173
- "output_type": "display_data"
174
- },
175
- {
176
- "name": "stdout",
177
- "output_type": "stream",
178
- "text": [
179
- "Is Multi-Label? True\n",
180
- "preprocessing test...\n",
181
- "language: en\n",
182
- "test sequence lengths:\n",
183
- "\tmean : 408\n",
184
- "\t95percentile : 603\n",
185
- "\t99percentile : 714\n"
186
- ]
187
- },
188
- {
189
- "data": {
190
- "text/html": [
191
- "\n",
192
- "<style>\n",
193
- " /* Turns off some styling */\n",
194
- " progress {\n",
195
- " /* gets rid of default border in Firefox and Opera. */\n",
196
- " border: none;\n",
197
- " /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
198
- " background-size: auto;\n",
199
- " }\n",
200
- " progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
201
- " background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
202
- " }\n",
203
- " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
204
- " background: #F44336;\n",
205
- " }\n",
206
- "</style>\n"
207
- ],
208
- "text/plain": [
209
- "<IPython.core.display.HTML object>"
210
- ]
211
- },
212
- "metadata": {},
213
- "output_type": "display_data"
214
- },
215
- {
216
- "data": {
217
- "text/html": [],
218
- "text/plain": [
219
- "<IPython.core.display.HTML object>"
220
- ]
221
- },
222
- "metadata": {},
223
- "output_type": "display_data"
224
- },
225
- {
226
- "name": "stdout",
227
- "output_type": "stream",
228
- "text": [
229
- "\n",
230
- "\n",
231
- "begin training using triangular learning rate policy with max lr of 1e-05...\n",
232
- "Epoch 1/10\n",
233
- "40995/40995 [==============================] - 9020s 219ms/step - loss: 0.0740 - binary_accuracy: 0.9869 - val_loss: 0.0526 - val_binary_accuracy: 0.9866\n",
234
- "Epoch 2/10\n",
235
- "40995/40995 [==============================] - 8939s 218ms/step - loss: 0.0464 - binary_accuracy: 0.9877 - val_loss: 0.0457 - val_binary_accuracy: 0.9871\n",
236
- "Epoch 3/10\n",
237
- "40995/40995 [==============================] - 8881s 217ms/step - loss: 0.0413 - binary_accuracy: 0.9883 - val_loss: 0.0418 - val_binary_accuracy: 0.9877\n",
238
- "Epoch 4/10\n",
239
- "40995/40995 [==============================] - 10277s 251ms/step - loss: 0.0380 - binary_accuracy: 0.9888 - val_loss: 0.0396 - val_binary_accuracy: 0.9881\n",
240
- "Epoch 5/10\n",
241
- "40995/40995 [==============================] - 10565s 258ms/step - loss: 0.0357 - binary_accuracy: 0.9892 - val_loss: 0.0380 - val_binary_accuracy: 0.9883\n",
242
- "Epoch 6/10\n",
243
- "40995/40995 [==============================] - 10693s 261ms/step - loss: 0.0338 - binary_accuracy: 0.9895 - val_loss: 0.0369 - val_binary_accuracy: 0.9885\n",
244
- "Epoch 7/10\n",
245
- "40995/40995 [==============================] - 12055s 294ms/step - loss: 0.0323 - binary_accuracy: 0.9898 - val_loss: 0.0360 - val_binary_accuracy: 0.9888\n",
246
- "Epoch 8/10\n",
247
- "40995/40995 [==============================] - 10225s 249ms/step - loss: 0.0309 - binary_accuracy: 0.9901 - val_loss: 0.0353 - val_binary_accuracy: 0.9890\n",
248
- "Epoch 9/10\n",
249
- "40995/40995 [==============================] - 10308s 251ms/step - loss: 0.0297 - binary_accuracy: 0.9904 - val_loss: 0.0347 - val_binary_accuracy: 0.9891\n",
250
- "Epoch 10/10\n",
251
- "40995/40995 [==============================] - 10275s 251ms/step - loss: 0.0286 - binary_accuracy: 0.9907 - val_loss: 0.0346 - val_binary_accuracy: 0.9893\n",
252
- "Weights from best epoch have been loaded into model.\n"
253
- ]
254
- },
255
- {
256
- "data": {
257
- "text/plain": [
258
- "<keras.callbacks.History at 0x2b644b84fd0>"
259
- ]
260
- },
261
- "execution_count": 2,
262
- "metadata": {},
263
- "output_type": "execute_result"
264
- }
265
- ],
266
- "source": [
267
- "import pandas as pd\n",
268
- "import numpy as np\n",
269
- "from tqdm import tqdm\n",
270
- "import random\n",
271
- "import os\n",
272
- "import ktrain\n",
273
- "from ktrain import text\n",
274
- "from sklearn.preprocessing import MultiLabelBinarizer\n",
275
- "\n",
276
- "\n",
277
- "# PAM1\n",
278
- "# PAM matrix model of protein evolution\n",
279
- "# DOI:10.1093/oxfordjournals.molbev.a040360\n",
280
- "pam_data = {\n",
281
- " 'A': [9948, 19, 27, 42, 31, 46, 50, 92, 17, 7, 40, 88, 42, 41, 122, 279, 255, 9, 72, 723],\n",
282
- " 'R': [14, 9871, 24, 38, 37, 130, 38, 62, 49, 4, 58, 205, 26, 33, 47, 103, 104, 5, 36, 52],\n",
283
- " 'N': [20, 22, 9860, 181, 29, 36, 41, 67, 31, 5, 22, 49, 23, 10, 33, 83, 66, 3, 43, 32],\n",
284
- " 'D': [40, 34, 187, 9818, 11, 63, 98, 61, 23, 5, 25, 54, 43, 13, 27, 88, 55, 4, 29, 36],\n",
285
- " 'C': [20, 16, 26, 9, 9987, 10, 17, 37, 12, 2, 16, 26, 10, 19, 27, 26, 25, 2, 6, 67],\n",
286
- " 'Q': [29, 118, 29, 49, 8, 9816, 72, 55, 36, 4, 60, 158, 35, 22, 39, 86, 74, 3, 34, 28],\n",
287
- " 'E': [35, 29, 41, 101, 12, 71, 9804, 56, 33, 5, 36, 107, 42, 20, 38, 87, 69, 4, 30, 42],\n",
288
- " 'G': [96, 61, 77, 70, 38, 51, 58, 9868, 26, 6, 37, 53, 39, 28, 69, 134, 116, 5, 47, 60],\n",
289
- " 'H': [17, 53, 33, 19, 15, 39, 34, 24, 9907, 3, 32, 57, 24, 15, 27, 47, 43, 2, 22, 19],\n",
290
- " 'I': [6, 3, 6, 6, 3, 5, 6, 7, 3, 9973, 23, 13, 12, 41, 93, 84, 115, 3, 8, 102],\n",
291
- " 'L': [26, 39, 17, 15, 7, 33, 22, 20, 19, 27, 9864, 49, 24, 78, 117, 148, 193, 5, 24, 70],\n",
292
- " 'K': [60, 198, 43, 52, 12, 142, 96, 53, 42, 10, 63, 9710, 33, 26, 54, 109, 102, 5, 43, 42],\n",
293
- " 'M': [21, 22, 15, 18, 6, 20, 18, 18, 17, 11, 27, 32, 9945, 26, 34, 61, 71, 3, 12, 31],\n",
294
- " 'F': [18, 17, 8, 6, 8, 11, 10, 16, 10, 44, 92, 24, 29, 9899, 89, 88, 142, 7, 14, 68],\n",
295
- " 'P': [97, 47, 35, 29, 23, 35, 38, 57, 21, 24, 47, 56, 28, 76, 9785, 115, 77, 4, 24, 35],\n",
296
- " 'S': [241, 87, 76, 73, 17, 56, 60, 99, 32, 13, 69, 92, 42, 67, 100, 9605, 212, 8, 63, 70],\n",
297
- " 'T': [186, 78, 54, 37, 14, 42, 42, 83, 28, 23, 84, 85, 53, 93, 66, 182, 9676, 8, 39, 90],\n",
298
- " 'W': [2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 5, 3, 4, 4, 9960, 3, 4],\n",
299
- " 'Y': [29, 21, 17, 9, 4, 13, 9, 21, 10, 7, 20, 17, 11, 23, 19, 41, 31, 3, 9935, 23],\n",
300
- " 'V': [368, 27, 18, 18, 50, 23, 34, 64, 15, 85, 72, 42, 33, 88, 42, 112, 137, 4, 20, 9514]\n",
301
- "}\n",
302
- "pam_raw = pd.DataFrame(pam_data, index=list(pam_data.keys()))\n",
303
- "pam_matrix = pam_raw.div(pam_raw.sum(axis=1), axis=0)\n",
304
- "list_amino = pam_raw.columns.tolist()\n",
305
- "pam_dict = {\n",
306
- " aa: {sub: pam_matrix.loc[aa, sub] for sub in list_amino}\n",
307
- " for aa in list_amino\n",
308
- "}\n",
309
- "\n",
310
- "def pam1_substitution(aa):\n",
311
- " if aa not in pam_dict:\n",
312
- " return aa\n",
313
- " subs = list(pam_dict[aa].keys())\n",
314
- " probs = list(pam_dict[aa].values())\n",
315
- " return np.random.choice(subs, p=probs)\n",
316
- "\n",
317
- "def augment_sequence(seq, sub_prob=0.05):\n",
318
- " return ''.join([pam1_substitution(aa) if random.random() < sub_prob else aa for aa in seq])\n",
319
- "\n",
320
- "def slice_sequence(seq, win=500, min_overlap=250):\n",
321
- " if len(seq) <= win:\n",
322
- " return [seq]\n",
323
- " slices, start = [], 0\n",
324
- " while start + win <= len(seq):\n",
325
- " slices.append(seq[start:start+win])\n",
326
- " start += win\n",
327
- " leftover = seq[start:]\n",
328
- " if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
329
- " extra = slices[-1][-min_overlap:] + leftover\n",
330
- " slices.append(extra)\n",
331
- " return slices\n",
332
- "\n",
333
- "def generate_data(df, augment=False):\n",
334
- " X, y = [], []\n",
335
- " label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
336
- " for _, row in tqdm(df.iterrows(), total=len(df)):\n",
337
- " seq = row[\"sequence\"]\n",
338
- " if augment:\n",
339
- " seq = augment_sequence(seq)\n",
340
- " seq_slices = slice_sequence(seq)\n",
341
- " X.extend(seq_slices)\n",
342
- " lbl = row[label_cols].values.astype(int)\n",
343
- " y.extend([lbl] * len(seq_slices))\n",
344
- " return X, np.array(y), label_cols\n",
345
- "\n",
346
- "def format_sequence(seq): return \" \".join(list(seq))\n",
347
- "\n",
348
- "# Função para carregar e binarizar\n",
349
- "def load_and_binarize(csv_path, mlb=None):\n",
350
- " df = pd.read_csv(csv_path)\n",
351
- " df[\"go_terms\"] = df[\"go_terms\"].str.split(\";\")\n",
352
- " if mlb is None:\n",
353
- " mlb = MultiLabelBinarizer()\n",
354
- " labels = mlb.fit_transform(df[\"go_terms\"])\n",
355
- " else:\n",
356
- " labels = mlb.transform(df[\"go_terms\"])\n",
357
- " labels_df = pd.DataFrame(labels, columns=mlb.classes_)\n",
358
- " df = df.reset_index(drop=True).join(labels_df)\n",
359
- " return df, mlb\n",
360
- "\n",
361
- "# Carregar os dados\n",
362
- "df_train, mlb = load_and_binarize(\"data/mf-training.csv\")\n",
363
- "df_val, _ = load_and_binarize(\"data/mf-validation.csv\", mlb=mlb)\n",
364
- "\n",
365
- "# Gerar com augmentation no treino\n",
366
- "X_train, y_train, term_cols = generate_data(df_train, augment=True)\n",
367
- "X_val, y_val, _ = generate_data(df_val, augment=False)\n",
368
- "\n",
369
- "# Preparar texto para tokenizer\n",
370
- "X_train_fmt = list(map(format_sequence, X_train))\n",
371
- "X_val_fmt = list(map(format_sequence, X_val))\n",
372
- "\n",
373
- "# Fine-tune ProtBERT\n",
374
- "# https://huggingface.co/Rostlab/prot_bert\n",
375
- "# https://doi.org/10.1093/bioinformatics/btac020\n",
376
- "# dados de treino-> UniRef100 (216 milhões de sequências)\n",
377
- "MODEL_NAME = \"Rostlab/prot_bert_bfd\"\n",
378
- "MAX_LEN = 512\n",
379
- "BATCH_SIZE = 1\n",
380
- "\n",
381
- "t = text.Transformer(MODEL_NAME, maxlen=MAX_LEN, classes=term_cols)\n",
382
- "trn = t.preprocess_train(X_train_fmt, y_train)\n",
383
- "val = t.preprocess_test(X_val_fmt, y_val)\n",
384
- "\n",
385
- "model = t.get_classifier()\n",
386
- "learner = ktrain.get_learner(model,\n",
387
- " train_data=trn,\n",
388
- " val_data=val,\n",
389
- " batch_size=BATCH_SIZE)\n",
390
- "\n",
391
- "learner.autofit(lr=1e-5,\n",
392
- " epochs=10,\n",
393
- " early_stopping=1,\n",
394
- " checkpoint_folder=\"mf-fine-tuned-protbertbfd\")\n"
395
- ]
396
- },
397
- {
398
- "cell_type": "code",
399
- "execution_count": 6,
400
- "id": "c66774b3-6cf0-41c5-bb01-9467a5283102",
401
- "metadata": {},
402
- "outputs": [
403
- {
404
- "name": "stdout",
405
- "output_type": "stream",
406
- "text": [
407
- "✅ Existe: weights/mf-fine-tuned-protbertbfd\n",
408
- "📁 Conteúdo:\n",
409
- " - config.json\n",
410
- " - tf_model.h5\n"
411
- ]
412
- }
413
- ],
414
- "source": [
415
- "import os\n",
416
- "learner.save_model('weights/mf-fine-tuned-protbertbfd')\n",
417
- "path = \"weights/mf-fine-tuned-protbertbfd\"\n",
418
- "\n",
419
- "if os.path.exists(path):\n",
420
- " print(f\"✅ Existe: {path}\")\n",
421
- " print(\"📁 Conteúdo:\")\n",
422
- " for f in os.listdir(path):\n",
423
- " print(\" -\", f)\n",
424
- "else:\n",
425
- " print(f\"❌ Não existe: {path}\")\n",
426
- "\n"
427
- ]
428
- },
429
- {
430
- "cell_type": "code",
431
- "execution_count": 8,
432
- "id": "9b39c439-5708-4787-bfee-d3a4d3aa190d",
433
- "metadata": {},
434
- "outputs": [
435
- {
436
- "name": "stdout",
437
- "output_type": "stream",
438
- "text": [
439
- "✓ Tokenizer base e modelo fine-tuned carregados com sucesso\n"
440
- ]
441
- },
442
- {
443
- "name": "stderr",
444
- "output_type": "stream",
445
- "text": [
446
- "Processando data/mf-training.csv: 100%|██████████| 31142/31142 [5:17:56<00:00, 1.63it/s] \n"
447
- ]
448
- },
449
- {
450
- "name": "stdout",
451
- "output_type": "stream",
452
- "text": [
453
- "✓ Guardado embeddings\\train_protbertbfd.pkl — 31142 proteínas\n"
454
- ]
455
- },
456
- {
457
- "name": "stderr",
458
- "output_type": "stream",
459
- "text": [
460
- "Processando data/mf-validation.csv: 100%|██████████| 1724/1724 [19:15<00:00, 1.49it/s]\n"
461
- ]
462
- },
463
- {
464
- "name": "stdout",
465
- "output_type": "stream",
466
- "text": [
467
- "✓ Guardado embeddings\\val_protbertbfd.pkl — 1724 proteínas\n"
468
- ]
469
- },
470
- {
471
- "name": "stderr",
472
- "output_type": "stream",
473
- "text": [
474
- "Processando data/mf-test.csv: 100%|██████████| 1724/1724 [17:15<00:00, 1.66it/s]\n"
475
- ]
476
- },
477
- {
478
- "name": "stdout",
479
- "output_type": "stream",
480
- "text": [
481
- "✓ Guardado embeddings\\test_protbertbfd.pkl — 1724 proteínas\n"
482
- ]
483
- }
484
- ],
485
- "source": [
486
- "import os\n",
487
- "import pandas as pd\n",
488
- "import numpy as np\n",
489
- "from tqdm import tqdm\n",
490
- "import joblib\n",
491
- "import gc\n",
492
- "from transformers import AutoTokenizer, TFAutoModel\n",
493
- "\n",
494
- "# --- 1. Parâmetros --------------------------------------------------------\n",
495
- "MODEL_DIR = \"weights/mf-fine-tuned-protbertbfd\"\n",
496
- "MODEL_NAME = \"Rostlab/prot_bert_bfd\"\n",
497
- "OUT_DIR = \"embeddings\"\n",
498
- "BATCH_TOK = 16\n",
499
- "\n",
500
- "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)\n",
501
- "model = TFAutoModel.from_pretrained(MODEL_DIR, from_pt=False)\n",
502
- "\n",
503
- "print(\"✓ Tokenizer base e modelo fine-tuned carregados com sucesso\")\n",
504
- "\n",
505
- "# --- 3. Funções auxiliares ------------------------------------------------\n",
506
- "def format_sequence(seq):\n",
507
- " return \" \".join(list(seq))\n",
508
- "\n",
509
- "def slice_sequence(seq, win=500, min_overlap=250):\n",
510
- " if len(seq) <= win:\n",
511
- " return [seq]\n",
512
- " slices, start = [], 0\n",
513
- " while start + win <= len(seq):\n",
514
- " slices.append(seq[start:start+win])\n",
515
- " start += win\n",
516
- " leftover = seq[start:]\n",
517
- " if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
518
- " extra = slices[-1][-min_overlap:] + leftover\n",
519
- " slices.append(extra)\n",
520
- " return slices\n",
521
- "\n",
522
- "def get_embeddings(batch, tokenizer, model):\n",
523
- " tokens = tokenizer(batch, return_tensors=\"tf\", padding=True, truncation=True, max_length=512)\n",
524
- " output = model(**tokens)\n",
525
- " return output.last_hidden_state[:, 0, :].numpy()\n",
526
- "\n",
527
- "def process_split(csv_path, out_path):\n",
528
- " df = pd.read_csv(csv_path)\n",
529
- " label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
530
- " prot_ids, embeds, labels = [], [], []\n",
531
- "\n",
532
- " for _, row in tqdm(df.iterrows(), total=len(df), desc=f\"Processando {csv_path}\"):\n",
533
- " slices = slice_sequence(row[\"sequence\"])\n",
534
- " slices_fmt = list(map(format_sequence, slices))\n",
535
- "\n",
536
- " slice_embeds = []\n",
537
- " for i in range(0, len(slices_fmt), BATCH_TOK):\n",
538
- " batch = slices_fmt[i:i+BATCH_TOK]\n",
539
- " slice_embeds.append(get_embeddings(batch, tokenizer, model))\n",
540
- " slice_embeds = np.vstack(slice_embeds)\n",
541
- "\n",
542
- " prot_embed = slice_embeds.mean(axis=0)\n",
543
- " prot_ids.append(row[\"protein_id\"])\n",
544
- " embeds.append(prot_embed.astype(np.float32))\n",
545
- " labels.append(row[label_cols].values.astype(np.int8))\n",
546
- " gc.collect()\n",
547
- "\n",
548
- " embeds = np.vstack(embeds)\n",
549
- " labels = np.vstack(labels)\n",
550
- "\n",
551
- " joblib.dump({\n",
552
- " \"protein_ids\": prot_ids,\n",
553
- " \"embeddings\": embeds,\n",
554
- " \"labels\": labels,\n",
555
- " \"go_terms\": label_cols\n",
556
- " }, out_path, compress=3)\n",
557
- "\n",
558
- " print(f\"✓ Guardado {out_path} — {embeds.shape[0]} proteínas\")\n",
559
- "\n",
560
- "# --- 4. Aplicar -----------------------------------------------------------\n",
561
- "os.makedirs(OUT_DIR, exist_ok=True)\n",
562
- "\n",
563
- "process_split(\"data/mf-training.csv\", os.path.join(OUT_DIR, \"train_protbertbfd.pkl\"))\n",
564
- "process_split(\"data/mf-validation.csv\", os.path.join(OUT_DIR, \"val_protbertbfd.pkl\"))\n",
565
- "process_split(\"data/mf-test.csv\", os.path.join(OUT_DIR, \"test_protbertbfd.pkl\"))\n"
566
- ]
567
- },
568
- {
569
- "cell_type": "code",
570
- "execution_count": 9,
571
- "id": "ad0c5421-e0a1-4a6a-8ace-2c69aeab0e0d",
572
- "metadata": {},
573
- "outputs": [
574
- {
575
- "name": "stdout",
576
- "output_type": "stream",
577
- "text": [
578
- "✓ Corrigido: embeddings/train_protbertbfd.pkl — 31142 exemplos, 597 GO terms\n",
579
- "✓ Corrigido: embeddings/val_protbertbfd.pkl — 1724 exemplos, 597 GO terms\n",
580
- "✓ Corrigido: embeddings/test_protbertbfd.pkl — 1724 exemplos, 597 GO terms\n"
581
- ]
582
- }
583
- ],
584
- "source": [
585
- "import pandas as pd\n",
586
- "import joblib\n",
587
- "from sklearn.preprocessing import MultiLabelBinarizer\n",
588
- "\n",
589
- "# --- 1. Obter GO terms do ficheiro de teste --------------------------------\n",
590
- "df_test = pd.read_csv(\"data/mf-test.csv\")\n",
591
- "test_terms = sorted(set(term for row in df_test[\"go_terms\"].str.split(\";\") for term in row))\n",
592
- "\n",
593
- "# --- 2. Função para corrigir um .pkl com base nos GO terms do teste --------\n",
594
- "def patch_to_common_terms(csv_path, pkl_path, common_terms):\n",
595
- " df = pd.read_csv(csv_path)\n",
596
- " terms_split = df[\"go_terms\"].str.split(\";\")\n",
597
- " \n",
598
- " # Apenas termos presentes nos common_terms\n",
599
- " terms_filtered = terms_split.apply(lambda lst: [t for t in lst if t in common_terms])\n",
600
- " \n",
601
- " mlb = MultiLabelBinarizer(classes=common_terms)\n",
602
- " Y = mlb.fit_transform(terms_filtered)\n",
603
- "\n",
604
- " data = joblib.load(pkl_path)\n",
605
- " data[\"labels\"] = Y\n",
606
- " data[\"go_terms\"] = mlb.classes_.tolist()\n",
607
- " \n",
608
- " joblib.dump(data, pkl_path, compress=3)\n",
609
- " print(f\"✓ Corrigido: {pkl_path} — {Y.shape[0]} exemplos, {Y.shape[1]} GO terms\")\n",
610
- "\n",
611
- "# --- 3. Aplicar às 3 partições --------------------------------------------\n",
612
- "patch_to_common_terms(\"data/mf-training.csv\", \"embeddings/train_protbertbfd.pkl\", test_terms)\n",
613
- "patch_to_common_terms(\"data/mf-validation.csv\", \"embeddings/val_protbertbfd.pkl\", test_terms)\n",
614
- "patch_to_common_terms(\"data/mf-test.csv\", \"embeddings/test_protbertbfd.pkl\", test_terms)\n"
615
- ]
616
- },
617
- {
618
- "cell_type": "code",
619
- "execution_count": 2,
620
- "id": "dbd5c35f-4a08-4906-9cf4-e1df501d1ecb",
621
- "metadata": {},
622
- "outputs": [],
623
- "source": [
624
- "import joblib\n",
625
- "train = joblib.load(\"embeddings/train_protbertbfd.pkl\")\n",
626
- "val = joblib.load(\"embeddings/val_protbertbfd.pkl\")\n",
627
- "test = joblib.load(\"embeddings/test_protbertbfd.pkl\")\n",
628
- "\n",
629
- "X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
630
- "X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
631
- "X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n"
632
- ]
633
- },
634
- {
635
- "cell_type": "code",
636
- "execution_count": 5,
637
- "id": "1785d8a9-23fc-4490-8d71-29cc91a4cb57",
638
- "metadata": {},
639
- "outputs": [
640
- {
641
- "name": "stdout",
642
- "output_type": "stream",
643
- "text": [
644
- "✓ Embeddings carregados: (31142, 1024) → 597 GO terms\n",
645
- "Epoch 1/100\n",
646
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0337 - binary_accuracy: 0.9901 - val_loss: 0.0331 - val_binary_accuracy: 0.9905\n",
647
- "Epoch 2/100\n",
648
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0252 - binary_accuracy: 0.9921 - val_loss: 0.0326 - val_binary_accuracy: 0.9905\n",
649
- "Epoch 3/100\n",
650
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0244 - binary_accuracy: 0.9924 - val_loss: 0.0330 - val_binary_accuracy: 0.9905\n",
651
- "Epoch 4/100\n",
652
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0240 - binary_accuracy: 0.9925 - val_loss: 0.0322 - val_binary_accuracy: 0.9907\n",
653
- "Epoch 5/100\n",
654
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0236 - binary_accuracy: 0.9925 - val_loss: 0.0328 - val_binary_accuracy: 0.9907\n",
655
- "Epoch 6/100\n",
656
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0232 - binary_accuracy: 0.9926 - val_loss: 0.0325 - val_binary_accuracy: 0.9908\n",
657
- "Epoch 7/100\n",
658
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0231 - binary_accuracy: 0.9926 - val_loss: 0.0325 - val_binary_accuracy: 0.9907\n",
659
- "Epoch 8/100\n",
660
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0228 - binary_accuracy: 0.9927 - val_loss: 0.0326 - val_binary_accuracy: 0.9908\n",
661
- "Epoch 9/100\n",
662
- "974/974 [==============================] - 3s 3ms/step - loss: 0.0226 - binary_accuracy: 0.9927 - val_loss: 0.0326 - val_binary_accuracy: 0.9908\n",
663
- "Previsões guardadas em mf-protbertbfd-pam1.npy\n",
664
- "Modelo guardado em models/protbertbfd_mlp.keras\n"
665
- ]
666
- }
667
- ],
668
- "source": [
669
- "import tensorflow as tf\n",
670
- "import joblib\n",
671
- "import numpy as np\n",
672
- "from tensorflow.keras.models import Sequential\n",
673
- "from tensorflow.keras.layers import Dense, Dropout\n",
674
- "from tensorflow.keras.callbacks import EarlyStopping\n",
675
- "\n",
676
- "# --- 1. Carregar embeddings ----------------------------------------------\n",
677
- "train = joblib.load(\"embeddings/train_protbertbfd.pkl\")\n",
678
- "val = joblib.load(\"embeddings/val_protbertbfd.pkl\")\n",
679
- "test = joblib.load(\"embeddings/test_protbertbfd.pkl\")\n",
680
- "\n",
681
- "X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
682
- "X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
683
- "X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n",
684
- "\n",
685
- "print(f\"✓ Embeddings carregados: {X_train.shape} → {y_train.shape[1]} GO terms\")\n",
686
- "\n",
687
- "# --- 2. Garantir consistência de classes ---------------------------------\n",
688
- "max_classes = y_train.shape[1] # 602 GO terms (do treino)\n",
689
- "\n",
690
- "def pad_labels(y, target_dim=max_classes):\n",
691
- " if y.shape[1] < target_dim:\n",
692
- " padding = np.zeros((y.shape[0], target_dim - y.shape[1]), dtype=np.int8)\n",
693
- " return np.hstack([y, padding])\n",
694
- " return y\n",
695
- "\n",
696
- "y_val = pad_labels(y_val)\n",
697
- "y_test = pad_labels(y_test)\n",
698
- "\n",
699
- "# --- 3. Modelo MLP ------------------------------------------------------\n",
700
- "model = Sequential([\n",
701
- " Dense(1024, activation=\"relu\", input_shape=(X_train.shape[1],)),\n",
702
- " Dropout(0.3),\n",
703
- " Dense(512, activation=\"relu\"),\n",
704
- " Dropout(0.3),\n",
705
- " Dense(max_classes, activation=\"sigmoid\")\n",
706
- "])\n",
707
- "\n",
708
- "model.compile(loss=\"binary_crossentropy\",\n",
709
- " optimizer=\"adam\",\n",
710
- " metrics=[\"binary_accuracy\"])\n",
711
- "\n",
712
- "# --- 4. Early stopping e treino -----------------------------------------\n",
713
- "callbacks = [\n",
714
- " EarlyStopping(monitor=\"val_loss\", patience=5, restore_best_weights=True)\n",
715
- "]\n",
716
- "\n",
717
- "model.fit(X_train, y_train,\n",
718
- " validation_data=(X_val, y_val),\n",
719
- " epochs=100,\n",
720
- " batch_size=32,\n",
721
- " callbacks=callbacks,\n",
722
- " verbose=1)\n",
723
- "\n",
724
- "# --- 5. Previsões --------------------------------------------------------\n",
725
- "y_prob = model.predict(X_test)\n",
726
- "np.save(\"predictions/mf-protbertbfd-pam1.npy\", y_prob)\n",
727
- "print(\"Previsões guardadas em mf-protbertbfd-pam1.npy\")\n",
728
- "\n",
729
- "# --- 6. Modelo ----------------------------------------------------------\n",
730
- "model.save(\"models/protbertbfd_mlp.keras\")\n",
731
- "print(\"Modelo guardado em models/protbertbfd_mlp.keras\")"
732
- ]
733
- },
734
- {
735
- "cell_type": "code",
736
- "execution_count": 12,
737
- "id": "fdb66630-76dc-43a0-bd56-45052175fdba",
738
- "metadata": {},
739
- "outputs": [
740
- {
741
- "name": "stdout",
742
- "output_type": "stream",
743
- "text": [
744
- "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
745
- "✓ Embeddings: (1724, 597) labels × 597 GO terms\n",
746
- "\n",
747
- "📊 Resultados finais (ProtBERTBFD + PAM1 + propagação):\n",
748
- "Fmax = 0.6570\n",
749
- "Thr. = 0.41\n",
750
- "AuPRC = 0.6929\n",
751
- "Smin = 13.8114\n"
752
- ]
753
- }
754
- ],
755
- "source": [
756
- "import numpy as np\n",
757
- "from sklearn.metrics import precision_recall_curve, auc\n",
758
- "from goatools.obo_parser import GODag\n",
759
- "import joblib\n",
760
- "import math\n",
761
- "\n",
762
- "# --- 1. Parâmetros -------------------------------------------------------\n",
763
- "GO_FILE = \"go.obo\"\n",
764
- "THRESHOLDS = np.arange(0.0, 1.01, 0.01)\n",
765
- "ALPHA = 0.5\n",
766
- "\n",
767
- "# --- 2. Carregar dados ---------------------------------------------------\n",
768
- "test = joblib.load(\"embeddings/test_protbertbfd.pkl\")\n",
769
- "y_true = test[\"labels\"]\n",
770
- "terms = test[\"go_terms\"]\n",
771
- "y_prob = np.load(\"predictions/mf-protbertbfd-pam1.npy\")\n",
772
- "go_dag = GODag(GO_FILE)\n",
773
- "\n",
774
- "print(f\"✓ Embeddings: {y_true.shape} labels × {len(terms)} GO terms\")\n",
775
- "\n",
776
- "# --- 3. Fmax -------------------------------------------------------------\n",
777
- "def compute_fmax(y_true, y_prob, thresholds):\n",
778
- " fmax, best_thr = 0, 0\n",
779
- " for t in thresholds:\n",
780
- " y_pred = (y_prob >= t).astype(int)\n",
781
- " tp = (y_true * y_pred).sum(axis=1)\n",
782
- " fp = ((1 - y_true) * y_pred).sum(axis=1)\n",
783
- " fn = (y_true * (1 - y_pred)).sum(axis=1)\n",
784
- " precision = tp / (tp + fp + 1e-8)\n",
785
- " recall = tp / (tp + fn + 1e-8)\n",
786
- " f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
787
- " avg_f1 = np.mean(f1)\n",
788
- " if avg_f1 > fmax:\n",
789
- " fmax, best_thr = avg_f1, t\n",
790
- " return fmax, best_thr\n",
791
- "\n",
792
- "# --- 4. AuPRC micro ------------------------------------------------------\n",
793
- "def compute_auprc(y_true, y_prob):\n",
794
- " precision, recall, _ = precision_recall_curve(y_true.ravel(), y_prob.ravel())\n",
795
- " return auc(recall, precision)\n",
796
- "\n",
797
- "# --- 5. Smin -------------------------------------------------------------\n",
798
- "def compute_smin(y_true, y_prob, terms, threshold, go_dag, alpha=ALPHA):\n",
799
- " y_pred = (y_prob >= threshold).astype(int)\n",
800
- " ic = {}\n",
801
- " total = (y_true + y_pred).sum(axis=0).sum()\n",
802
- " for i, term in enumerate(terms):\n",
803
- " freq = (y_true[:, i] + y_pred[:, i]).sum()\n",
804
- " ic[term] = -np.log((freq + 1e-8) / total)\n",
805
- "\n",
806
- " s_values = []\n",
807
- " for true_vec, pred_vec in zip(y_true, y_pred):\n",
808
- " true_terms = {terms[i] for i in np.where(true_vec)[0]}\n",
809
- " pred_terms = {terms[i] for i in np.where(pred_vec)[0]}\n",
810
- "\n",
811
- " anc_true = set()\n",
812
- " for t in true_terms:\n",
813
- " if t in go_dag:\n",
814
- " anc_true |= go_dag[t].get_all_parents()\n",
815
- " anc_pred = set()\n",
816
- " for t in pred_terms:\n",
817
- " if t in go_dag:\n",
818
- " anc_pred |= go_dag[t].get_all_parents()\n",
819
- "\n",
820
- " ru = pred_terms - true_terms\n",
821
- " mi = true_terms - pred_terms\n",
822
- " dist_ru = sum(ic.get(t, 0) for t in ru)\n",
823
- " dist_mi = sum(ic.get(t, 0) for t in mi)\n",
824
- " s = math.sqrt((alpha * dist_ru)**2 + ((1 - alpha) * dist_mi)**2)\n",
825
- " s_values.append(s)\n",
826
- "\n",
827
- " return np.mean(s_values)\n",
828
- "\n",
829
- "# --- 6. Avaliar ----------------------------------------------------------\n",
830
- "fmax, thr = compute_fmax(y_true, y_prob, THRESHOLDS)\n",
831
- "auprc = compute_auprc(y_true, y_prob)\n",
832
- "smin = compute_smin(y_true, y_prob, terms, thr, go_dag)\n",
833
- "\n",
834
- "print(f\"\\n📊 Resultados finais (ProtBERTBFD + PAM1 + propagação):\")\n",
835
- "print(f\"Fmax = {fmax:.4f}\")\n",
836
- "print(f\"Thr. = {thr:.2f}\")\n",
837
- "print(f\"AuPRC = {auprc:.4f}\")\n",
838
- "print(f\"Smin = {smin:.4f}\")\n"
839
- ]
840
- },
841
- {
842
- "cell_type": "code",
843
- "execution_count": null,
844
- "id": "70d131ef-ef84-42ee-953b-0d3f1268694d",
845
- "metadata": {},
846
- "outputs": [],
847
- "source": []
848
- }
849
- ],
850
- "metadata": {
851
- "kernelspec": {
852
- "display_name": "Python 3 (ipykernel)",
853
- "language": "python",
854
- "name": "python3"
855
- },
856
- "language_info": {
857
- "codemirror_mode": {
858
- "name": "ipython",
859
- "version": 3
860
- },
861
- "file_extension": ".py",
862
- "mimetype": "text/x-python",
863
- "name": "python",
864
- "nbconvert_exporter": "python",
865
- "pygments_lexer": "ipython3",
866
- "version": "3.8.18"
867
- }
868
- },
869
- "nbformat": 4,
870
- "nbformat_minor": 5
871
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/keras_models_fix.ipynb DELETED
@@ -1,94 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 3,
6
- "id": "39935741-50be-4766-873c-99f3c3f14e55",
7
- "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stdout",
11
- "output_type": "stream",
12
- "text": [
13
- "A converter modelos .keras para .h5...\n",
14
- "\n",
15
- "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
16
- "Guardado com sucesso\n",
17
- "\n",
18
- "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
19
- "Guardado com sucesso\n",
20
- "\n",
21
- "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
22
- "Guardado com sucesso\n",
23
- "\n",
24
- "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
25
- "Guardado com sucesso\n",
26
- "\n",
27
- "Conversão concluída.\n"
28
- ]
29
- }
30
- ],
31
- "source": [
32
- "import os\n",
33
- "from tensorflow.keras.models import load_model\n",
34
- "\n",
35
- "MODELS_DIR = \"models\"\n",
36
- "\n",
37
- "# Modelos a converter: (nome original .keras → novo nome .h5)\n",
38
- "modelos = [\n",
39
- " (\"protbert_mlp.keras\", \"mlp_protbert.h5\"),\n",
40
- " (\"protbertbfd_mlp.keras\", \"mlp_protbertbfd.h5\"),\n",
41
- " (\"esm2_mlp.keras\", \"mlp_esm2.h5\"),\n",
42
- " (\"modelo_ensemble_stacking.keras\", \"modelo_ensemble_stack.h5\"),\n",
43
- "]\n",
44
- "\n",
45
- "print(\"A converter modelos .keras para .h5...\\n\")\n",
46
- "\n",
47
- "for origem, destino in modelos:\n",
48
- " origem_path = os.path.join(MODELS_DIR, origem)\n",
49
- " destino_path = os.path.join(MODELS_DIR, destino)\n",
50
- "\n",
51
- " if not os.path.exists(origem_path):\n",
52
- " print(f\"Ficheiro não encontrado: {origem_path}\")\n",
53
- " continue\n",
54
- "\n",
55
- "\n",
56
- " model = load_model(origem_path, compile=False)\n",
57
- " model.save(destino_path)\n",
58
- "\n",
59
- " print(\"Guardado com sucesso\\n\")\n",
60
- "\n",
61
- "print(\"Conversão concluída.\")\n"
62
- ]
63
- },
64
- {
65
- "cell_type": "code",
66
- "execution_count": null,
67
- "id": "c0f2a4b7-d3eb-48a7-b97e-213b58b2b2ca",
68
- "metadata": {},
69
- "outputs": [],
70
- "source": []
71
- }
72
- ],
73
- "metadata": {
74
- "kernelspec": {
75
- "display_name": "Python 3 (ipykernel)",
76
- "language": "python",
77
- "name": "python3"
78
- },
79
- "language_info": {
80
- "codemirror_mode": {
81
- "name": "ipython",
82
- "version": 3
83
- },
84
- "file_extension": ".py",
85
- "mimetype": "text/x-python",
86
- "name": "python",
87
- "nbconvert_exporter": "python",
88
- "pygments_lexer": "ipython3",
89
- "version": "3.8.18"
90
- }
91
- },
92
- "nbformat": 4,
93
- "nbformat_minor": 5
94
- }