Delete notebooks
Browse files- notebooks/Ensemble.ipynb +0 -292
- notebooks/Input.ipynb +0 -157
- notebooks/PAM1_ESM2.ipynb +0 -561
- notebooks/PAM1_protbert.ipynb +0 -935
- notebooks/PAM1_protbertBFD.ipynb +0 -871
- notebooks/keras_models_fix.ipynb +0 -94
notebooks/Ensemble.ipynb
DELETED
|
@@ -1,292 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 1,
|
| 6 |
-
"id": "0fbbb46c-1a00-4585-9ecd-a490a46e8b99",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [
|
| 9 |
-
{
|
| 10 |
-
"name": "stdout",
|
| 11 |
-
"output_type": "stream",
|
| 12 |
-
"text": [
|
| 13 |
-
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
|
| 14 |
-
"mlb com 597 GO terms guardado em data/mlb_597.pkl\n",
|
| 15 |
-
" ProtBERT Fmax=0.6616 Thr=0.40 AuPRC=0.7009 Smin=13.9047\n",
|
| 16 |
-
" ProtBERT-BFD Fmax=0.6573 Thr=0.41 AuPRC=0.6925 Smin=13.7060\n",
|
| 17 |
-
" ESM-2 Fmax=0.6375 Thr=0.39 AuPRC=0.6875 Smin=14.1194\n",
|
| 18 |
-
" Ensemble Fmax=0.6864 Thr=0.37 AuPRC=0.7332 Smin=12.7879\n"
|
| 19 |
-
]
|
| 20 |
-
}
|
| 21 |
-
],
|
| 22 |
-
"source": [
|
| 23 |
-
"# %%\n",
|
| 24 |
-
"import numpy as np, joblib, math\n",
|
| 25 |
-
"from sklearn.metrics import precision_recall_curve, auc\n",
|
| 26 |
-
"from goatools.obo_parser import GODag\n",
|
| 27 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 28 |
-
"\n",
|
| 29 |
-
"GO_FILE = \"go.obo\"\n",
|
| 30 |
-
"dag = GODag(GO_FILE)\n",
|
| 31 |
-
"\n",
|
| 32 |
-
"# ---------- 1. y_true + GO terms (referência ProtBERT) ----------\n",
|
| 33 |
-
"test_pb = joblib.load(\"embeddings/test_protbert.pkl\")\n",
|
| 34 |
-
"y_true = test_pb[\"labels\"] # (1724, 597) ← ground-truth\n",
|
| 35 |
-
"go_ref = list(test_pb[\"go_terms\"]) # ordem exacta das colunas\n",
|
| 36 |
-
"\n",
|
| 37 |
-
"n_go = len(go_ref) # 597\n",
|
| 38 |
-
"\n",
|
| 39 |
-
"# --- Recriar o MultiLabelBinarizer com os 597 termos corretos ---\n",
|
| 40 |
-
"mlb = MultiLabelBinarizer(classes=go_ref)\n",
|
| 41 |
-
"mlb.fit([go_ref]) # necessário para permitir inverse_transform depois\n",
|
| 42 |
-
"\n",
|
| 43 |
-
"# ---------- 2. Carregar predições ----------\n",
|
| 44 |
-
"y_pb = np.load(\"predictions/mf-protbert-pam1.npy\") # 1724×597\n",
|
| 45 |
-
"y_bfd = np.load(\"predictions/mf-protbertbfd-pam1.npy\") # 1724×597\n",
|
| 46 |
-
"y_esm0 = np.load(\"predictions/mf-esm2.npy\") # 1724×602\n",
|
| 47 |
-
"\n",
|
| 48 |
-
"# ---------- 3. Remapear ESM-2 para ordem ProtBERT ----------\n",
|
| 49 |
-
"mlb_esm = joblib.load(\"data/mlb.pkl\") # 602 GO terms\n",
|
| 50 |
-
"idx_map = [list(mlb_esm.classes_).index(t) for t in go_ref]\n",
|
| 51 |
-
"y_esm = y_esm0[:, idx_map] # 1724×597\n",
|
| 52 |
-
"\n",
|
| 53 |
-
"# ---------- 4. Garantir shapes iguais ----------\n",
|
| 54 |
-
"assert (y_true.shape == y_pb.shape == y_bfd.shape\n",
|
| 55 |
-
" == y_esm.shape == (1724, n_go)), \"Ainda há desalinhamento!\"\n",
|
| 56 |
-
"\n",
|
| 57 |
-
"# ---------- 4. Guardar mlb (y_true) alinhado ----------\n",
|
| 58 |
-
"joblib.dump(mlb, \"data/mlb_597.pkl\")\n",
|
| 59 |
-
"print(\"mlb com 597 GO terms guardado em data/mlb_597.pkl\")\n",
|
| 60 |
-
"\n",
|
| 61 |
-
"# ---------- 5. Métricas ----------\n",
|
| 62 |
-
"THR = np.linspace(0,1,101)\n",
|
| 63 |
-
"def fmax(y_t,y_p):\n",
|
| 64 |
-
" best,thr = 0,0\n",
|
| 65 |
-
" for t in THR:\n",
|
| 66 |
-
" y_b = (y_p>=t).astype(int)\n",
|
| 67 |
-
" tp = (y_t*y_b).sum(1); fp=((1-y_t)*y_b).sum(1); fn=(y_t*(1-y_b)).sum(1)\n",
|
| 68 |
-
" f1 = 2*tp/(2*tp+fp+fn+1e-8); m=f1.mean()\n",
|
| 69 |
-
" if m>best: best,thr = m,t\n",
|
| 70 |
-
" return best,thr\n",
|
| 71 |
-
"\n",
|
| 72 |
-
"def auprc(y_t,y_p):\n",
|
| 73 |
-
" p,r,_ = precision_recall_curve(y_t.ravel(), y_p.ravel()); return auc(r,p)\n",
|
| 74 |
-
"\n",
|
| 75 |
-
"def smin(y_t,y_p,thr,alpha=0.5):\n",
|
| 76 |
-
" y_b=(y_p>=thr).astype(int)\n",
|
| 77 |
-
" ic=-(np.log((y_t+y_b).sum(0)+1e-8)-np.log((y_t+y_b).sum()+1e-8))\n",
|
| 78 |
-
" ru=np.logical_and(y_b, np.logical_not(y_t))*ic\n",
|
| 79 |
-
" mi=np.logical_and(y_t, np.logical_not(y_b))*ic\n",
|
| 80 |
-
" return np.sqrt((alpha*ru.sum(1))**2 + ((1-alpha)*mi.sum(1))**2).mean()\n",
|
| 81 |
-
"\n",
|
| 82 |
-
"def show(name,y_p):\n",
|
| 83 |
-
" f,thr=fmax(y_true,y_p)\n",
|
| 84 |
-
" print(f\"{name:>13s} Fmax={f:.4f} Thr={thr:.2f} \"\n",
|
| 85 |
-
" f\"AuPRC={auprc(y_true,y_p):.4f} Smin={smin(y_true,y_p,thr):.4f}\")\n",
|
| 86 |
-
"\n",
|
| 87 |
-
"show(\"ProtBERT\", y_pb)\n",
|
| 88 |
-
"show(\"ProtBERT-BFD\", y_bfd)\n",
|
| 89 |
-
"show(\"ESM-2\", y_esm)\n",
|
| 90 |
-
"show(\"Ensemble\", (y_pb + y_bfd + y_esm)/3)\n",
|
| 91 |
-
"\n"
|
| 92 |
-
]
|
| 93 |
-
},
|
| 94 |
-
{
|
| 95 |
-
"cell_type": "code",
|
| 96 |
-
"execution_count": 9,
|
| 97 |
-
"id": "f1807404-c2ce-48d0-b87c-a7e0fecc1728",
|
| 98 |
-
"metadata": {},
|
| 99 |
-
"outputs": [
|
| 100 |
-
{
|
| 101 |
-
"name": "stdout",
|
| 102 |
-
"output_type": "stream",
|
| 103 |
-
"text": [
|
| 104 |
-
"Epoch 1/50\n",
|
| 105 |
-
"19/19 [==============================] - 0s 17ms/step - loss: 0.3811 - val_loss: 0.0868\n",
|
| 106 |
-
"Epoch 2/50\n",
|
| 107 |
-
"19/19 [==============================] - 0s 8ms/step - loss: 0.0882 - val_loss: 0.0696\n",
|
| 108 |
-
"Epoch 3/50\n",
|
| 109 |
-
"19/19 [==============================] - 0s 6ms/step - loss: 0.0628 - val_loss: 0.0563\n",
|
| 110 |
-
"Epoch 4/50\n",
|
| 111 |
-
"19/19 [==============================] - 0s 6ms/step - loss: 0.0552 - val_loss: 0.0520\n",
|
| 112 |
-
"Epoch 5/50\n",
|
| 113 |
-
"19/19 [==============================] - 0s 5ms/step - loss: 0.0507 - val_loss: 0.0486\n",
|
| 114 |
-
"Epoch 6/50\n",
|
| 115 |
-
"19/19 [==============================] - 0s 5ms/step - loss: 0.0473 - val_loss: 0.0455\n",
|
| 116 |
-
"Epoch 7/50\n",
|
| 117 |
-
"19/19 [==============================] - 0s 5ms/step - loss: 0.0437 - val_loss: 0.0431\n",
|
| 118 |
-
"Epoch 8/50\n",
|
| 119 |
-
"19/19 [==============================] - 0s 8ms/step - loss: 0.0414 - val_loss: 0.0414\n",
|
| 120 |
-
"Epoch 9/50\n",
|
| 121 |
-
"19/19 [==============================] - 0s 7ms/step - loss: 0.0391 - val_loss: 0.0395\n",
|
| 122 |
-
"Epoch 10/50\n",
|
| 123 |
-
"19/19 [==============================] - 0s 3ms/step - loss: 0.0371 - val_loss: 0.0383\n",
|
| 124 |
-
"Epoch 11/50\n",
|
| 125 |
-
"19/19 [==============================] - 0s 3ms/step - loss: 0.0355 - val_loss: 0.0372\n",
|
| 126 |
-
"Epoch 12/50\n",
|
| 127 |
-
"19/19 [==============================] - 0s 4ms/step - loss: 0.0341 - val_loss: 0.0362\n",
|
| 128 |
-
"Epoch 13/50\n",
|
| 129 |
-
"19/19 [==============================] - 0s 3ms/step - loss: 0.0329 - val_loss: 0.0352\n",
|
| 130 |
-
"Epoch 14/50\n",
|
| 131 |
-
"19/19 [==============================] - 0s 3ms/step - loss: 0.0318 - val_loss: 0.0345\n",
|
| 132 |
-
"Epoch 15/50\n",
|
| 133 |
-
"19/19 [==============================] - 0s 3ms/step - loss: 0.0307 - val_loss: 0.0341\n",
|
| 134 |
-
"Epoch 16/50\n",
|
| 135 |
-
"19/19 [==============================] - 0s 3ms/step - loss: 0.0300 - val_loss: 0.0337\n",
|
| 136 |
-
"Epoch 17/50\n",
|
| 137 |
-
"19/19 [==============================] - 0s 3ms/step - loss: 0.0291 - val_loss: 0.0334\n",
|
| 138 |
-
"Epoch 18/50\n",
|
| 139 |
-
"19/19 [==============================] - 0s 3ms/step - loss: 0.0281 - val_loss: 0.0331\n",
|
| 140 |
-
"Epoch 19/50\n",
|
| 141 |
-
"19/19 [==============================] - 0s 4ms/step - loss: 0.0278 - val_loss: 0.0329\n",
|
| 142 |
-
"Epoch 20/50\n",
|
| 143 |
-
"19/19 [==============================] - 0s 5ms/step - loss: 0.0270 - val_loss: 0.0328\n",
|
| 144 |
-
"Epoch 21/50\n",
|
| 145 |
-
"19/19 [==============================] - 0s 4ms/step - loss: 0.0266 - val_loss: 0.0329\n",
|
| 146 |
-
"Epoch 22/50\n",
|
| 147 |
-
"19/19 [==============================] - 0s 3ms/step - loss: 0.0261 - val_loss: 0.0326\n",
|
| 148 |
-
"Epoch 23/50\n",
|
| 149 |
-
"19/19 [==============================] - 0s 3ms/step - loss: 0.0257 - val_loss: 0.0325\n",
|
| 150 |
-
"Epoch 24/50\n",
|
| 151 |
-
"19/19 [==============================] - 0s 3ms/step - loss: 0.0249 - val_loss: 0.0324\n",
|
| 152 |
-
"Epoch 25/50\n",
|
| 153 |
-
"19/19 [==============================] - 0s 3ms/step - loss: 0.0247 - val_loss: 0.0325\n",
|
| 154 |
-
"Epoch 26/50\n",
|
| 155 |
-
"19/19 [==============================] - 0s 3ms/step - loss: 0.0243 - val_loss: 0.0322\n",
|
| 156 |
-
"Epoch 27/50\n",
|
| 157 |
-
"19/19 [==============================] - 0s 3ms/step - loss: 0.0239 - val_loss: 0.0325\n",
|
| 158 |
-
"Epoch 28/50\n",
|
| 159 |
-
"19/19 [==============================] - 0s 3ms/step - loss: 0.0235 - val_loss: 0.0325\n",
|
| 160 |
-
"Epoch 29/50\n",
|
| 161 |
-
"19/19 [==============================] - 0s 5ms/step - loss: 0.0226 - val_loss: 0.0328\n",
|
| 162 |
-
"Epoch 30/50\n",
|
| 163 |
-
"19/19 [==============================] - 0s 5ms/step - loss: 0.0227 - val_loss: 0.0326\n",
|
| 164 |
-
"Epoch 31/50\n",
|
| 165 |
-
"19/19 [==============================] - 0s 4ms/step - loss: 0.0224 - val_loss: 0.0326\n",
|
| 166 |
-
"\n",
|
| 167 |
-
" STACKING (GPU-Keras MLP)\n",
|
| 168 |
-
"Fmax = 0.6937\n",
|
| 169 |
-
"Thr. = 0.34\n",
|
| 170 |
-
"AuPRC = 0.7551\n",
|
| 171 |
-
"Smin = 12.2407\n"
|
| 172 |
-
]
|
| 173 |
-
}
|
| 174 |
-
],
|
| 175 |
-
"source": [
|
| 176 |
-
"# %%\n",
|
| 177 |
-
"from tensorflow.keras.models import Sequential\n",
|
| 178 |
-
"from tensorflow.keras.layers import Dense, Dropout\n",
|
| 179 |
-
"from tensorflow.keras.optimizers import Adam\n",
|
| 180 |
-
"from sklearn.model_selection import train_test_split\n",
|
| 181 |
-
"from sklearn.metrics import precision_recall_curve, auc\n",
|
| 182 |
-
"import numpy as np\n",
|
| 183 |
-
"import math\n",
|
| 184 |
-
"\n",
|
| 185 |
-
"# --- Preparar dados para stacking ---\n",
|
| 186 |
-
"# (já com y_pb, y_bfd, y_esm com shape (1724, 597))\n",
|
| 187 |
-
"X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) # (1724, 597*3)\n",
|
| 188 |
-
"y_stack = y_true.copy() # (1724, 597)\n",
|
| 189 |
-
"\n",
|
| 190 |
-
"# --- Divisão treino/validação ---\n",
|
| 191 |
-
"X_train, X_val, y_train, y_val = train_test_split(X_stack, y_stack, test_size=0.3, random_state=42)\n",
|
| 192 |
-
"\n",
|
| 193 |
-
"# --- Modelo MLP (usa GPU automaticamente se disponível) ---\n",
|
| 194 |
-
"from tensorflow.keras.callbacks import EarlyStopping\n",
|
| 195 |
-
"\n",
|
| 196 |
-
"model = Sequential([\n",
|
| 197 |
-
" Dense(512, activation=\"relu\", input_shape=(X_train.shape[1],)),\n",
|
| 198 |
-
" Dropout(0.3),\n",
|
| 199 |
-
" Dense(256, activation=\"relu\"),\n",
|
| 200 |
-
" Dropout(0.3),\n",
|
| 201 |
-
" Dense(y_stack.shape[1], activation=\"sigmoid\")\n",
|
| 202 |
-
"])\n",
|
| 203 |
-
"\n",
|
| 204 |
-
"model.compile(optimizer=Adam(1e-3), loss=\"binary_crossentropy\")\n",
|
| 205 |
-
"\n",
|
| 206 |
-
"model.fit(X_train, y_train, validation_data=(X_val, y_val),\n",
|
| 207 |
-
" epochs=50, batch_size=64, verbose=1,\n",
|
| 208 |
-
" callbacks=[EarlyStopping(patience=5, restore_best_weights=True)])\n",
|
| 209 |
-
"\n",
|
| 210 |
-
"# --- Prever com stacking ---\n",
|
| 211 |
-
"y_pred_stack = model.predict(X_stack, batch_size=64)\n",
|
| 212 |
-
"\n",
|
| 213 |
-
"# --- Métricas ---\n",
|
| 214 |
-
"THR = np.linspace(0, 1, 101)\n",
|
| 215 |
-
"def fmax(y_t, y_p):\n",
|
| 216 |
-
" best, thr = 0, 0\n",
|
| 217 |
-
" for t in THR:\n",
|
| 218 |
-
" y_b = (y_p >= t).astype(int)\n",
|
| 219 |
-
" tp = (y_t * y_b).sum(1); fp = ((1 - y_t) * y_b).sum(1); fn = (y_t * (1 - y_b)).sum(1)\n",
|
| 220 |
-
" f1 = 2 * tp / (2 * tp + fp + fn + 1e-8); m = f1.mean()\n",
|
| 221 |
-
" if m > best: best, thr = m, t\n",
|
| 222 |
-
" return best, thr\n",
|
| 223 |
-
"\n",
|
| 224 |
-
"def auprc(y_t, y_p):\n",
|
| 225 |
-
" p, r, _ = precision_recall_curve(y_t.ravel(), y_p.ravel())\n",
|
| 226 |
-
" return auc(r, p)\n",
|
| 227 |
-
"\n",
|
| 228 |
-
"def smin(y_t, y_p, thr, alpha=0.5):\n",
|
| 229 |
-
" y_b = (y_p >= thr).astype(int)\n",
|
| 230 |
-
" ic = -(np.log((y_t + y_b).sum(0) + 1e-8) - np.log((y_t + y_b).sum() + 1e-8))\n",
|
| 231 |
-
" ru = np.logical_and(y_b, np.logical_not(y_t)) * ic\n",
|
| 232 |
-
" mi = np.logical_and(y_t, np.logical_not(y_b)) * ic\n",
|
| 233 |
-
" return np.sqrt((alpha * ru.sum(1))**2 + ((1 - alpha) * mi.sum(1))**2).mean()\n",
|
| 234 |
-
"\n",
|
| 235 |
-
"f, thr = fmax(y_stack, y_pred_stack)\n",
|
| 236 |
-
"print(f\"\\n STACKING (GPU-Keras MLP)\")\n",
|
| 237 |
-
"print(f\"Fmax = {f:.4f}\")\n",
|
| 238 |
-
"print(f\"Thr. = {thr:.2f}\")\n",
|
| 239 |
-
"print(f\"AuPRC = {auprc(y_stack, y_pred_stack):.4f}\")\n",
|
| 240 |
-
"print(f\"Smin = {smin(y_stack, y_pred_stack, thr):.4f}\")\n"
|
| 241 |
-
]
|
| 242 |
-
},
|
| 243 |
-
{
|
| 244 |
-
"cell_type": "code",
|
| 245 |
-
"execution_count": 10,
|
| 246 |
-
"id": "00695029-3d24-4803-a6e1-8ac5fd70b710",
|
| 247 |
-
"metadata": {},
|
| 248 |
-
"outputs": [
|
| 249 |
-
{
|
| 250 |
-
"name": "stdout",
|
| 251 |
-
"output_type": "stream",
|
| 252 |
-
"text": [
|
| 253 |
-
"guardado em models/modelo_ensemble_stacking.keras\n"
|
| 254 |
-
]
|
| 255 |
-
}
|
| 256 |
-
],
|
| 257 |
-
"source": [
|
| 258 |
-
"model.save(\"models/modelo_ensemble_stacking.keras\")\n",
|
| 259 |
-
"print('guardado em models/modelo_ensemble_stacking.keras')"
|
| 260 |
-
]
|
| 261 |
-
},
|
| 262 |
-
{
|
| 263 |
-
"cell_type": "code",
|
| 264 |
-
"execution_count": null,
|
| 265 |
-
"id": "37629e3a-1c24-4f0f-9d12-dddf48be8724",
|
| 266 |
-
"metadata": {},
|
| 267 |
-
"outputs": [],
|
| 268 |
-
"source": []
|
| 269 |
-
}
|
| 270 |
-
],
|
| 271 |
-
"metadata": {
|
| 272 |
-
"kernelspec": {
|
| 273 |
-
"display_name": "Python 3 (ipykernel)",
|
| 274 |
-
"language": "python",
|
| 275 |
-
"name": "python3"
|
| 276 |
-
},
|
| 277 |
-
"language_info": {
|
| 278 |
-
"codemirror_mode": {
|
| 279 |
-
"name": "ipython",
|
| 280 |
-
"version": 3
|
| 281 |
-
},
|
| 282 |
-
"file_extension": ".py",
|
| 283 |
-
"mimetype": "text/x-python",
|
| 284 |
-
"name": "python",
|
| 285 |
-
"nbconvert_exporter": "python",
|
| 286 |
-
"pygments_lexer": "ipython3",
|
| 287 |
-
"version": "3.8.18"
|
| 288 |
-
}
|
| 289 |
-
},
|
| 290 |
-
"nbformat": 4,
|
| 291 |
-
"nbformat_minor": 5
|
| 292 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/Input.ipynb
DELETED
|
@@ -1,157 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 1,
|
| 6 |
-
"id": "9eca7d69-3f17-4306-84d0-58a0363144fa",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [
|
| 9 |
-
{
|
| 10 |
-
"name": "stdout",
|
| 11 |
-
"output_type": "stream",
|
| 12 |
-
"text": [
|
| 13 |
-
"A gerar embeddings por chunks...\n"
|
| 14 |
-
]
|
| 15 |
-
},
|
| 16 |
-
{
|
| 17 |
-
"name": "stderr",
|
| 18 |
-
"output_type": "stream",
|
| 19 |
-
"text": [
|
| 20 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
| 21 |
-
" warnings.warn(\n",
|
| 22 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
| 23 |
-
" warnings.warn(\n",
|
| 24 |
-
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
|
| 25 |
-
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
| 26 |
-
]
|
| 27 |
-
},
|
| 28 |
-
{
|
| 29 |
-
"name": "stdout",
|
| 30 |
-
"output_type": "stream",
|
| 31 |
-
"text": [
|
| 32 |
-
"A fazer predições base...\n",
|
| 33 |
-
"\n",
|
| 34 |
-
" GO terms com prob ≥ 0.5:\n",
|
| 35 |
-
"('GO:0003674', 'GO:0003824', 'GO:0005488', 'GO:0016491', 'GO:0036094', 'GO:0043167')\n",
|
| 36 |
-
"\n",
|
| 37 |
-
" Top 10 GO terms mais prováveis:\n",
|
| 38 |
-
"GO:0003674 : 0.9975\n",
|
| 39 |
-
"GO:0003824 : 0.9156\n",
|
| 40 |
-
"GO:0036094 : 0.6652\n",
|
| 41 |
-
"GO:0043167 : 0.6336\n",
|
| 42 |
-
"GO:0016491 : 0.6327\n",
|
| 43 |
-
"GO:0005488 : 0.5595\n",
|
| 44 |
-
"GO:0043169 : 0.4801\n",
|
| 45 |
-
"GO:0140096 : 0.4790\n",
|
| 46 |
-
"GO:0051213 : 0.4551\n",
|
| 47 |
-
"GO:0046872 : 0.4098\n"
|
| 48 |
-
]
|
| 49 |
-
}
|
| 50 |
-
],
|
| 51 |
-
"source": [
|
| 52 |
-
"# %%\n",
|
| 53 |
-
"import numpy as np\n",
|
| 54 |
-
"import torch\n",
|
| 55 |
-
"from transformers import AutoTokenizer, AutoModel\n",
|
| 56 |
-
"from tensorflow.keras.models import load_model\n",
|
| 57 |
-
"import joblib\n",
|
| 58 |
-
"\n",
|
| 59 |
-
"# --- Parâmetros ---\n",
|
| 60 |
-
"SEQ_FASTA = \"MPISSSSSSSTKSMRRAASELERSDSVTSPRFIGRRQSLIEDARKEREAAAAAAEAAEATEQIVFEEEDGKALLNLFFTLRSSKTPALSRSLKVFETFEAKIHHLETRPCRKPRDSLEGLEYFVRCEVHLSDVSTLISSIKRIAEDVKTTKEVKFHWFPKKISELDRCHHLITKFDPDLDQEHPGFTDPVYRQRRKMIGDIAFRYKQGEPIPRVEYTEEEIGTWREVYSTLRDLYTTHACSEHLEAFNLLERHCGYSPENIPQLEDVSRFLRERTGFQLRPVAGLLSARDFLASLAFRVFQCTQYIRHASSPMHSPEPDCVHELLGHVPILADRVFAQFSQNIGLASLGASEEDIEKLSTLYWFTVEFGLCKQGGIVKAYGAGLLSSYGELVHALSDEPERREFDPEAAAIQPYQDQNYQSVYFVSESFTDAKEKLRSYVAGIKRPFSVRFDPYTYSIEVLDNPLKIRGGLESVKDELKMLTDALNVLA\"\n",
|
| 61 |
-
"TOP_N = 10\n",
|
| 62 |
-
"\n",
|
| 63 |
-
"# --- 1. Função para dividir sequência (512 para Protbert e Protbertbfd. 1024 para ESM2) ---\n",
|
| 64 |
-
"def slice_sequence(seq, chunk_size):\n",
|
| 65 |
-
" return [seq[i:i+chunk_size] for i in range(0, len(seq), chunk_size)]\n",
|
| 66 |
-
"\n",
|
| 67 |
-
"# --- 2. Função para gerar embeddings médios ---\n",
|
| 68 |
-
"def get_embedding_mean(model_name, seq, chunk_size):\n",
|
| 69 |
-
" tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)\n",
|
| 70 |
-
" model = AutoModel.from_pretrained(model_name)\n",
|
| 71 |
-
" model.eval()\n",
|
| 72 |
-
"\n",
|
| 73 |
-
" chunks = [seq[i:i+chunk_size] for i in range(0, len(seq), chunk_size)]\n",
|
| 74 |
-
" embeddings = []\n",
|
| 75 |
-
"\n",
|
| 76 |
-
" for chunk in chunks:\n",
|
| 77 |
-
" seq_chunk = \" \".join(list(chunk))\n",
|
| 78 |
-
" # tokenizar SEM truncar\n",
|
| 79 |
-
" inputs = tokenizer(seq_chunk,\n",
|
| 80 |
-
" return_tensors=\"pt\",\n",
|
| 81 |
-
" truncation=False, # ≤ 512 ou 1024 já garantido\n",
|
| 82 |
-
" padding=False)\n",
|
| 83 |
-
" with torch.no_grad():\n",
|
| 84 |
-
" cls = model(**inputs).last_hidden_state[:, 0, :].squeeze().numpy()\n",
|
| 85 |
-
" embeddings.append(cls)\n",
|
| 86 |
-
"\n",
|
| 87 |
-
" return np.mean(embeddings, axis=0, keepdims=True) # (1, dim)\n",
|
| 88 |
-
"\n",
|
| 89 |
-
"print(\"A gerar embeddings por chunks...\")\n",
|
| 90 |
-
"emb_pb = get_embedding_mean(\"Rostlab/prot_bert\", SEQ_FASTA, 512)\n",
|
| 91 |
-
"emb_bfd = get_embedding_mean(\"Rostlab/prot_bert_bfd\", SEQ_FASTA, 512)\n",
|
| 92 |
-
"emb_esm = get_embedding_mean(\"facebook/esm2_t33_650M_UR50D\", SEQ_FASTA, 1024)\n",
|
| 93 |
-
"\n",
|
| 94 |
-
"# --- 3. Carregar os MLPs base ---\n",
|
| 95 |
-
"mlp_pb = load_model(\"models/protbert_mlp.keras\")\n",
|
| 96 |
-
"mlp_bfd = load_model(\"models/protbertbfd_mlp.keras\")\n",
|
| 97 |
-
"mlp_esm = load_model(\"models/esm2_mlp.keras\")\n",
|
| 98 |
-
"\n",
|
| 99 |
-
"# --- 4. Gerar predições base (garantir 597 colunas) ---\n",
|
| 100 |
-
"print(\"A fazer predições base...\")\n",
|
| 101 |
-
"y_pb = mlp_pb.predict(emb_pb)[:, :597]\n",
|
| 102 |
-
"y_bfd = mlp_bfd.predict(emb_bfd)[:, :597]\n",
|
| 103 |
-
"y_esm = mlp_esm.predict(emb_esm)[:, :597]\n",
|
| 104 |
-
"\n",
|
| 105 |
-
"# --- 5. Concatenar para o stacking ---\n",
|
| 106 |
-
"X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)\n",
|
| 107 |
-
"\n",
|
| 108 |
-
"# --- 6. Carregar modelo de stacking ---\n",
|
| 109 |
-
"stacking = load_model(\"models/modelo_ensemble_stacking.keras\")\n",
|
| 110 |
-
"y_pred = stacking.predict(X_stack)\n",
|
| 111 |
-
"\n",
|
| 112 |
-
"# --- 7. Carregar binarizador (597 GO terms) ---\n",
|
| 113 |
-
"mlb = joblib.load(\"data/mlb_597.pkl\")\n",
|
| 114 |
-
"go_terms = mlb.classes_\n",
|
| 115 |
-
"\n",
|
| 116 |
-
"# --- 8. Mostrar resultados ---\n",
|
| 117 |
-
"print(\"\\n GO terms com prob ≥ 0.5:\")\n",
|
| 118 |
-
"predicted_terms = mlb.inverse_transform((y_pred >= 0.5).astype(int))\n",
|
| 119 |
-
"print(predicted_terms[0] if predicted_terms[0] else \"Nenhum GO term acima de 0.5\")\n",
|
| 120 |
-
"\n",
|
| 121 |
-
"print(f\"\\n Top {TOP_N} GO terms mais prováveis:\")\n",
|
| 122 |
-
"top_idx = np.argsort(-y_pred[0])[:TOP_N]\n",
|
| 123 |
-
"for i in top_idx:\n",
|
| 124 |
-
" print(f\"{go_terms[i]} : {y_pred[0][i]:.4f}\")\n"
|
| 125 |
-
]
|
| 126 |
-
},
|
| 127 |
-
{
|
| 128 |
-
"cell_type": "code",
|
| 129 |
-
"execution_count": null,
|
| 130 |
-
"id": "e959e7d9-15ba-4533-a2bb-ddd7df2a639d",
|
| 131 |
-
"metadata": {},
|
| 132 |
-
"outputs": [],
|
| 133 |
-
"source": []
|
| 134 |
-
}
|
| 135 |
-
],
|
| 136 |
-
"metadata": {
|
| 137 |
-
"kernelspec": {
|
| 138 |
-
"display_name": "Python 3 (ipykernel)",
|
| 139 |
-
"language": "python",
|
| 140 |
-
"name": "python3"
|
| 141 |
-
},
|
| 142 |
-
"language_info": {
|
| 143 |
-
"codemirror_mode": {
|
| 144 |
-
"name": "ipython",
|
| 145 |
-
"version": 3
|
| 146 |
-
},
|
| 147 |
-
"file_extension": ".py",
|
| 148 |
-
"mimetype": "text/x-python",
|
| 149 |
-
"name": "python",
|
| 150 |
-
"nbconvert_exporter": "python",
|
| 151 |
-
"pygments_lexer": "ipython3",
|
| 152 |
-
"version": "3.8.18"
|
| 153 |
-
}
|
| 154 |
-
},
|
| 155 |
-
"nbformat": 4,
|
| 156 |
-
"nbformat_minor": 5
|
| 157 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/PAM1_ESM2.ipynb
DELETED
|
@@ -1,561 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 9,
|
| 6 |
-
"id": "641053e3-7fec-4f9b-a75e-ddd957af03c4",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [
|
| 9 |
-
{
|
| 10 |
-
"name": "stdout",
|
| 11 |
-
"output_type": "stream",
|
| 12 |
-
"text": [
|
| 13 |
-
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
|
| 14 |
-
"✓ Dataset preparado:\n",
|
| 15 |
-
" - Training: (31142, 3)\n",
|
| 16 |
-
" - Validation: (1724, 3)\n",
|
| 17 |
-
" - Test: (1724, 3)\n",
|
| 18 |
-
" - GO terms: 602\n"
|
| 19 |
-
]
|
| 20 |
-
}
|
| 21 |
-
],
|
| 22 |
-
"source": [
|
| 23 |
-
"# %%\n",
|
| 24 |
-
"import pandas as pd\n",
|
| 25 |
-
"import numpy as np\n",
|
| 26 |
-
"from Bio import SeqIO\n",
|
| 27 |
-
"from goatools.obo_parser import GODag\n",
|
| 28 |
-
"from collections import Counter\n",
|
| 29 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 30 |
-
"from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
|
| 31 |
-
"import os, random\n",
|
| 32 |
-
"\n",
|
| 33 |
-
"# --- 1. Carregar ficheiros principais ---\n",
|
| 34 |
-
"FASTA = \"uniprot_sprot_exp.fasta\"\n",
|
| 35 |
-
"ANNOT = \"uniprot_sprot_exp.txt\"\n",
|
| 36 |
-
"GO_OBO = \"go.obo\"\n",
|
| 37 |
-
"\n",
|
| 38 |
-
"# --- 2. Ler sequências ---\n",
|
| 39 |
-
"seqs, ids = [], []\n",
|
| 40 |
-
"for record in SeqIO.parse(FASTA, \"fasta\"):\n",
|
| 41 |
-
" ids.append(record.id)\n",
|
| 42 |
-
" seqs.append(str(record.seq))\n",
|
| 43 |
-
"\n",
|
| 44 |
-
"df_seq = pd.DataFrame({\"protein_id\": ids, \"sequence\": seqs})\n",
|
| 45 |
-
"\n",
|
| 46 |
-
"# --- 3. Ler anotações GO:MF ---\n",
|
| 47 |
-
"df_ann = pd.read_csv(ANNOT, sep=\"\\t\", names=[\"protein_id\", \"go_term\", \"category\"])\n",
|
| 48 |
-
"df_ann = df_ann[df_ann[\"category\"] == \"F\"]\n",
|
| 49 |
-
"\n",
|
| 50 |
-
"# --- 4. Propagação hierárquica dos GO terms ---\n",
|
| 51 |
-
"go_dag = GODag(GO_OBO)\n",
|
| 52 |
-
"mf_terms = {t for t, o in go_dag.items() if o.namespace == \"molecular_function\"}\n",
|
| 53 |
-
"\n",
|
| 54 |
-
"def propagate_terms(terms):\n",
|
| 55 |
-
" expanded = set()\n",
|
| 56 |
-
" for t in terms:\n",
|
| 57 |
-
" if t in go_dag:\n",
|
| 58 |
-
" expanded |= go_dag[t].get_all_parents()\n",
|
| 59 |
-
" expanded.add(t)\n",
|
| 60 |
-
" return list(expanded & mf_terms)\n",
|
| 61 |
-
"\n",
|
| 62 |
-
"grouped = df_ann.groupby(\"protein_id\")[\"go_term\"].apply(list).reset_index()\n",
|
| 63 |
-
"grouped[\"go_term\"] = grouped[\"go_term\"].apply(propagate_terms)\n",
|
| 64 |
-
"\n",
|
| 65 |
-
"# --- 5. Juntar com sequência ---\n",
|
| 66 |
-
"df = df_seq.merge(grouped, on=\"protein_id\")\n",
|
| 67 |
-
"df = df[df[\"go_term\"].str.len() > 0]\n",
|
| 68 |
-
"\n",
|
| 69 |
-
"# --- 6. Filtrar GO terms com ≥50 proteínas ---\n",
|
| 70 |
-
"all_terms = [term for sublist in df[\"go_term\"] for term in sublist]\n",
|
| 71 |
-
"term_counts = Counter(all_terms)\n",
|
| 72 |
-
"valid_terms = {t for t, count in term_counts.items() if count >= 50}\n",
|
| 73 |
-
"\n",
|
| 74 |
-
"df[\"go_term\"] = df[\"go_term\"].apply(lambda ts: [t for t in ts if t in valid_terms])\n",
|
| 75 |
-
"df = df[df[\"go_term\"].str.len() > 0]\n",
|
| 76 |
-
"\n",
|
| 77 |
-
"# --- 7. Preparar labels e dividir por proteína ---\n",
|
| 78 |
-
"df[\"go_terms\"] = df[\"go_term\"].apply(lambda x: ';'.join(sorted(set(x))))\n",
|
| 79 |
-
"df = df[[\"protein_id\", \"sequence\", \"go_terms\"]].drop_duplicates()\n",
|
| 80 |
-
"\n",
|
| 81 |
-
"mlb = MultiLabelBinarizer()\n",
|
| 82 |
-
"Y = mlb.fit_transform(df[\"go_terms\"].str.split(\";\"))\n",
|
| 83 |
-
"X = df[[\"protein_id\", \"sequence\"]].values\n",
|
| 84 |
-
"\n",
|
| 85 |
-
"mskf = MultilabelStratifiedKFold(n_splits=10, random_state=42, shuffle=True)\n",
|
| 86 |
-
"train_idx, temp_idx = next(mskf.split(X, Y))\n",
|
| 87 |
-
"val_idx, test_idx = np.array_split(temp_idx, 2)\n",
|
| 88 |
-
"\n",
|
| 89 |
-
"df_train = df.iloc[train_idx].copy()\n",
|
| 90 |
-
"df_val = df.iloc[val_idx].copy()\n",
|
| 91 |
-
"df_test = df.iloc[test_idx].copy()\n",
|
| 92 |
-
"\n",
|
| 93 |
-
"os.makedirs(\"data\", exist_ok=True)\n",
|
| 94 |
-
"df_train.to_csv(\"data/mf-training.csv\", index=False)\n",
|
| 95 |
-
"df_val.to_csv(\"data/mf-validation.csv\", index=False)\n",
|
| 96 |
-
"df_test.to_csv(\"data/mf-test.csv\", index=False)\n",
|
| 97 |
-
"\n",
|
| 98 |
-
"# --- 8. Guardar o binarizador ---\n",
|
| 99 |
-
"import joblib\n",
|
| 100 |
-
"joblib.dump(mlb, \"data/mlb.pkl\")\n",
|
| 101 |
-
"\n",
|
| 102 |
-
"print(\"✓ Dataset preparado:\")\n",
|
| 103 |
-
"print(\" - Training:\", df_train.shape)\n",
|
| 104 |
-
"print(\" - Validation:\", df_val.shape)\n",
|
| 105 |
-
"print(\" - Test:\", df_test.shape)\n",
|
| 106 |
-
"print(\" - GO terms:\", len(mlb.classes_))\n"
|
| 107 |
-
]
|
| 108 |
-
},
|
| 109 |
-
{
|
| 110 |
-
"cell_type": "code",
|
| 111 |
-
"execution_count": 10,
|
| 112 |
-
"id": "40ba1798-daf8-4649-ae3f-bfe81df6437f",
|
| 113 |
-
"metadata": {},
|
| 114 |
-
"outputs": [],
|
| 115 |
-
"source": [
|
| 116 |
-
"# %%\n",
|
| 117 |
-
"import random\n",
|
| 118 |
-
"from collections import defaultdict\n",
|
| 119 |
-
"\n",
|
| 120 |
-
"# --- PAM1 matrix normalizada ---\n",
|
| 121 |
-
"pam_data = {\n",
|
| 122 |
-
" 'A': [9948, 19, 27, 42, 31, 46, 50, 92, 17, 7, 40, 88, 42, 41, 122, 279, 255, 9, 72, 723],\n",
|
| 123 |
-
" 'R': [14, 9871, 24, 38, 37, 130, 38, 62, 49, 4, 58, 205, 26, 33, 47, 103, 104, 5, 36, 52],\n",
|
| 124 |
-
" 'N': [20, 22, 9860, 181, 29, 36, 41, 67, 31, 5, 22, 49, 23, 10, 33, 83, 66, 3, 43, 32],\n",
|
| 125 |
-
" 'D': [40, 34, 187, 9818, 11, 63, 98, 61, 23, 5, 25, 54, 43, 13, 27, 88, 55, 4, 29, 36],\n",
|
| 126 |
-
" 'C': [20, 16, 26, 9, 9987, 10, 17, 37, 12, 2, 16, 26, 10, 19, 27, 26, 25, 2, 6, 67],\n",
|
| 127 |
-
" 'Q': [29, 118, 29, 49, 8, 9816, 72, 55, 36, 4, 60, 158, 35, 22, 39, 86, 74, 3, 34, 28],\n",
|
| 128 |
-
" 'E': [35, 29, 41, 101, 12, 71, 9804, 56, 33, 5, 36, 107, 42, 20, 38, 87, 69, 4, 30, 42],\n",
|
| 129 |
-
" 'G': [96, 61, 77, 70, 38, 51, 58, 9868, 26, 6, 37, 53, 39, 28, 69, 134, 116, 5, 47, 60],\n",
|
| 130 |
-
" 'H': [17, 53, 33, 19, 15, 39, 34, 24, 9907, 3, 32, 57, 24, 15, 27, 47, 43, 2, 22, 19],\n",
|
| 131 |
-
" 'I': [6, 3, 6, 6, 3, 5, 6, 7, 3, 9973, 23, 13, 12, 41, 93, 84, 115, 3, 8, 102],\n",
|
| 132 |
-
" 'L': [26, 39, 17, 15, 7, 33, 22, 20, 19, 27, 9864, 49, 24, 78, 117, 148, 193, 5, 24, 70],\n",
|
| 133 |
-
" 'K': [60, 198, 43, 52, 12, 142, 96, 53, 42, 10, 63, 9710, 33, 26, 54, 109, 102, 5, 43, 42],\n",
|
| 134 |
-
" 'M': [21, 22, 15, 18, 6, 20, 18, 18, 17, 11, 27, 32, 9945, 26, 34, 61, 71, 3, 12, 31],\n",
|
| 135 |
-
" 'F': [18, 17, 8, 6, 8, 11, 10, 16, 10, 44, 92, 24, 29, 9899, 89, 88, 142, 7, 14, 68],\n",
|
| 136 |
-
" 'P': [97, 47, 35, 29, 23, 35, 38, 57, 21, 24, 47, 56, 28, 76, 9785, 115, 77, 4, 24, 35],\n",
|
| 137 |
-
" 'S': [241, 87, 76, 73, 17, 56, 60, 99, 32, 13, 69, 92, 42, 67, 100, 9605, 212, 8, 63, 70],\n",
|
| 138 |
-
" 'T': [186, 78, 54, 37, 14, 42, 42, 83, 28, 23, 84, 85, 53, 93, 66, 182, 9676, 8, 39, 90],\n",
|
| 139 |
-
" 'W': [2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 5, 3, 4, 4, 9960, 3, 4],\n",
|
| 140 |
-
" 'Y': [29, 21, 17, 9, 4, 13, 9, 21, 10, 7, 20, 17, 11, 23, 19, 41, 31, 3, 9935, 23],\n",
|
| 141 |
-
" 'V': [368, 27, 18, 18, 50, 23, 34, 64, 15, 85, 72, 42, 33, 88, 42, 112, 137, 4, 20, 9514]\n",
|
| 142 |
-
"}\n",
|
| 143 |
-
"\n",
|
| 144 |
-
"pam_raw = pd.DataFrame(pam_data, index=pam_data.keys())\n",
|
| 145 |
-
"pam_matrix = pam_raw.div(pam_raw.sum(axis=1), axis=0)\n",
|
| 146 |
-
"pam_dict = {aa: pam_matrix.loc[aa].to_dict() for aa in pam_matrix.index}\n",
|
| 147 |
-
"\n",
|
| 148 |
-
"def pam1_substitution(aa):\n",
|
| 149 |
-
" if aa not in pam_dict:\n",
|
| 150 |
-
" return aa\n",
|
| 151 |
-
" subs = list(pam_dict[aa].keys())\n",
|
| 152 |
-
" probs = list(pam_dict[aa].values())\n",
|
| 153 |
-
" return np.random.choice(subs, p=probs)\n",
|
| 154 |
-
"\n",
|
| 155 |
-
"def augment_sequence(seq, sub_prob=0.05):\n",
|
| 156 |
-
" return ''.join([pam1_substitution(aa) if random.random() < sub_prob else aa for aa in seq])\n",
|
| 157 |
-
"\n",
|
| 158 |
-
"def slice_sequence(seq, win=1024):\n",
|
| 159 |
-
" if len(seq) <= win:\n",
|
| 160 |
-
" return [seq]\n",
|
| 161 |
-
" return [seq[i:i+win] for i in range(0, len(seq), win)]\n",
|
| 162 |
-
"\n",
|
| 163 |
-
"def format_seq(seq):\n",
|
| 164 |
-
" return \" \".join(seq)\n",
|
| 165 |
-
"\n",
|
| 166 |
-
"# --- Carregar labels e datasets ---\n",
|
| 167 |
-
"import joblib\n",
|
| 168 |
-
"mlb = joblib.load(\"data/mlb.pkl\")\n",
|
| 169 |
-
"df_train = pd.read_csv(\"data/mf-training.csv\")\n",
|
| 170 |
-
"df_val = pd.read_csv(\"data/mf-validation.csv\")\n",
|
| 171 |
-
"df_test = pd.read_csv(\"data/mf-test.csv\")\n",
|
| 172 |
-
"\n",
|
| 173 |
-
"# --- Slicing + augmentação no treino ---\n",
|
| 174 |
-
"X_train, y_train = [], []\n",
|
| 175 |
-
"\n",
|
| 176 |
-
"for _, row in df_train.iterrows():\n",
|
| 177 |
-
" seq_aug = augment_sequence(row[\"sequence\"], sub_prob=0.05)\n",
|
| 178 |
-
" slices = slice_sequence(seq_aug, win=1024)\n",
|
| 179 |
-
" label = mlb.transform([row[\"go_terms\"].split(\";\")])[0]\n",
|
| 180 |
-
" for sl in slices:\n",
|
| 181 |
-
" X_train.append(format_seq(sl))\n",
|
| 182 |
-
" y_train.append(label)\n",
|
| 183 |
-
"\n",
|
| 184 |
-
"# --- Sem slicing no val/test ---\n",
|
| 185 |
-
"X_val = [format_seq(seq) for seq in df_val[\"sequence\"]]\n",
|
| 186 |
-
"X_test = [format_seq(seq) for seq in df_test[\"sequence\"]]\n",
|
| 187 |
-
"\n",
|
| 188 |
-
"y_val = mlb.transform(df_val[\"go_terms\"].str.split(\";\"))\n",
|
| 189 |
-
"y_test = mlb.transform(df_test[\"go_terms\"].str.split(\";\"))\n",
|
| 190 |
-
"\n",
|
| 191 |
-
"np.save(\"embeddings/y_test.npy\", y_test)"
|
| 192 |
-
]
|
| 193 |
-
},
|
| 194 |
-
{
|
| 195 |
-
"cell_type": "code",
|
| 196 |
-
"execution_count": 11,
|
| 197 |
-
"id": "80d5c1fb-9c84-463d-8d8c-bfcc2982afc9",
|
| 198 |
-
"metadata": {},
|
| 199 |
-
"outputs": [
|
| 200 |
-
{
|
| 201 |
-
"name": "stderr",
|
| 202 |
-
"output_type": "stream",
|
| 203 |
-
"text": [
|
| 204 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
| 205 |
-
" warnings.warn(\n",
|
| 206 |
-
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
|
| 207 |
-
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
| 208 |
-
"100%|██████████| 2189/2189 [1:17:26<00:00, 2.12s/it]\n",
|
| 209 |
-
"100%|██████████| 108/108 [03:43<00:00, 2.07s/it]\n",
|
| 210 |
-
"100%|██████████| 108/108 [03:56<00:00, 2.19s/it]\n"
|
| 211 |
-
]
|
| 212 |
-
}
|
| 213 |
-
],
|
| 214 |
-
"source": [
|
| 215 |
-
"# %%\n",
|
| 216 |
-
"from transformers import AutoTokenizer, AutoModel\n",
|
| 217 |
-
"import torch\n",
|
| 218 |
-
"from tqdm import tqdm\n",
|
| 219 |
-
"import numpy as np\n",
|
| 220 |
-
"import os\n",
|
| 221 |
-
"\n",
|
| 222 |
-
"# --- Configurações ---\n",
|
| 223 |
-
"MODEL_NAME = \"facebook/esm2_t33_650M_UR50D\"\n",
|
| 224 |
-
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 225 |
-
"CHUNK_SIZE = 16\n",
|
| 226 |
-
"\n",
|
| 227 |
-
"# --- Carregar modelo ---\n",
|
| 228 |
-
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)\n",
|
| 229 |
-
"model = AutoModel.from_pretrained(MODEL_NAME)\n",
|
| 230 |
-
"model.to(DEVICE)\n",
|
| 231 |
-
"model.eval()\n",
|
| 232 |
-
"\n",
|
| 233 |
-
"def extract_embeddings(texts):\n",
|
| 234 |
-
" embeddings = []\n",
|
| 235 |
-
" for i in tqdm(range(0, len(texts), CHUNK_SIZE)):\n",
|
| 236 |
-
" batch = texts[i:i+CHUNK_SIZE]\n",
|
| 237 |
-
" with torch.no_grad():\n",
|
| 238 |
-
" inputs = tokenizer(batch, return_tensors=\"pt\", padding=True, truncation=True, max_length=1024)\n",
|
| 239 |
-
" inputs = {k: v.to(DEVICE) for k, v in inputs.items()}\n",
|
| 240 |
-
" outputs = model(**inputs).last_hidden_state\n",
|
| 241 |
-
" cls_tokens = outputs[:, 0, :] # token CLS\n",
|
| 242 |
-
" embeddings.append(cls_tokens.cpu().numpy())\n",
|
| 243 |
-
" return np.vstack(embeddings)\n",
|
| 244 |
-
"\n",
|
| 245 |
-
"# --- Extrair e guardar embeddings ---\n",
|
| 246 |
-
"os.makedirs(\"embeddings\", exist_ok=True)\n",
|
| 247 |
-
"\n",
|
| 248 |
-
"emb_train = extract_embeddings(X_train)\n",
|
| 249 |
-
"emb_val = extract_embeddings(X_val)\n",
|
| 250 |
-
"emb_test = extract_embeddings(X_test)\n",
|
| 251 |
-
"\n",
|
| 252 |
-
"np.save(\"embeddings/esm2_train.npy\", emb_train)\n",
|
| 253 |
-
"np.save(\"embeddings/esm2_val.npy\", emb_val)\n",
|
| 254 |
-
"np.save(\"embeddings/esm2_test.npy\", emb_test)\n",
|
| 255 |
-
"\n",
|
| 256 |
-
"np.save(\"embeddings/y_train.npy\", np.array(y_train))\n",
|
| 257 |
-
"np.save(\"embeddings/y_val.npy\", np.array(y_val))\n"
|
| 258 |
-
]
|
| 259 |
-
},
|
| 260 |
-
{
|
| 261 |
-
"cell_type": "code",
|
| 262 |
-
"execution_count": 1,
|
| 263 |
-
"id": "592e4f6c-b871-4f0b-b84c-f3918c698544",
|
| 264 |
-
"metadata": {},
|
| 265 |
-
"outputs": [
|
| 266 |
-
{
|
| 267 |
-
"name": "stdout",
|
| 268 |
-
"output_type": "stream",
|
| 269 |
-
"text": [
|
| 270 |
-
"Epoch 1/100\n",
|
| 271 |
-
"1095/1095 [==============================] - 4s 2ms/step - loss: 0.0552 - val_loss: 0.0455\n",
|
| 272 |
-
"Epoch 2/100\n",
|
| 273 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0445 - val_loss: 0.0424\n",
|
| 274 |
-
"Epoch 3/100\n",
|
| 275 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0419 - val_loss: 0.0394\n",
|
| 276 |
-
"Epoch 4/100\n",
|
| 277 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0403 - val_loss: 0.0381\n",
|
| 278 |
-
"Epoch 5/100\n",
|
| 279 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0392 - val_loss: 0.0373\n",
|
| 280 |
-
"Epoch 6/100\n",
|
| 281 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0383 - val_loss: 0.0362\n",
|
| 282 |
-
"Epoch 7/100\n",
|
| 283 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0374 - val_loss: 0.0358\n",
|
| 284 |
-
"Epoch 8/100\n",
|
| 285 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0368 - val_loss: 0.0351\n",
|
| 286 |
-
"Epoch 9/100\n",
|
| 287 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0362 - val_loss: 0.0348\n",
|
| 288 |
-
"Epoch 10/100\n",
|
| 289 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0357 - val_loss: 0.0344\n",
|
| 290 |
-
"Epoch 11/100\n",
|
| 291 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0353 - val_loss: 0.0340\n",
|
| 292 |
-
"Epoch 12/100\n",
|
| 293 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0349 - val_loss: 0.0335\n",
|
| 294 |
-
"Epoch 13/100\n",
|
| 295 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0344 - val_loss: 0.0334\n",
|
| 296 |
-
"Epoch 14/100\n",
|
| 297 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0342 - val_loss: 0.0330\n",
|
| 298 |
-
"Epoch 15/100\n",
|
| 299 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0338 - val_loss: 0.0325\n",
|
| 300 |
-
"Epoch 16/100\n",
|
| 301 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0337 - val_loss: 0.0327\n",
|
| 302 |
-
"Epoch 17/100\n",
|
| 303 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0333 - val_loss: 0.0325\n",
|
| 304 |
-
"Epoch 18/100\n",
|
| 305 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0330 - val_loss: 0.0322\n",
|
| 306 |
-
"Epoch 19/100\n",
|
| 307 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0328 - val_loss: 0.0321\n",
|
| 308 |
-
"Epoch 20/100\n",
|
| 309 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0326 - val_loss: 0.0322\n",
|
| 310 |
-
"Epoch 21/100\n",
|
| 311 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0323 - val_loss: 0.0320\n",
|
| 312 |
-
"Epoch 22/100\n",
|
| 313 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0322 - val_loss: 0.0317\n",
|
| 314 |
-
"Epoch 23/100\n",
|
| 315 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0320 - val_loss: 0.0318\n",
|
| 316 |
-
"Epoch 24/100\n",
|
| 317 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0317 - val_loss: 0.0315\n",
|
| 318 |
-
"Epoch 25/100\n",
|
| 319 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0316 - val_loss: 0.0317\n",
|
| 320 |
-
"Epoch 26/100\n",
|
| 321 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0314 - val_loss: 0.0313\n",
|
| 322 |
-
"Epoch 27/100\n",
|
| 323 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0313 - val_loss: 0.0320\n",
|
| 324 |
-
"Epoch 28/100\n",
|
| 325 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0311 - val_loss: 0.0315\n",
|
| 326 |
-
"Epoch 29/100\n",
|
| 327 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0310 - val_loss: 0.0313\n",
|
| 328 |
-
"Epoch 30/100\n",
|
| 329 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0309 - val_loss: 0.0313\n",
|
| 330 |
-
"Epoch 31/100\n",
|
| 331 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0307 - val_loss: 0.0310\n",
|
| 332 |
-
"Epoch 32/100\n",
|
| 333 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0306 - val_loss: 0.0310\n",
|
| 334 |
-
"Epoch 33/100\n",
|
| 335 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0304 - val_loss: 0.0310\n",
|
| 336 |
-
"Epoch 34/100\n",
|
| 337 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0303 - val_loss: 0.0312\n",
|
| 338 |
-
"Epoch 35/100\n",
|
| 339 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0302 - val_loss: 0.0309\n",
|
| 340 |
-
"Epoch 36/100\n",
|
| 341 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0300 - val_loss: 0.0310\n",
|
| 342 |
-
"Epoch 37/100\n",
|
| 343 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0299 - val_loss: 0.0313\n",
|
| 344 |
-
"Epoch 38/100\n",
|
| 345 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0298 - val_loss: 0.0312\n",
|
| 346 |
-
"Epoch 39/100\n",
|
| 347 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0296 - val_loss: 0.0307\n",
|
| 348 |
-
"Epoch 40/100\n",
|
| 349 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0296 - val_loss: 0.0306\n",
|
| 350 |
-
"Epoch 41/100\n",
|
| 351 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0295 - val_loss: 0.0310\n",
|
| 352 |
-
"Epoch 42/100\n",
|
| 353 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0294 - val_loss: 0.0304\n",
|
| 354 |
-
"Epoch 43/100\n",
|
| 355 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0294 - val_loss: 0.0308\n",
|
| 356 |
-
"Epoch 44/100\n",
|
| 357 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0293 - val_loss: 0.0306\n",
|
| 358 |
-
"Epoch 45/100\n",
|
| 359 |
-
"1095/1095 [==============================] - 3s 3ms/step - loss: 0.0292 - val_loss: 0.0307\n",
|
| 360 |
-
"Epoch 46/100\n",
|
| 361 |
-
"1095/1095 [==============================] - 4s 4ms/step - loss: 0.0290 - val_loss: 0.0305\n",
|
| 362 |
-
"Epoch 47/100\n",
|
| 363 |
-
"1095/1095 [==============================] - 4s 4ms/step - loss: 0.0290 - val_loss: 0.0305\n",
|
| 364 |
-
"Modelo guardado em models/esm2_mlp.keras\n",
|
| 365 |
-
" Predições do ESM-2 salvas com forma: (1724, 602)\n"
|
| 366 |
-
]
|
| 367 |
-
}
|
| 368 |
-
],
|
| 369 |
-
"source": [
|
| 370 |
-
"# %%\n",
|
| 371 |
-
"import numpy as np\n",
|
| 372 |
-
"import tensorflow as tf\n",
|
| 373 |
-
"from tensorflow.keras.models import Sequential\n",
|
| 374 |
-
"from tensorflow.keras.layers import Dense, Dropout\n",
|
| 375 |
-
"from tensorflow.keras.callbacks import EarlyStopping\n",
|
| 376 |
-
"from sklearn.metrics import average_precision_score\n",
|
| 377 |
-
"\n",
|
| 378 |
-
"# --- Carregar os embeddings e labels ---\n",
|
| 379 |
-
"X_train = np.load(\"embeddings/esm2_train.npy\")\n",
|
| 380 |
-
"X_val = np.load(\"embeddings/esm2_val.npy\")\n",
|
| 381 |
-
"X_test = np.load(\"embeddings/esm2_test.npy\")\n",
|
| 382 |
-
"\n",
|
| 383 |
-
"y_train = np.load(\"embeddings/y_train.npy\")\n",
|
| 384 |
-
"y_val = np.load(\"embeddings/y_val.npy\")\n",
|
| 385 |
-
"y_test = np.load(\"embeddings/y_test.npy\")\n",
|
| 386 |
-
"\n",
|
| 387 |
-
"# --- Definir o modelo ---\n",
|
| 388 |
-
"model = Sequential([\n",
|
| 389 |
-
" Dense(1024, activation='relu', input_shape=(X_train.shape[1],)),\n",
|
| 390 |
-
" Dropout(0.3),\n",
|
| 391 |
-
" Dense(512, activation='relu'),\n",
|
| 392 |
-
" Dropout(0.3),\n",
|
| 393 |
-
" Dense(y_train.shape[1], activation='sigmoid')\n",
|
| 394 |
-
"])\n",
|
| 395 |
-
"\n",
|
| 396 |
-
"model.compile(optimizer='adam', loss='binary_crossentropy')\n",
|
| 397 |
-
"\n",
|
| 398 |
-
"# --- Treinar ---\n",
|
| 399 |
-
"early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)\n",
|
| 400 |
-
"\n",
|
| 401 |
-
"history = model.fit(\n",
|
| 402 |
-
" X_train, y_train,\n",
|
| 403 |
-
" validation_data=(X_val, y_val),\n",
|
| 404 |
-
" epochs=100,\n",
|
| 405 |
-
" batch_size=32,\n",
|
| 406 |
-
" callbacks=[early_stop],\n",
|
| 407 |
-
" verbose=1\n",
|
| 408 |
-
")\n",
|
| 409 |
-
"\n",
|
| 410 |
-
"# --- Salvar o modelo ---\n",
|
| 411 |
-
"model.save(\"models/esm2_mlp.keras\")\n",
|
| 412 |
-
"print(\"Modelo guardado em models/esm2_mlp.keras\")\n",
|
| 413 |
-
"\n",
|
| 414 |
-
"# --- Fazer predições no conjunto de teste ---\n",
|
| 415 |
-
"y_prob = model.predict(X_test)\n",
|
| 416 |
-
"np.save(\"predictions/mf-esm2.npy\", y_prob)\n",
|
| 417 |
-
"\n",
|
| 418 |
-
"print(\" Predições do ESM-2 salvas com forma:\", y_prob.shape)\n"
|
| 419 |
-
]
|
| 420 |
-
},
|
| 421 |
-
{
|
| 422 |
-
"cell_type": "code",
|
| 423 |
-
"execution_count": 15,
|
| 424 |
-
"id": "3dddb0df-3ea5-4e32-8cf0-45e90be8ba66",
|
| 425 |
-
"metadata": {},
|
| 426 |
-
"outputs": [
|
| 427 |
-
{
|
| 428 |
-
"name": "stdout",
|
| 429 |
-
"output_type": "stream",
|
| 430 |
-
"text": [
|
| 431 |
-
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
|
| 432 |
-
"✓ Dados carregados: (1724, 602) proteínas × 602 GO terms\n",
|
| 433 |
-
"\n",
|
| 434 |
-
" Resultados finais (ESM-2 + PAM1 + propagação):\n",
|
| 435 |
-
"Fmax = 0.6439\n",
|
| 436 |
-
"Thr. = 0.34\n",
|
| 437 |
-
"AuPRC = 0.6948\n",
|
| 438 |
-
"Smin = 14.1500\n"
|
| 439 |
-
]
|
| 440 |
-
}
|
| 441 |
-
],
|
| 442 |
-
"source": [
|
| 443 |
-
"# %%\n",
|
| 444 |
-
"import numpy as np\n",
|
| 445 |
-
"import joblib\n",
|
| 446 |
-
"import math\n",
|
| 447 |
-
"from goatools.obo_parser import GODag\n",
|
| 448 |
-
"from sklearn.metrics import precision_recall_curve, auc\n",
|
| 449 |
-
"\n",
|
| 450 |
-
"# --- 1. Carregar dados e parâmetros ---\n",
|
| 451 |
-
"GO_FILE = \"go.obo\"\n",
|
| 452 |
-
"THRESHOLDS = np.arange(0.0, 1.01, 0.01)\n",
|
| 453 |
-
"ALPHA = 0.5\n",
|
| 454 |
-
"\n",
|
| 455 |
-
"mlb = joblib.load(\"data/mlb.pkl\")\n",
|
| 456 |
-
"y_true = np.load(\"embeddings/y_test.npy\")\n",
|
| 457 |
-
"y_prob = np.load(\"predictions/mf-esm2.npy\")\n",
|
| 458 |
-
"terms = mlb.classes_\n",
|
| 459 |
-
"go_dag = GODag(GO_FILE)\n",
|
| 460 |
-
"\n",
|
| 461 |
-
"print(f\"✓ Dados carregados: {y_true.shape} proteínas × {len(terms)} GO terms\")\n",
|
| 462 |
-
"\n",
|
| 463 |
-
"# --- 2. Fmax ---\n",
|
| 464 |
-
"def compute_fmax(y_true, y_prob, thresholds):\n",
|
| 465 |
-
" fmax, best_thr = 0, 0\n",
|
| 466 |
-
" for t in thresholds:\n",
|
| 467 |
-
" y_pred = (y_prob >= t).astype(int)\n",
|
| 468 |
-
" tp = (y_true * y_pred).sum(axis=1)\n",
|
| 469 |
-
" fp = ((1 - y_true) * y_pred).sum(axis=1)\n",
|
| 470 |
-
" fn = (y_true * (1 - y_pred)).sum(axis=1)\n",
|
| 471 |
-
" precision = tp / (tp + fp + 1e-8)\n",
|
| 472 |
-
" recall = tp / (tp + fn + 1e-8)\n",
|
| 473 |
-
" f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
|
| 474 |
-
" avg_f1 = np.mean(f1)\n",
|
| 475 |
-
" if avg_f1 > fmax:\n",
|
| 476 |
-
" fmax, best_thr = avg_f1, t\n",
|
| 477 |
-
" return fmax, best_thr\n",
|
| 478 |
-
"\n",
|
| 479 |
-
"# --- 3. AuPRC (micro) ---\n",
|
| 480 |
-
"def compute_auprc(y_true, y_prob):\n",
|
| 481 |
-
" precision, recall, _ = precision_recall_curve(y_true.ravel(), y_prob.ravel())\n",
|
| 482 |
-
" return auc(recall, precision)\n",
|
| 483 |
-
"\n",
|
| 484 |
-
"# --- 4. Smin ---\n",
|
| 485 |
-
"def compute_smin(y_true, y_prob, terms, threshold, go_dag, alpha=ALPHA):\n",
|
| 486 |
-
" y_pred = (y_prob >= threshold).astype(int)\n",
|
| 487 |
-
"\n",
|
| 488 |
-
" # Informação semântica: IC (Information Content)\n",
|
| 489 |
-
" ic = {}\n",
|
| 490 |
-
" total = (y_true + y_pred).sum(axis=0).sum()\n",
|
| 491 |
-
" for i, term in enumerate(terms):\n",
|
| 492 |
-
" freq = (y_true[:, i] + y_pred[:, i]).sum()\n",
|
| 493 |
-
" ic[term] = -np.log((freq + 1e-8) / total)\n",
|
| 494 |
-
"\n",
|
| 495 |
-
" # Para cada proteína, calcular RU e MI\n",
|
| 496 |
-
" s_values = []\n",
|
| 497 |
-
" for true_vec, pred_vec in zip(y_true, y_pred):\n",
|
| 498 |
-
" true_terms = {terms[i] for i in np.where(true_vec)[0]}\n",
|
| 499 |
-
" pred_terms = {terms[i] for i in np.where(pred_vec)[0]}\n",
|
| 500 |
-
"\n",
|
| 501 |
-
" anc_true = set()\n",
|
| 502 |
-
" for t in true_terms:\n",
|
| 503 |
-
" if t in go_dag:\n",
|
| 504 |
-
" anc_true |= go_dag[t].get_all_parents()\n",
|
| 505 |
-
" anc_pred = set()\n",
|
| 506 |
-
" for t in pred_terms:\n",
|
| 507 |
-
" if t in go_dag:\n",
|
| 508 |
-
" anc_pred |= go_dag[t].get_all_parents()\n",
|
| 509 |
-
"\n",
|
| 510 |
-
" ru = pred_terms - true_terms\n",
|
| 511 |
-
" mi = true_terms - pred_terms\n",
|
| 512 |
-
" dist_ru = sum(ic.get(t, 0) for t in ru)\n",
|
| 513 |
-
" dist_mi = sum(ic.get(t, 0) for t in mi)\n",
|
| 514 |
-
" s = math.sqrt((alpha * dist_ru)**2 + ((1 - alpha) * dist_mi)**2)\n",
|
| 515 |
-
" s_values.append(s)\n",
|
| 516 |
-
"\n",
|
| 517 |
-
" return np.mean(s_values)\n",
|
| 518 |
-
"\n",
|
| 519 |
-
"# --- 5. Avaliação ---\n",
|
| 520 |
-
"fmax, thr = compute_fmax(y_true, y_prob, THRESHOLDS)\n",
|
| 521 |
-
"auprc = compute_auprc(y_true, y_prob)\n",
|
| 522 |
-
"smin = compute_smin(y_true, y_prob, terms, thr, go_dag)\n",
|
| 523 |
-
"\n",
|
| 524 |
-
"print(f\"\\n Resultados finais (ESM-2 + PAM1 + propagação):\")\n",
|
| 525 |
-
"print(f\"Fmax = {fmax:.4f}\")\n",
|
| 526 |
-
"print(f\"Thr. = {thr:.2f}\")\n",
|
| 527 |
-
"print(f\"AuPRC = {auprc:.4f}\")\n",
|
| 528 |
-
"print(f\"Smin = {smin:.4f}\")\n"
|
| 529 |
-
]
|
| 530 |
-
},
|
| 531 |
-
{
|
| 532 |
-
"cell_type": "code",
|
| 533 |
-
"execution_count": null,
|
| 534 |
-
"id": "1a1ea084-01de-4dc4-88da-e7ffeb8c94c9",
|
| 535 |
-
"metadata": {},
|
| 536 |
-
"outputs": [],
|
| 537 |
-
"source": []
|
| 538 |
-
}
|
| 539 |
-
],
|
| 540 |
-
"metadata": {
|
| 541 |
-
"kernelspec": {
|
| 542 |
-
"display_name": "Python 3 (ipykernel)",
|
| 543 |
-
"language": "python",
|
| 544 |
-
"name": "python3"
|
| 545 |
-
},
|
| 546 |
-
"language_info": {
|
| 547 |
-
"codemirror_mode": {
|
| 548 |
-
"name": "ipython",
|
| 549 |
-
"version": 3
|
| 550 |
-
},
|
| 551 |
-
"file_extension": ".py",
|
| 552 |
-
"mimetype": "text/x-python",
|
| 553 |
-
"name": "python",
|
| 554 |
-
"nbconvert_exporter": "python",
|
| 555 |
-
"pygments_lexer": "ipython3",
|
| 556 |
-
"version": "3.8.18"
|
| 557 |
-
}
|
| 558 |
-
},
|
| 559 |
-
"nbformat": 4,
|
| 560 |
-
"nbformat_minor": 5
|
| 561 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/PAM1_protbert.ipynb
DELETED
|
@@ -1,935 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 2,
|
| 6 |
-
"id": "c6dbc330-062a-48f0-8242-3f21cc1c9c2b",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [
|
| 9 |
-
{
|
| 10 |
-
"name": "stdout",
|
| 11 |
-
"output_type": "stream",
|
| 12 |
-
"text": [
|
| 13 |
-
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
|
| 14 |
-
"✓ Ficheiros criados:\n",
|
| 15 |
-
" - data/mf-training.csv : (31142, 3)\n",
|
| 16 |
-
" - data/mf-validation.csv: (1724, 3)\n",
|
| 17 |
-
" - data/mf-test.csv : (1724, 3)\n",
|
| 18 |
-
"GO terms únicos (após propagação e filtro): 602\n"
|
| 19 |
-
]
|
| 20 |
-
}
|
| 21 |
-
],
|
| 22 |
-
"source": [
|
| 23 |
-
"import pandas as pd\n",
|
| 24 |
-
"from Bio import SeqIO\n",
|
| 25 |
-
"from collections import Counter\n",
|
| 26 |
-
"from goatools.obo_parser import GODag\n",
|
| 27 |
-
"from sklearn.model_selection import train_test_split\n",
|
| 28 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 29 |
-
"from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
|
| 30 |
-
"import numpy as np\n",
|
| 31 |
-
"import os\n",
|
| 32 |
-
"\n",
|
| 33 |
-
"# --- 1. Carregar GO anotações ------------------------------------------\n",
|
| 34 |
-
"annotations = pd.read_csv(\"uniprot_sprot_exp.txt\", sep=\"\\t\", names=[\"protein_id\", \"go_term\", \"go_category\"])\n",
|
| 35 |
-
"annotations_f = annotations[annotations[\"go_category\"] == \"F\"]\n",
|
| 36 |
-
"\n",
|
| 37 |
-
"# --- 2. Carregar DAG e propagar GO terms -------------------------------\n",
|
| 38 |
-
"# propagação hierárquica\n",
|
| 39 |
-
"# https://geneontology.org/docs/download-ontology/\n",
|
| 40 |
-
"go_dag = GODag(\"go.obo\")\n",
|
| 41 |
-
"mf_terms = {t for t, o in go_dag.items() if o.namespace == \"molecular_function\"}\n",
|
| 42 |
-
"\n",
|
| 43 |
-
"def propagate_terms(term_list):\n",
|
| 44 |
-
" full = set()\n",
|
| 45 |
-
" for t in term_list:\n",
|
| 46 |
-
" if t not in go_dag:\n",
|
| 47 |
-
" continue\n",
|
| 48 |
-
" full.add(t)\n",
|
| 49 |
-
" full.update(go_dag[t].get_all_parents())\n",
|
| 50 |
-
" return list(full & mf_terms)\n",
|
| 51 |
-
"\n",
|
| 52 |
-
"# --- 3. Carregar sequências --------------------------------------------\n",
|
| 53 |
-
"seqs, ids = [], []\n",
|
| 54 |
-
"for record in SeqIO.parse(\"uniprot_sprot_exp.fasta\", \"fasta\"):\n",
|
| 55 |
-
" ids.append(record.id)\n",
|
| 56 |
-
" seqs.append(str(record.seq))\n",
|
| 57 |
-
"\n",
|
| 58 |
-
"seq_df = pd.DataFrame({\"protein_id\": ids, \"sequence\": seqs})\n",
|
| 59 |
-
"\n",
|
| 60 |
-
"# --- 4. Juntar com GO anotado e propagar -------------------------------\n",
|
| 61 |
-
"grouped = annotations_f.groupby(\"protein_id\")[\"go_term\"].apply(list).reset_index()\n",
|
| 62 |
-
"data = seq_df.merge(grouped, on=\"protein_id\")\n",
|
| 63 |
-
"data = data[data[\"go_term\"].apply(len) > 0]\n",
|
| 64 |
-
"data[\"go_term\"] = data[\"go_term\"].apply(propagate_terms)\n",
|
| 65 |
-
"data = data[data[\"go_term\"].apply(len) > 0]\n",
|
| 66 |
-
"\n",
|
| 67 |
-
"# --- 5. Filtrar GO terms raros -----------------------------------------\n",
|
| 68 |
-
"# todos os terms com menos de 50 proteinas associadas\n",
|
| 69 |
-
"all_terms = [term for sublist in data[\"go_term\"] for term in sublist]\n",
|
| 70 |
-
"term_counts = Counter(all_terms)\n",
|
| 71 |
-
"valid_terms = {term for term, count in term_counts.items() if count >= 50}\n",
|
| 72 |
-
"data[\"go_term\"] = data[\"go_term\"].apply(lambda terms: [t for t in terms if t in valid_terms])\n",
|
| 73 |
-
"data = data[data[\"go_term\"].apply(len) > 0]\n",
|
| 74 |
-
"\n",
|
| 75 |
-
"# --- 6. Preparar dataset final -----------------------------------------\n",
|
| 76 |
-
"data[\"go_terms\"] = data[\"go_term\"].apply(lambda x: ';'.join(sorted(set(x))))\n",
|
| 77 |
-
"data = data[[\"protein_id\", \"sequence\", \"go_terms\"]].drop_duplicates()\n",
|
| 78 |
-
"\n",
|
| 79 |
-
"# --- 7. Binarizar labels e dividir -------------------------------------\n",
|
| 80 |
-
"mlb = MultiLabelBinarizer()\n",
|
| 81 |
-
"Y = mlb.fit_transform(data[\"go_terms\"].str.split(\";\"))\n",
|
| 82 |
-
"X = data[[\"protein_id\", \"sequence\"]].values\n",
|
| 83 |
-
"\n",
|
| 84 |
-
"mskf = MultilabelStratifiedKFold(n_splits=10, random_state=42, shuffle=True)\n",
|
| 85 |
-
"train_idx, temp_idx = next(mskf.split(X, Y))\n",
|
| 86 |
-
"val_idx, test_idx = np.array_split(temp_idx, 2)\n",
|
| 87 |
-
"\n",
|
| 88 |
-
"df_train = data.iloc[train_idx].copy()\n",
|
| 89 |
-
"df_val = data.iloc[val_idx].copy()\n",
|
| 90 |
-
"df_test = data.iloc[test_idx].copy()\n",
|
| 91 |
-
"\n",
|
| 92 |
-
"# --- 8. Guardar em CSV -------------------------------------------------\n",
|
| 93 |
-
"os.makedirs(\"data\", exist_ok=True)\n",
|
| 94 |
-
"df_train.to_csv(\"data/mf-training.csv\", index=False)\n",
|
| 95 |
-
"df_val.to_csv(\"data/mf-validation.csv\", index=False)\n",
|
| 96 |
-
"df_test.to_csv(\"data/mf-test.csv\", index=False)\n",
|
| 97 |
-
"\n",
|
| 98 |
-
"# --- 9. Confirmar ------------------------------------------------------\n",
|
| 99 |
-
"print(\"✓ Ficheiros criados:\")\n",
|
| 100 |
-
"print(\" - data/mf-training.csv :\", df_train.shape)\n",
|
| 101 |
-
"print(\" - data/mf-validation.csv:\", df_val.shape)\n",
|
| 102 |
-
"print(\" - data/mf-test.csv :\", df_test.shape)\n",
|
| 103 |
-
"print(f\"GO terms únicos (após propagação e filtro): {len(mlb.classes_)}\")\n"
|
| 104 |
-
]
|
| 105 |
-
},
|
| 106 |
-
{
|
| 107 |
-
"cell_type": "code",
|
| 108 |
-
"execution_count": 2,
|
| 109 |
-
"id": "6cf7aaa6-4941-4951-8d73-1f4f1f4362f3",
|
| 110 |
-
"metadata": {},
|
| 111 |
-
"outputs": [
|
| 112 |
-
{
|
| 113 |
-
"name": "stderr",
|
| 114 |
-
"output_type": "stream",
|
| 115 |
-
"text": [
|
| 116 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 117 |
-
" from .autonotebook import tqdm as notebook_tqdm\n",
|
| 118 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
|
| 119 |
-
" _torch_pytree._register_pytree_node(\n",
|
| 120 |
-
"100%|██████████| 31142/31142 [00:24<00:00, 1262.18it/s]\n",
|
| 121 |
-
"100%|██████████| 1724/1724 [00:00<00:00, 2628.24it/s]\n",
|
| 122 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\text\\preprocessor.py:382: UserWarning: The class_names argument is replacing the classes argument. Please update your code.\n",
|
| 123 |
-
" warnings.warn(\n"
|
| 124 |
-
]
|
| 125 |
-
},
|
| 126 |
-
{
|
| 127 |
-
"name": "stdout",
|
| 128 |
-
"output_type": "stream",
|
| 129 |
-
"text": [
|
| 130 |
-
"preprocessing train...\n",
|
| 131 |
-
"language: de\n",
|
| 132 |
-
"train sequence lengths:\n",
|
| 133 |
-
"\tmean : 423\n",
|
| 134 |
-
"\t95percentile : 604\n",
|
| 135 |
-
"\t99percentile : 715\n"
|
| 136 |
-
]
|
| 137 |
-
},
|
| 138 |
-
{
|
| 139 |
-
"data": {
|
| 140 |
-
"text/html": [
|
| 141 |
-
"\n",
|
| 142 |
-
"<style>\n",
|
| 143 |
-
" /* Turns off some styling */\n",
|
| 144 |
-
" progress {\n",
|
| 145 |
-
" /* gets rid of default border in Firefox and Opera. */\n",
|
| 146 |
-
" border: none;\n",
|
| 147 |
-
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
| 148 |
-
" background-size: auto;\n",
|
| 149 |
-
" }\n",
|
| 150 |
-
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
|
| 151 |
-
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
|
| 152 |
-
" }\n",
|
| 153 |
-
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
|
| 154 |
-
" background: #F44336;\n",
|
| 155 |
-
" }\n",
|
| 156 |
-
"</style>\n"
|
| 157 |
-
],
|
| 158 |
-
"text/plain": [
|
| 159 |
-
"<IPython.core.display.HTML object>"
|
| 160 |
-
]
|
| 161 |
-
},
|
| 162 |
-
"metadata": {},
|
| 163 |
-
"output_type": "display_data"
|
| 164 |
-
},
|
| 165 |
-
{
|
| 166 |
-
"data": {
|
| 167 |
-
"text/html": [],
|
| 168 |
-
"text/plain": [
|
| 169 |
-
"<IPython.core.display.HTML object>"
|
| 170 |
-
]
|
| 171 |
-
},
|
| 172 |
-
"metadata": {},
|
| 173 |
-
"output_type": "display_data"
|
| 174 |
-
},
|
| 175 |
-
{
|
| 176 |
-
"name": "stdout",
|
| 177 |
-
"output_type": "stream",
|
| 178 |
-
"text": [
|
| 179 |
-
"Is Multi-Label? True\n",
|
| 180 |
-
"preprocessing test...\n",
|
| 181 |
-
"language: de\n",
|
| 182 |
-
"test sequence lengths:\n",
|
| 183 |
-
"\tmean : 408\n",
|
| 184 |
-
"\t95percentile : 603\n",
|
| 185 |
-
"\t99percentile : 714\n"
|
| 186 |
-
]
|
| 187 |
-
},
|
| 188 |
-
{
|
| 189 |
-
"data": {
|
| 190 |
-
"text/html": [
|
| 191 |
-
"\n",
|
| 192 |
-
"<style>\n",
|
| 193 |
-
" /* Turns off some styling */\n",
|
| 194 |
-
" progress {\n",
|
| 195 |
-
" /* gets rid of default border in Firefox and Opera. */\n",
|
| 196 |
-
" border: none;\n",
|
| 197 |
-
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
| 198 |
-
" background-size: auto;\n",
|
| 199 |
-
" }\n",
|
| 200 |
-
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
|
| 201 |
-
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
|
| 202 |
-
" }\n",
|
| 203 |
-
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
|
| 204 |
-
" background: #F44336;\n",
|
| 205 |
-
" }\n",
|
| 206 |
-
"</style>\n"
|
| 207 |
-
],
|
| 208 |
-
"text/plain": [
|
| 209 |
-
"<IPython.core.display.HTML object>"
|
| 210 |
-
]
|
| 211 |
-
},
|
| 212 |
-
"metadata": {},
|
| 213 |
-
"output_type": "display_data"
|
| 214 |
-
},
|
| 215 |
-
{
|
| 216 |
-
"data": {
|
| 217 |
-
"text/html": [],
|
| 218 |
-
"text/plain": [
|
| 219 |
-
"<IPython.core.display.HTML object>"
|
| 220 |
-
]
|
| 221 |
-
},
|
| 222 |
-
"metadata": {},
|
| 223 |
-
"output_type": "display_data"
|
| 224 |
-
},
|
| 225 |
-
{
|
| 226 |
-
"name": "stderr",
|
| 227 |
-
"output_type": "stream",
|
| 228 |
-
"text": [
|
| 229 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\text\\preprocessor.py:1093: UserWarning: Could not load a Tensorflow version of model. (If this worked before, it might be an out-of-memory issue.) Attempting to download/load PyTorch version as TensorFlow model using from_pt=True. You will need PyTorch installed for this.\n",
|
| 230 |
-
" warnings.warn(\n"
|
| 231 |
-
]
|
| 232 |
-
},
|
| 233 |
-
{
|
| 234 |
-
"name": "stdout",
|
| 235 |
-
"output_type": "stream",
|
| 236 |
-
"text": [
|
| 237 |
-
"\n",
|
| 238 |
-
"\n",
|
| 239 |
-
"begin training using triangular learning rate policy with max lr of 1e-05...\n",
|
| 240 |
-
"Epoch 1/10\n",
|
| 241 |
-
"40995/40995 [==============================] - 13053s 318ms/step - loss: 0.0745 - binary_accuracy: 0.9866 - val_loss: 0.0582 - val_binary_accuracy: 0.9859\n",
|
| 242 |
-
"Epoch 2/10\n",
|
| 243 |
-
"40995/40995 [==============================] - 14484s 353ms/step - loss: 0.0504 - binary_accuracy: 0.9873 - val_loss: 0.0499 - val_binary_accuracy: 0.9867\n",
|
| 244 |
-
"Epoch 3/10\n",
|
| 245 |
-
"40995/40995 [==============================] - 14472s 353ms/step - loss: 0.0450 - binary_accuracy: 0.9879 - val_loss: 0.0449 - val_binary_accuracy: 0.9873\n",
|
| 246 |
-
"Epoch 4/10\n",
|
| 247 |
-
"40995/40995 [==============================] - 14445s 352ms/step - loss: 0.0407 - binary_accuracy: 0.9884 - val_loss: 0.0413 - val_binary_accuracy: 0.9878\n",
|
| 248 |
-
"Epoch 5/10\n",
|
| 249 |
-
"40995/40995 [==============================] - 12524s 305ms/step - loss: 0.0378 - binary_accuracy: 0.9888 - val_loss: 0.0394 - val_binary_accuracy: 0.9881\n",
|
| 250 |
-
"Epoch 6/10\n",
|
| 251 |
-
"40995/40995 [==============================] - 14737s 359ms/step - loss: 0.0359 - binary_accuracy: 0.9891 - val_loss: 0.0383 - val_binary_accuracy: 0.9883\n",
|
| 252 |
-
"Epoch 7/10\n",
|
| 253 |
-
"40995/40995 [==============================] - 20317s 495ms/step - loss: 0.0343 - binary_accuracy: 0.9894 - val_loss: 0.0371 - val_binary_accuracy: 0.9885\n",
|
| 254 |
-
"Epoch 8/10\n",
|
| 255 |
-
"40995/40995 [==============================] - 9073s 221ms/step - loss: 0.0331 - binary_accuracy: 0.9896 - val_loss: 0.0364 - val_binary_accuracy: 0.9887\n",
|
| 256 |
-
"Epoch 9/10\n",
|
| 257 |
-
"40995/40995 [==============================] - 9001s 219ms/step - loss: 0.0320 - binary_accuracy: 0.9898 - val_loss: 0.0360 - val_binary_accuracy: 0.9888\n",
|
| 258 |
-
"Epoch 10/10\n",
|
| 259 |
-
"40995/40995 [==============================] - 8980s 219ms/step - loss: 0.0311 - binary_accuracy: 0.9900 - val_loss: 0.0356 - val_binary_accuracy: 0.9890\n"
|
| 260 |
-
]
|
| 261 |
-
},
|
| 262 |
-
{
|
| 263 |
-
"ename": "RuntimeError",
|
| 264 |
-
"evalue": "Can't decrement id ref count (unable to extend file properly)",
|
| 265 |
-
"output_type": "error",
|
| 266 |
-
"traceback": [
|
| 267 |
-
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
| 268 |
-
"\u001b[1;31mOSError\u001b[0m Traceback (most recent call last)",
|
| 269 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\engine\\training.py:2252\u001b[0m, in \u001b[0;36mModel.save_weights\u001b[1;34m(self, filepath, overwrite, save_format, options)\u001b[0m\n\u001b[0;32m 2251\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m h5py\u001b[38;5;241m.\u001b[39mFile(filepath, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m-> 2252\u001b[0m \u001b[43mhdf5_format\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_weights_to_hdf5_group\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayers\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 2253\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
|
| 270 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\saving\\hdf5_format.py:646\u001b[0m, in \u001b[0;36msave_weights_to_hdf5_group\u001b[1;34m(f, layers)\u001b[0m\n\u001b[0;32m 645\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 646\u001b[0m param_dset[:] \u001b[38;5;241m=\u001b[39m val\n",
|
| 271 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 272 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 273 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\h5py\\_hl\\dataset.py:999\u001b[0m, in \u001b[0;36mDataset.__setitem__\u001b[1;34m(self, args, val)\u001b[0m\n\u001b[0;32m 998\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m fspace \u001b[38;5;129;01min\u001b[39;00m selection\u001b[38;5;241m.\u001b[39mbroadcast(mshape):\n\u001b[1;32m--> 999\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmspace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfspace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdxpl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dxpl\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 274 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 275 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 276 |
-
"File \u001b[1;32mh5py\\\\h5d.pyx:282\u001b[0m, in \u001b[0;36mh5py.h5d.DatasetID.write\u001b[1;34m()\u001b[0m\n",
|
| 277 |
-
"File \u001b[1;32mh5py\\\\_proxy.pyx:115\u001b[0m, in \u001b[0;36mh5py._proxy.dset_rw\u001b[1;34m()\u001b[0m\n",
|
| 278 |
-
"\u001b[1;31mOSError\u001b[0m: [Errno 28] Can't write data (file write failed: time = Wed May 7 10:48:36 2025\n, filename = 'mf-fine-tuned-protbert\\weights-10-0.04.hdf5', file descriptor = 4, errno = 28, error message = 'No space left on device', buf = 000002CC552FF040, total write size = 4194304, bytes this sub-write = 4194304, bytes actually written = 18446744073709551615, offset = 1180551864)",
|
| 279 |
-
"\nDuring handling of the above exception, another exception occurred:\n",
|
| 280 |
-
"\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
| 281 |
-
"Cell \u001b[1;32mIn[2], line 119\u001b[0m\n\u001b[0;32m 113\u001b[0m model \u001b[38;5;241m=\u001b[39m t\u001b[38;5;241m.\u001b[39mget_classifier()\n\u001b[0;32m 114\u001b[0m learner \u001b[38;5;241m=\u001b[39m ktrain\u001b[38;5;241m.\u001b[39mget_learner(model,\n\u001b[0;32m 115\u001b[0m train_data\u001b[38;5;241m=\u001b[39mtrn,\n\u001b[0;32m 116\u001b[0m val_data\u001b[38;5;241m=\u001b[39mval,\n\u001b[0;32m 117\u001b[0m batch_size\u001b[38;5;241m=\u001b[39mBATCH_SIZE)\n\u001b[1;32m--> 119\u001b[0m \u001b[43mlearner\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautofit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-5\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 120\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 121\u001b[0m \u001b[43m \u001b[49m\u001b[43mearly_stopping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 122\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmf-fine-tuned-protbert\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
| 282 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\core.py:1239\u001b[0m, in \u001b[0;36mLearner.autofit\u001b[1;34m(self, lr, epochs, early_stopping, reduce_on_plateau, reduce_factor, cycle_momentum, max_momentum, min_momentum, monitor, checkpoint_folder, class_weight, callbacks, steps_per_epoch, verbose)\u001b[0m\n\u001b[0;32m 1234\u001b[0m policy \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtriangular learning rate\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1235\u001b[0m U\u001b[38;5;241m.\u001b[39mvprint(\n\u001b[0;32m 1236\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbegin training using \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m policy with max lr of \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m...\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (policy, lr),\n\u001b[0;32m 1237\u001b[0m verbose\u001b[38;5;241m=\u001b[39mverbose,\n\u001b[0;32m 1238\u001b[0m )\n\u001b[1;32m-> 1239\u001b[0m hist \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1240\u001b[0m \u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1241\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1242\u001b[0m \u001b[43m \u001b[49m\u001b[43mearly_stopping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mearly_stopping\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1243\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheckpoint_folder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1244\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1245\u001b[0m \u001b[43m \u001b[49m\u001b[43mclass_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclass_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1246\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkcallbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1247\u001b[0m \u001b[43m \u001b[49m\u001b[43msteps_per_epoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msteps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1248\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1249\u001b[0m hist\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m clr\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m 1250\u001b[0m hist\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miterations\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m clr\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miterations\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
|
| 283 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\core.py:1650\u001b[0m, in \u001b[0;36mGenLearner.fit\u001b[1;34m(self, lr, n_cycles, cycle_len, cycle_mult, lr_decay, checkpoint_folder, early_stopping, class_weight, callbacks, steps_per_epoch, verbose)\u001b[0m\n\u001b[0;32m 1648\u001b[0m warnings\u001b[38;5;241m.\u001b[39mfilterwarnings(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mignore\u001b[39m\u001b[38;5;124m\"\u001b[39m, message\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.*Check your callbacks.*\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 1649\u001b[0m fit_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mfit\n\u001b[1;32m-> 1650\u001b[0m hist \u001b[38;5;241m=\u001b[39m \u001b[43mfit_fn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1651\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_data\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1652\u001b[0m \u001b[43m \u001b[49m\u001b[43msteps_per_epoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msteps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1653\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidation_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1654\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1655\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_data\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mval_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1656\u001b[0m \u001b[43m \u001b[49m\u001b[43mworkers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mworkers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1657\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_multiprocessing\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_multiprocessing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1658\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1659\u001b[0m \u001b[43m \u001b[49m\u001b[43mshuffle\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 1660\u001b[0m \u001b[43m \u001b[49m\u001b[43mclass_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclass_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1661\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkcallbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1662\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1663\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m sgdr \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 1664\u001b[0m hist\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m sgdr\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
|
| 284 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\engine\\training.py:1230\u001b[0m, in \u001b[0;36mModel.fit\u001b[1;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[0;32m 1227\u001b[0m val_logs \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m name: val \u001b[38;5;28;01mfor\u001b[39;00m name, val \u001b[38;5;129;01min\u001b[39;00m val_logs\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m 1228\u001b[0m epoch_logs\u001b[38;5;241m.\u001b[39mupdate(val_logs)\n\u001b[1;32m-> 1230\u001b[0m \u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mon_epoch_end\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch_logs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1231\u001b[0m training_logs \u001b[38;5;241m=\u001b[39m epoch_logs\n\u001b[0;32m 1232\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstop_training:\n",
|
| 285 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\callbacks.py:413\u001b[0m, in \u001b[0;36mCallbackList.on_epoch_end\u001b[1;34m(self, epoch, logs)\u001b[0m\n\u001b[0;32m 411\u001b[0m logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_logs(logs)\n\u001b[0;32m 412\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m callback \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallbacks:\n\u001b[1;32m--> 413\u001b[0m \u001b[43mcallback\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mon_epoch_end\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlogs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 286 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\callbacks.py:1368\u001b[0m, in \u001b[0;36mModelCheckpoint.on_epoch_end\u001b[1;34m(self, epoch, logs)\u001b[0m\n\u001b[0;32m 1366\u001b[0m \u001b[38;5;66;03m# pylint: disable=protected-access\u001b[39;00m\n\u001b[0;32m 1367\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msave_freq \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mepoch\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m-> 1368\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlogs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlogs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 287 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\callbacks.py:1431\u001b[0m, in \u001b[0;36mModelCheckpoint._save_model\u001b[1;34m(self, epoch, batch, logs)\u001b[0m\n\u001b[0;32m 1429\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m%05d\u001b[39;00m\u001b[38;5;124m: saving model to \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m%\u001b[39m (epoch \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m, filepath))\n\u001b[0;32m 1430\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msave_weights_only:\n\u001b[1;32m-> 1431\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_weights\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1432\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilepath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverwrite\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_options\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1433\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 1434\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39msave(filepath, overwrite\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, options\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_options)\n",
|
| 288 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\engine\\training.py:2252\u001b[0m, in \u001b[0;36mModel.save_weights\u001b[1;34m(self, filepath, overwrite, save_format, options)\u001b[0m\n\u001b[0;32m 2250\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m save_format \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mh5\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m 2251\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m h5py\u001b[38;5;241m.\u001b[39mFile(filepath, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m-> 2252\u001b[0m hdf5_format\u001b[38;5;241m.\u001b[39msave_weights_to_hdf5_group(f, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers)\n\u001b[0;32m 2253\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 2254\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tf\u001b[38;5;241m.\u001b[39mexecuting_eagerly():\n",
|
| 289 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 290 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 291 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\h5py\\_hl\\files.py:599\u001b[0m, in \u001b[0;36mFile.__exit__\u001b[1;34m(self, *args)\u001b[0m\n\u001b[0;32m 596\u001b[0m \u001b[38;5;129m@with_phil\u001b[39m\n\u001b[0;32m 597\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__exit__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs):\n\u001b[0;32m 598\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid:\n\u001b[1;32m--> 599\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclose\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 292 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\h5py\\_hl\\files.py:581\u001b[0m, in \u001b[0;36mFile.close\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 575\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid\u001b[38;5;241m.\u001b[39mvalid:\n\u001b[0;32m 576\u001b[0m \u001b[38;5;66;03m# We have to explicitly murder all open objects related to the file\u001b[39;00m\n\u001b[0;32m 577\u001b[0m \n\u001b[0;32m 578\u001b[0m \u001b[38;5;66;03m# Close file-resident objects first, then the files.\u001b[39;00m\n\u001b[0;32m 579\u001b[0m \u001b[38;5;66;03m# Otherwise we get errors in MPI mode.\u001b[39;00m\n\u001b[0;32m 580\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid\u001b[38;5;241m.\u001b[39m_close_open_objects(h5f\u001b[38;5;241m.\u001b[39mOBJ_LOCAL \u001b[38;5;241m|\u001b[39m \u001b[38;5;241m~\u001b[39mh5f\u001b[38;5;241m.\u001b[39mOBJ_FILE)\n\u001b[1;32m--> 581\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_close_open_objects\u001b[49m\u001b[43m(\u001b[49m\u001b[43mh5f\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mOBJ_LOCAL\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m|\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mh5f\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mOBJ_FILE\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 583\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid\u001b[38;5;241m.\u001b[39mclose()\n\u001b[0;32m 584\u001b[0m _objects\u001b[38;5;241m.\u001b[39mnonlocal_close()\n",
|
| 293 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 294 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 295 |
-
"File \u001b[1;32mh5py\\\\h5f.pyx:355\u001b[0m, in \u001b[0;36mh5py.h5f.FileID._close_open_objects\u001b[1;34m()\u001b[0m\n",
|
| 296 |
-
"\u001b[1;31mRuntimeError\u001b[0m: Can't decrement id ref count (unable to extend file properly)"
|
| 297 |
-
]
|
| 298 |
-
}
|
| 299 |
-
],
|
| 300 |
-
"source": [
|
| 301 |
-
"import pandas as pd\n",
|
| 302 |
-
"import numpy as np\n",
|
| 303 |
-
"from tqdm import tqdm\n",
|
| 304 |
-
"import random\n",
|
| 305 |
-
"import os\n",
|
| 306 |
-
"import ktrain\n",
|
| 307 |
-
"from ktrain import text\n",
|
| 308 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 309 |
-
"\n",
|
| 310 |
-
"\n",
|
| 311 |
-
"# PAM1\n",
|
| 312 |
-
"# PAM matrix model of protein evolution\n",
|
| 313 |
-
"# DOI:10.1093/oxfordjournals.molbev.a040360\n",
|
| 314 |
-
"pam_data = {\n",
|
| 315 |
-
" 'A': [9948, 19, 27, 42, 31, 46, 50, 92, 17, 7, 40, 88, 42, 41, 122, 279, 255, 9, 72, 723],\n",
|
| 316 |
-
" 'R': [14, 9871, 24, 38, 37, 130, 38, 62, 49, 4, 58, 205, 26, 33, 47, 103, 104, 5, 36, 52],\n",
|
| 317 |
-
" 'N': [20, 22, 9860, 181, 29, 36, 41, 67, 31, 5, 22, 49, 23, 10, 33, 83, 66, 3, 43, 32],\n",
|
| 318 |
-
" 'D': [40, 34, 187, 9818, 11, 63, 98, 61, 23, 5, 25, 54, 43, 13, 27, 88, 55, 4, 29, 36],\n",
|
| 319 |
-
" 'C': [20, 16, 26, 9, 9987, 10, 17, 37, 12, 2, 16, 26, 10, 19, 27, 26, 25, 2, 6, 67],\n",
|
| 320 |
-
" 'Q': [29, 118, 29, 49, 8, 9816, 72, 55, 36, 4, 60, 158, 35, 22, 39, 86, 74, 3, 34, 28],\n",
|
| 321 |
-
" 'E': [35, 29, 41, 101, 12, 71, 9804, 56, 33, 5, 36, 107, 42, 20, 38, 87, 69, 4, 30, 42],\n",
|
| 322 |
-
" 'G': [96, 61, 77, 70, 38, 51, 58, 9868, 26, 6, 37, 53, 39, 28, 69, 134, 116, 5, 47, 60],\n",
|
| 323 |
-
" 'H': [17, 53, 33, 19, 15, 39, 34, 24, 9907, 3, 32, 57, 24, 15, 27, 47, 43, 2, 22, 19],\n",
|
| 324 |
-
" 'I': [6, 3, 6, 6, 3, 5, 6, 7, 3, 9973, 23, 13, 12, 41, 93, 84, 115, 3, 8, 102],\n",
|
| 325 |
-
" 'L': [26, 39, 17, 15, 7, 33, 22, 20, 19, 27, 9864, 49, 24, 78, 117, 148, 193, 5, 24, 70],\n",
|
| 326 |
-
" 'K': [60, 198, 43, 52, 12, 142, 96, 53, 42, 10, 63, 9710, 33, 26, 54, 109, 102, 5, 43, 42],\n",
|
| 327 |
-
" 'M': [21, 22, 15, 18, 6, 20, 18, 18, 17, 11, 27, 32, 9945, 26, 34, 61, 71, 3, 12, 31],\n",
|
| 328 |
-
" 'F': [18, 17, 8, 6, 8, 11, 10, 16, 10, 44, 92, 24, 29, 9899, 89, 88, 142, 7, 14, 68],\n",
|
| 329 |
-
" 'P': [97, 47, 35, 29, 23, 35, 38, 57, 21, 24, 47, 56, 28, 76, 9785, 115, 77, 4, 24, 35],\n",
|
| 330 |
-
" 'S': [241, 87, 76, 73, 17, 56, 60, 99, 32, 13, 69, 92, 42, 67, 100, 9605, 212, 8, 63, 70],\n",
|
| 331 |
-
" 'T': [186, 78, 54, 37, 14, 42, 42, 83, 28, 23, 84, 85, 53, 93, 66, 182, 9676, 8, 39, 90],\n",
|
| 332 |
-
" 'W': [2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 5, 3, 4, 4, 9960, 3, 4],\n",
|
| 333 |
-
" 'Y': [29, 21, 17, 9, 4, 13, 9, 21, 10, 7, 20, 17, 11, 23, 19, 41, 31, 3, 9935, 23],\n",
|
| 334 |
-
" 'V': [368, 27, 18, 18, 50, 23, 34, 64, 15, 85, 72, 42, 33, 88, 42, 112, 137, 4, 20, 9514]\n",
|
| 335 |
-
"}\n",
|
| 336 |
-
"pam_raw = pd.DataFrame(pam_data, index=list(pam_data.keys()))\n",
|
| 337 |
-
"pam_matrix = pam_raw.div(pam_raw.sum(axis=1), axis=0)\n",
|
| 338 |
-
"list_amino = pam_raw.columns.tolist()\n",
|
| 339 |
-
"pam_dict = {\n",
|
| 340 |
-
" aa: {sub: pam_matrix.loc[aa, sub] for sub in list_amino}\n",
|
| 341 |
-
" for aa in list_amino\n",
|
| 342 |
-
"}\n",
|
| 343 |
-
"\n",
|
| 344 |
-
"def pam1_substitution(aa):\n",
|
| 345 |
-
" if aa not in pam_dict:\n",
|
| 346 |
-
" return aa\n",
|
| 347 |
-
" subs = list(pam_dict[aa].keys())\n",
|
| 348 |
-
" probs = list(pam_dict[aa].values())\n",
|
| 349 |
-
" return np.random.choice(subs, p=probs)\n",
|
| 350 |
-
"\n",
|
| 351 |
-
"def augment_sequence(seq, sub_prob=0.05):\n",
|
| 352 |
-
" return ''.join([pam1_substitution(aa) if random.random() < sub_prob else aa for aa in seq])\n",
|
| 353 |
-
"\n",
|
| 354 |
-
"def slice_sequence(seq, win=500, min_overlap=250):\n",
|
| 355 |
-
" if len(seq) <= win:\n",
|
| 356 |
-
" return [seq]\n",
|
| 357 |
-
" slices, start = [], 0\n",
|
| 358 |
-
" while start + win <= len(seq):\n",
|
| 359 |
-
" slices.append(seq[start:start+win])\n",
|
| 360 |
-
" start += win\n",
|
| 361 |
-
" leftover = seq[start:]\n",
|
| 362 |
-
" if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
|
| 363 |
-
" extra = slices[-1][-min_overlap:] + leftover\n",
|
| 364 |
-
" slices.append(extra)\n",
|
| 365 |
-
" return slices\n",
|
| 366 |
-
"\n",
|
| 367 |
-
"def generate_data(df, augment=False):\n",
|
| 368 |
-
" X, y = [], []\n",
|
| 369 |
-
" label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
|
| 370 |
-
" for _, row in tqdm(df.iterrows(), total=len(df)):\n",
|
| 371 |
-
" seq = row[\"sequence\"]\n",
|
| 372 |
-
" if augment:\n",
|
| 373 |
-
" seq = augment_sequence(seq)\n",
|
| 374 |
-
" seq_slices = slice_sequence(seq)\n",
|
| 375 |
-
" X.extend(seq_slices)\n",
|
| 376 |
-
" lbl = row[label_cols].values.astype(int)\n",
|
| 377 |
-
" y.extend([lbl] * len(seq_slices))\n",
|
| 378 |
-
" return X, np.array(y), label_cols\n",
|
| 379 |
-
"\n",
|
| 380 |
-
"def format_sequence(seq): return \" \".join(list(seq))\n",
|
| 381 |
-
"\n",
|
| 382 |
-
"# Função para carregar e binarizar\n",
|
| 383 |
-
"def load_and_binarize(csv_path, mlb=None):\n",
|
| 384 |
-
" df = pd.read_csv(csv_path)\n",
|
| 385 |
-
" df[\"go_terms\"] = df[\"go_terms\"].str.split(\";\")\n",
|
| 386 |
-
" if mlb is None:\n",
|
| 387 |
-
" mlb = MultiLabelBinarizer()\n",
|
| 388 |
-
" labels = mlb.fit_transform(df[\"go_terms\"])\n",
|
| 389 |
-
" else:\n",
|
| 390 |
-
" labels = mlb.transform(df[\"go_terms\"])\n",
|
| 391 |
-
" labels_df = pd.DataFrame(labels, columns=mlb.classes_)\n",
|
| 392 |
-
" df = df.reset_index(drop=True).join(labels_df)\n",
|
| 393 |
-
" return df, mlb\n",
|
| 394 |
-
"\n",
|
| 395 |
-
"# Carregar os dados\n",
|
| 396 |
-
"df_train, mlb = load_and_binarize(\"data/mf-training.csv\")\n",
|
| 397 |
-
"df_val, _ = load_and_binarize(\"data/mf-validation.csv\", mlb=mlb)\n",
|
| 398 |
-
"\n",
|
| 399 |
-
"# Gerar com augmentation no treino\n",
|
| 400 |
-
"X_train, y_train, term_cols = generate_data(df_train, augment=True)\n",
|
| 401 |
-
"X_val, y_val, _ = generate_data(df_val, augment=False)\n",
|
| 402 |
-
"\n",
|
| 403 |
-
"# Preparar texto para tokenizer\n",
|
| 404 |
-
"X_train_fmt = list(map(format_sequence, X_train))\n",
|
| 405 |
-
"X_val_fmt = list(map(format_sequence, X_val))\n",
|
| 406 |
-
"\n",
|
| 407 |
-
"# Fine-tune ProtBERT\n",
|
| 408 |
-
"# https://huggingface.co/Rostlab/prot_bert\n",
|
| 409 |
-
"# https://doi.org/10.1093/bioinformatics/btac020\n",
|
| 410 |
-
"# dados de treino-> UniRef100 (216 milhões de sequências)\n",
|
| 411 |
-
"MODEL_NAME = \"Rostlab/prot_bert\"\n",
|
| 412 |
-
"MAX_LEN = 512\n",
|
| 413 |
-
"BATCH_SIZE = 1\n",
|
| 414 |
-
"\n",
|
| 415 |
-
"t = text.Transformer(MODEL_NAME, maxlen=MAX_LEN, classes=term_cols)\n",
|
| 416 |
-
"trn = t.preprocess_train(X_train_fmt, y_train)\n",
|
| 417 |
-
"val = t.preprocess_test(X_val_fmt, y_val)\n",
|
| 418 |
-
"\n",
|
| 419 |
-
"model = t.get_classifier()\n",
|
| 420 |
-
"learner = ktrain.get_learner(model,\n",
|
| 421 |
-
" train_data=trn,\n",
|
| 422 |
-
" val_data=val,\n",
|
| 423 |
-
" batch_size=BATCH_SIZE)\n",
|
| 424 |
-
"\n",
|
| 425 |
-
"learner.autofit(lr=1e-5,\n",
|
| 426 |
-
" epochs=10,\n",
|
| 427 |
-
" early_stopping=1,\n",
|
| 428 |
-
" checkpoint_folder=\"mf-fine-tuned-protbert\")\n"
|
| 429 |
-
]
|
| 430 |
-
},
|
| 431 |
-
{
|
| 432 |
-
"cell_type": "code",
|
| 433 |
-
"execution_count": 7,
|
| 434 |
-
"id": "c66774b3-6cf0-41c5-bb01-9467a5283102",
|
| 435 |
-
"metadata": {},
|
| 436 |
-
"outputs": [
|
| 437 |
-
{
|
| 438 |
-
"name": "stdout",
|
| 439 |
-
"output_type": "stream",
|
| 440 |
-
"text": [
|
| 441 |
-
"✅ Existe: weights/mf-fine-tuned-protbert-epoch10\n",
|
| 442 |
-
"📁 Conteúdo:\n",
|
| 443 |
-
" - config.json\n",
|
| 444 |
-
" - tf_model.h5\n"
|
| 445 |
-
]
|
| 446 |
-
}
|
| 447 |
-
],
|
| 448 |
-
"source": [
|
| 449 |
-
"import os\n",
|
| 450 |
-
"\n",
|
| 451 |
-
"path = \"weights/mf-fine-tuned-protbert-epoch10\"\n",
|
| 452 |
-
"\n",
|
| 453 |
-
"if os.path.exists(path):\n",
|
| 454 |
-
" print(f\"✅ Existe: {path}\")\n",
|
| 455 |
-
" print(\"📁 Conteúdo:\")\n",
|
| 456 |
-
" for f in os.listdir(path):\n",
|
| 457 |
-
" print(\" -\", f)\n",
|
| 458 |
-
"else:\n",
|
| 459 |
-
" print(f\"❌ Não existe: {path}\")\n",
|
| 460 |
-
"\n"
|
| 461 |
-
]
|
| 462 |
-
},
|
| 463 |
-
{
|
| 464 |
-
"cell_type": "code",
|
| 465 |
-
"execution_count": 19,
|
| 466 |
-
"id": "9b39c439-5708-4787-bfee-d3a4d3aa190d",
|
| 467 |
-
"metadata": {},
|
| 468 |
-
"outputs": [
|
| 469 |
-
{
|
| 470 |
-
"name": "stderr",
|
| 471 |
-
"output_type": "stream",
|
| 472 |
-
"text": [
|
| 473 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 474 |
-
" from .autonotebook import tqdm as notebook_tqdm\n",
|
| 475 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
|
| 476 |
-
" _torch_pytree._register_pytree_node(\n",
|
| 477 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
| 478 |
-
" warnings.warn(\n",
|
| 479 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:309: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
|
| 480 |
-
" _torch_pytree._register_pytree_node(\n",
|
| 481 |
-
"Some layers from the model checkpoint at weights/mf-fine-tuned-protbert-epoch10 were not used when initializing TFBertModel: ['classifier', 'dropout_183']\n",
|
| 482 |
-
"- This IS expected if you are initializing TFBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
| 483 |
-
"- This IS NOT expected if you are initializing TFBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
| 484 |
-
"All the layers of TFBertModel were initialized from the model checkpoint at weights/mf-fine-tuned-protbert-epoch10.\n",
|
| 485 |
-
"If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.\n"
|
| 486 |
-
]
|
| 487 |
-
},
|
| 488 |
-
{
|
| 489 |
-
"name": "stdout",
|
| 490 |
-
"output_type": "stream",
|
| 491 |
-
"text": [
|
| 492 |
-
"✓ Tokenizer base e modelo fine-tuned carregados com sucesso\n"
|
| 493 |
-
]
|
| 494 |
-
},
|
| 495 |
-
{
|
| 496 |
-
"name": "stderr",
|
| 497 |
-
"output_type": "stream",
|
| 498 |
-
"text": [
|
| 499 |
-
"Processando data/mf-training.csv: 0%| | 25/31142 [00:06<2:23:28, 3.61it/s]\n"
|
| 500 |
-
]
|
| 501 |
-
},
|
| 502 |
-
{
|
| 503 |
-
"ename": "KeyboardInterrupt",
|
| 504 |
-
"evalue": "",
|
| 505 |
-
"output_type": "error",
|
| 506 |
-
"traceback": [
|
| 507 |
-
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
| 508 |
-
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 509 |
-
"Cell \u001b[1;32mIn[19], line 78\u001b[0m\n\u001b[0;32m 75\u001b[0m \u001b[38;5;66;03m# --- 4. Aplicar -----------------------------------------------------------\u001b[39;00m\n\u001b[0;32m 76\u001b[0m os\u001b[38;5;241m.\u001b[39mmakedirs(OUT_DIR, exist_ok\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m---> 78\u001b[0m \u001b[43mprocess_split\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdata/mf-training.csv\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mOUT_DIR\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtrain_protbert.pkl\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 79\u001b[0m process_split(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata/mf-validation.csv\u001b[39m\u001b[38;5;124m\"\u001b[39m, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(OUT_DIR, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mval_protbert.pkl\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[0;32m 80\u001b[0m process_split(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata/mf-test.csv\u001b[39m\u001b[38;5;124m\"\u001b[39m, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(OUT_DIR, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_protbert.pkl\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n",
|
| 510 |
-
"Cell \u001b[1;32mIn[19], line 61\u001b[0m, in \u001b[0;36mprocess_split\u001b[1;34m(csv_path, out_path)\u001b[0m\n\u001b[0;32m 59\u001b[0m embeds\u001b[38;5;241m.\u001b[39mappend(prot_embed\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32))\n\u001b[0;32m 60\u001b[0m labels\u001b[38;5;241m.\u001b[39mappend(row[label_cols]\u001b[38;5;241m.\u001b[39mvalues\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mint8))\n\u001b[1;32m---> 61\u001b[0m gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[0;32m 63\u001b[0m embeds \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mvstack(embeds)\n\u001b[0;32m 64\u001b[0m labels \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mvstack(labels)\n",
|
| 511 |
-
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
| 512 |
-
]
|
| 513 |
-
}
|
| 514 |
-
],
|
| 515 |
-
"source": [
|
| 516 |
-
"import os\n",
|
| 517 |
-
"import pandas as pd\n",
|
| 518 |
-
"import numpy as np\n",
|
| 519 |
-
"from tqdm import tqdm\n",
|
| 520 |
-
"import joblib\n",
|
| 521 |
-
"import gc\n",
|
| 522 |
-
"from transformers import AutoTokenizer, TFAutoModel\n",
|
| 523 |
-
"\n",
|
| 524 |
-
"# --- 1. Parâmetros --------------------------------------------------------\n",
|
| 525 |
-
"MODEL_DIR = \"weights/mf-fine-tuned-protbert-epoch10\"\n",
|
| 526 |
-
"BASE_MODEL = \"Rostlab/prot_bert\"\n",
|
| 527 |
-
"OUT_DIR = \"embeddings\"\n",
|
| 528 |
-
"BATCH_TOK = 16\n",
|
| 529 |
-
"\n",
|
| 530 |
-
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, do_lower_case=False)\n",
|
| 531 |
-
"model = TFAutoModel.from_pretrained(MODEL_DIR, from_pt=False)\n",
|
| 532 |
-
"\n",
|
| 533 |
-
"print(\"✓ Tokenizer base e modelo fine-tuned carregados com sucesso\")\n",
|
| 534 |
-
"\n",
|
| 535 |
-
"# --- 3. Funções auxiliares ------------------------------------------------\n",
|
| 536 |
-
"def format_sequence(seq):\n",
|
| 537 |
-
" return \" \".join(list(seq))\n",
|
| 538 |
-
"\n",
|
| 539 |
-
"def slice_sequence(seq, win=500, min_overlap=250):\n",
|
| 540 |
-
" if len(seq) <= win:\n",
|
| 541 |
-
" return [seq]\n",
|
| 542 |
-
" slices, start = [], 0\n",
|
| 543 |
-
" while start + win <= len(seq):\n",
|
| 544 |
-
" slices.append(seq[start:start+win])\n",
|
| 545 |
-
" start += win\n",
|
| 546 |
-
" leftover = seq[start:]\n",
|
| 547 |
-
" if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
|
| 548 |
-
" extra = slices[-1][-min_overlap:] + leftover\n",
|
| 549 |
-
" slices.append(extra)\n",
|
| 550 |
-
" return slices\n",
|
| 551 |
-
"\n",
|
| 552 |
-
"def get_embeddings(batch, tokenizer, model):\n",
|
| 553 |
-
" tokens = tokenizer(batch, return_tensors=\"tf\", padding=True, truncation=True, max_length=512)\n",
|
| 554 |
-
" output = model(**tokens)\n",
|
| 555 |
-
" return output.last_hidden_state[:, 0, :].numpy()\n",
|
| 556 |
-
"\n",
|
| 557 |
-
"def process_split(csv_path, out_path):\n",
|
| 558 |
-
" df = pd.read_csv(csv_path)\n",
|
| 559 |
-
" label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
|
| 560 |
-
" prot_ids, embeds, labels = [], [], []\n",
|
| 561 |
-
"\n",
|
| 562 |
-
" for _, row in tqdm(df.iterrows(), total=len(df), desc=f\"Processando {csv_path}\"):\n",
|
| 563 |
-
" slices = slice_sequence(row[\"sequence\"])\n",
|
| 564 |
-
" slices_fmt = list(map(format_sequence, slices))\n",
|
| 565 |
-
"\n",
|
| 566 |
-
" slice_embeds = []\n",
|
| 567 |
-
" for i in range(0, len(slices_fmt), BATCH_TOK):\n",
|
| 568 |
-
" batch = slices_fmt[i:i+BATCH_TOK]\n",
|
| 569 |
-
" slice_embeds.append(get_embeddings(batch, tokenizer, model))\n",
|
| 570 |
-
" slice_embeds = np.vstack(slice_embeds)\n",
|
| 571 |
-
"\n",
|
| 572 |
-
" prot_embed = slice_embeds.mean(axis=0)\n",
|
| 573 |
-
" prot_ids.append(row[\"protein_id\"])\n",
|
| 574 |
-
" embeds.append(prot_embed.astype(np.float32))\n",
|
| 575 |
-
" labels.append(row[label_cols].values.astype(np.int8))\n",
|
| 576 |
-
" gc.collect()\n",
|
| 577 |
-
"\n",
|
| 578 |
-
" embeds = np.vstack(embeds)\n",
|
| 579 |
-
" labels = np.vstack(labels)\n",
|
| 580 |
-
"\n",
|
| 581 |
-
" joblib.dump({\n",
|
| 582 |
-
" \"protein_ids\": prot_ids,\n",
|
| 583 |
-
" \"embeddings\": embeds,\n",
|
| 584 |
-
" \"labels\": labels,\n",
|
| 585 |
-
" \"go_terms\": label_cols\n",
|
| 586 |
-
" }, out_path, compress=3)\n",
|
| 587 |
-
"\n",
|
| 588 |
-
" print(f\"✓ Guardado {out_path} — {embeds.shape[0]} proteínas\")\n",
|
| 589 |
-
"\n",
|
| 590 |
-
"# --- 4. Aplicar -----------------------------------------------------------\n",
|
| 591 |
-
"os.makedirs(OUT_DIR, exist_ok=True)\n",
|
| 592 |
-
"\n",
|
| 593 |
-
"process_split(\"data/mf-training.csv\", os.path.join(OUT_DIR, \"train_protbert.pkl\"))\n",
|
| 594 |
-
"process_split(\"data/mf-validation.csv\", os.path.join(OUT_DIR, \"val_protbert.pkl\"))\n",
|
| 595 |
-
"process_split(\"data/mf-test.csv\", os.path.join(OUT_DIR, \"test_protbert.pkl\"))\n"
|
| 596 |
-
]
|
| 597 |
-
},
|
| 598 |
-
{
|
| 599 |
-
"cell_type": "code",
|
| 600 |
-
"execution_count": 27,
|
| 601 |
-
"id": "ad0c5421-e0a1-4a6a-8ace-2c69aeab0e0d",
|
| 602 |
-
"metadata": {},
|
| 603 |
-
"outputs": [
|
| 604 |
-
{
|
| 605 |
-
"name": "stdout",
|
| 606 |
-
"output_type": "stream",
|
| 607 |
-
"text": [
|
| 608 |
-
"✓ Corrigido: embeddings/train_protbert.pkl — 31142 exemplos, 597 GO terms\n",
|
| 609 |
-
"✓ Corrigido: embeddings/val_protbert.pkl — 1724 exemplos, 597 GO terms\n",
|
| 610 |
-
"✓ Corrigido: embeddings/test_protbert.pkl — 1724 exemplos, 597 GO terms\n"
|
| 611 |
-
]
|
| 612 |
-
}
|
| 613 |
-
],
|
| 614 |
-
"source": [
|
| 615 |
-
"import pandas as pd\n",
|
| 616 |
-
"import joblib\n",
|
| 617 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 618 |
-
"\n",
|
| 619 |
-
"# --- 1. Obter GO terms do ficheiro de teste --------------------------------\n",
|
| 620 |
-
"df_test = pd.read_csv(\"data/mf-test.csv\")\n",
|
| 621 |
-
"test_terms = sorted(set(term for row in df_test[\"go_terms\"].str.split(\";\") for term in row))\n",
|
| 622 |
-
"\n",
|
| 623 |
-
"# --- 2. Função para corrigir um .pkl com base nos GO terms do teste --------\n",
|
| 624 |
-
"def patch_to_common_terms(csv_path, pkl_path, common_terms):\n",
|
| 625 |
-
" df = pd.read_csv(csv_path)\n",
|
| 626 |
-
" terms_split = df[\"go_terms\"].str.split(\";\")\n",
|
| 627 |
-
" \n",
|
| 628 |
-
" # Apenas termos presentes nos common_terms\n",
|
| 629 |
-
" terms_filtered = terms_split.apply(lambda lst: [t for t in lst if t in common_terms])\n",
|
| 630 |
-
" \n",
|
| 631 |
-
" mlb = MultiLabelBinarizer(classes=common_terms)\n",
|
| 632 |
-
" Y = mlb.fit_transform(terms_filtered)\n",
|
| 633 |
-
"\n",
|
| 634 |
-
" data = joblib.load(pkl_path)\n",
|
| 635 |
-
" data[\"labels\"] = Y\n",
|
| 636 |
-
" data[\"go_terms\"] = mlb.classes_.tolist()\n",
|
| 637 |
-
" \n",
|
| 638 |
-
" joblib.dump(data, pkl_path, compress=3)\n",
|
| 639 |
-
" print(f\"✓ Corrigido: {pkl_path} — {Y.shape[0]} exemplos, {Y.shape[1]} GO terms\")\n",
|
| 640 |
-
"\n",
|
| 641 |
-
"# --- 3. Aplicar às 3 partições --------------------------------------------\n",
|
| 642 |
-
"patch_to_common_terms(\"data/mf-training.csv\", \"embeddings/train_protbert.pkl\", test_terms)\n",
|
| 643 |
-
"patch_to_common_terms(\"data/mf-validation.csv\", \"embeddings/val_protbert.pkl\", test_terms)\n",
|
| 644 |
-
"patch_to_common_terms(\"data/mf-test.csv\", \"embeddings/test_protbert.pkl\", test_terms)\n"
|
| 645 |
-
]
|
| 646 |
-
},
|
| 647 |
-
{
|
| 648 |
-
"cell_type": "code",
|
| 649 |
-
"execution_count": 1,
|
| 650 |
-
"id": "dbd5c35f-4a08-4906-9cf4-e1df501d1ecb",
|
| 651 |
-
"metadata": {},
|
| 652 |
-
"outputs": [],
|
| 653 |
-
"source": [
|
| 654 |
-
"import joblib\n",
|
| 655 |
-
"train = joblib.load(\"embeddings/train_protbert.pkl\")\n",
|
| 656 |
-
"val = joblib.load(\"embeddings/val_protbert.pkl\")\n",
|
| 657 |
-
"test = joblib.load(\"embeddings/test_protbert.pkl\")\n",
|
| 658 |
-
"\n",
|
| 659 |
-
"X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
|
| 660 |
-
"X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
|
| 661 |
-
"X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n"
|
| 662 |
-
]
|
| 663 |
-
},
|
| 664 |
-
{
|
| 665 |
-
"cell_type": "code",
|
| 666 |
-
"execution_count": 3,
|
| 667 |
-
"id": "1785d8a9-23fc-4490-8d71-29cc91a4cb57",
|
| 668 |
-
"metadata": {},
|
| 669 |
-
"outputs": [
|
| 670 |
-
{
|
| 671 |
-
"name": "stdout",
|
| 672 |
-
"output_type": "stream",
|
| 673 |
-
"text": [
|
| 674 |
-
"✓ Embeddings carregados: (31142, 1024) → 597 GO terms\n",
|
| 675 |
-
"Epoch 1/100\n",
|
| 676 |
-
"974/974 [==============================] - 4s 3ms/step - loss: 0.0358 - binary_accuracy: 0.9893 - val_loss: 0.0336 - val_binary_accuracy: 0.9901\n",
|
| 677 |
-
"Epoch 2/100\n",
|
| 678 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0276 - binary_accuracy: 0.9914 - val_loss: 0.0331 - val_binary_accuracy: 0.9902\n",
|
| 679 |
-
"Epoch 3/100\n",
|
| 680 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0268 - binary_accuracy: 0.9916 - val_loss: 0.0330 - val_binary_accuracy: 0.9902\n",
|
| 681 |
-
"Epoch 4/100\n",
|
| 682 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0264 - binary_accuracy: 0.9917 - val_loss: 0.0320 - val_binary_accuracy: 0.9904\n",
|
| 683 |
-
"Epoch 5/100\n",
|
| 684 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0260 - binary_accuracy: 0.9917 - val_loss: 0.0319 - val_binary_accuracy: 0.9904\n",
|
| 685 |
-
"Epoch 6/100\n",
|
| 686 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0256 - binary_accuracy: 0.9918 - val_loss: 0.0322 - val_binary_accuracy: 0.9904\n",
|
| 687 |
-
"Epoch 7/100\n",
|
| 688 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0255 - binary_accuracy: 0.9918 - val_loss: 0.0317 - val_binary_accuracy: 0.9903\n",
|
| 689 |
-
"Epoch 8/100\n",
|
| 690 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0252 - binary_accuracy: 0.9919 - val_loss: 0.0320 - val_binary_accuracy: 0.9905\n",
|
| 691 |
-
"Epoch 9/100\n",
|
| 692 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0251 - binary_accuracy: 0.9919 - val_loss: 0.0316 - val_binary_accuracy: 0.9904\n",
|
| 693 |
-
"Epoch 10/100\n",
|
| 694 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0250 - binary_accuracy: 0.9920 - val_loss: 0.0314 - val_binary_accuracy: 0.9905\n",
|
| 695 |
-
"Epoch 11/100\n",
|
| 696 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0248 - binary_accuracy: 0.9920 - val_loss: 0.0317 - val_binary_accuracy: 0.9905\n",
|
| 697 |
-
"Epoch 12/100\n",
|
| 698 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0247 - binary_accuracy: 0.9920 - val_loss: 0.0315 - val_binary_accuracy: 0.9905\n",
|
| 699 |
-
"Epoch 13/100\n",
|
| 700 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0246 - binary_accuracy: 0.9920 - val_loss: 0.0322 - val_binary_accuracy: 0.9904\n",
|
| 701 |
-
"Epoch 14/100\n",
|
| 702 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0245 - binary_accuracy: 0.9920 - val_loss: 0.0319 - val_binary_accuracy: 0.9905\n",
|
| 703 |
-
"Epoch 15/100\n",
|
| 704 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0244 - binary_accuracy: 0.9920 - val_loss: 0.0319 - val_binary_accuracy: 0.9906\n",
|
| 705 |
-
"Previsões guardadas em mf-protbert-pam1.npy\n",
|
| 706 |
-
"Modelo guardado em models/protbert_mlp.keras\n"
|
| 707 |
-
]
|
| 708 |
-
}
|
| 709 |
-
],
|
| 710 |
-
"source": [
|
| 711 |
-
"import tensorflow as tf\n",
|
| 712 |
-
"import joblib\n",
|
| 713 |
-
"import numpy as np\n",
|
| 714 |
-
"from tensorflow.keras.models import Sequential\n",
|
| 715 |
-
"from tensorflow.keras.layers import Dense, Dropout\n",
|
| 716 |
-
"from tensorflow.keras.callbacks import EarlyStopping\n",
|
| 717 |
-
"\n",
|
| 718 |
-
"# --- 1. Carregar embeddings ----------------------------------------------\n",
|
| 719 |
-
"train = joblib.load(\"embeddings/train_protbert.pkl\")\n",
|
| 720 |
-
"val = joblib.load(\"embeddings/val_protbert.pkl\")\n",
|
| 721 |
-
"test = joblib.load(\"embeddings/test_protbert.pkl\")\n",
|
| 722 |
-
"\n",
|
| 723 |
-
"X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
|
| 724 |
-
"X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
|
| 725 |
-
"X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n",
|
| 726 |
-
"\n",
|
| 727 |
-
"print(f\"✓ Embeddings carregados: {X_train.shape} → {y_train.shape[1]} GO terms\")\n",
|
| 728 |
-
"\n",
|
| 729 |
-
"# --- 2. Garantir consistência de classes ---------------------------------\n",
|
| 730 |
-
"max_classes = y_train.shape[1] # 602 GO terms (do treino)\n",
|
| 731 |
-
"\n",
|
| 732 |
-
"def pad_labels(y, target_dim=max_classes):\n",
|
| 733 |
-
" if y.shape[1] < target_dim:\n",
|
| 734 |
-
" padding = np.zeros((y.shape[0], target_dim - y.shape[1]), dtype=np.int8)\n",
|
| 735 |
-
" return np.hstack([y, padding])\n",
|
| 736 |
-
" return y\n",
|
| 737 |
-
"\n",
|
| 738 |
-
"y_val = pad_labels(y_val)\n",
|
| 739 |
-
"y_test = pad_labels(y_test)\n",
|
| 740 |
-
"\n",
|
| 741 |
-
"# --- 3. Modelo MLP ------------------------------------------------------\n",
|
| 742 |
-
"model = Sequential([\n",
|
| 743 |
-
" Dense(1024, activation=\"relu\", input_shape=(X_train.shape[1],)),\n",
|
| 744 |
-
" Dropout(0.3),\n",
|
| 745 |
-
" Dense(512, activation=\"relu\"),\n",
|
| 746 |
-
" Dropout(0.3),\n",
|
| 747 |
-
" Dense(max_classes, activation=\"sigmoid\")\n",
|
| 748 |
-
"])\n",
|
| 749 |
-
"\n",
|
| 750 |
-
"model.compile(loss=\"binary_crossentropy\",\n",
|
| 751 |
-
" optimizer=\"adam\",\n",
|
| 752 |
-
" metrics=[\"binary_accuracy\"])\n",
|
| 753 |
-
"\n",
|
| 754 |
-
"# --- 4. Early stopping e treino -----------------------------------------\n",
|
| 755 |
-
"callbacks = [\n",
|
| 756 |
-
" EarlyStopping(monitor=\"val_loss\", patience=5, restore_best_weights=True)\n",
|
| 757 |
-
"]\n",
|
| 758 |
-
"\n",
|
| 759 |
-
"model.fit(X_train, y_train,\n",
|
| 760 |
-
" validation_data=(X_val, y_val),\n",
|
| 761 |
-
" epochs=100,\n",
|
| 762 |
-
" batch_size=32,\n",
|
| 763 |
-
" callbacks=callbacks,\n",
|
| 764 |
-
" verbose=1)\n",
|
| 765 |
-
"\n",
|
| 766 |
-
"# --- 5. Previsões --------------------------------------------------------\n",
|
| 767 |
-
"y_prob = model.predict(X_test)\n",
|
| 768 |
-
"np.save(\"predictions/mf-protbert-pam1.npy\", y_prob)\n",
|
| 769 |
-
"print(\"Previsões guardadas em mf-protbert-pam1.npy\")\n",
|
| 770 |
-
"\n",
|
| 771 |
-
"# --- 6. Modelo ----------------------------------------------------------\n",
|
| 772 |
-
"model.save(\"models/protbert_mlp.keras\")\n",
|
| 773 |
-
"print(\"Modelo guardado em models/protbert_mlp.keras\")"
|
| 774 |
-
]
|
| 775 |
-
},
|
| 776 |
-
{
|
| 777 |
-
"cell_type": "code",
|
| 778 |
-
"execution_count": 30,
|
| 779 |
-
"id": "fdb66630-76dc-43a0-bd56-45052175fdba",
|
| 780 |
-
"metadata": {},
|
| 781 |
-
"outputs": [
|
| 782 |
-
{
|
| 783 |
-
"name": "stdout",
|
| 784 |
-
"output_type": "stream",
|
| 785 |
-
"text": [
|
| 786 |
-
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
|
| 787 |
-
"✓ Embeddings: (1724, 597) labels × 597 GO terms\n",
|
| 788 |
-
"\n",
|
| 789 |
-
"📊 Resultados finais (ProtBERT + PAM1 + propagação):\n",
|
| 790 |
-
"Fmax = 0.6666\n",
|
| 791 |
-
"Thr. = 0.50\n",
|
| 792 |
-
"AuPRC = 0.7028\n",
|
| 793 |
-
"Smin = 13.1745\n"
|
| 794 |
-
]
|
| 795 |
-
}
|
| 796 |
-
],
|
| 797 |
-
"source": [
|
| 798 |
-
"import numpy as np\n",
|
| 799 |
-
"from sklearn.metrics import precision_recall_curve, auc\n",
|
| 800 |
-
"from goatools.obo_parser import GODag\n",
|
| 801 |
-
"import joblib\n",
|
| 802 |
-
"import math\n",
|
| 803 |
-
"\n",
|
| 804 |
-
"# --- 1. Parâmetros -------------------------------------------------------\n",
|
| 805 |
-
"GO_FILE = \"go.obo\"\n",
|
| 806 |
-
"THRESHOLDS = np.arange(0.0, 1.01, 0.01)\n",
|
| 807 |
-
"ALPHA = 0.5\n",
|
| 808 |
-
"\n",
|
| 809 |
-
"# --- 2. Carregar dados ---------------------------------------------------\n",
|
| 810 |
-
"test = joblib.load(\"embeddings/test_protbert.pkl\")\n",
|
| 811 |
-
"y_true = test[\"labels\"]\n",
|
| 812 |
-
"terms = test[\"go_terms\"]\n",
|
| 813 |
-
"y_prob = np.load(\"predictions/mf-protbert-pam1.npy\")\n",
|
| 814 |
-
"go_dag = GODag(GO_FILE)\n",
|
| 815 |
-
"\n",
|
| 816 |
-
"print(f\"✓ Embeddings: {y_true.shape} labels × {len(terms)} GO terms\")\n",
|
| 817 |
-
"\n",
|
| 818 |
-
"# --- 3. Fmax -------------------------------------------------------------\n",
|
| 819 |
-
"def compute_fmax(y_true, y_prob, thresholds):\n",
|
| 820 |
-
" fmax, best_thr = 0, 0\n",
|
| 821 |
-
" for t in thresholds:\n",
|
| 822 |
-
" y_pred = (y_prob >= t).astype(int)\n",
|
| 823 |
-
" tp = (y_true * y_pred).sum(axis=1)\n",
|
| 824 |
-
" fp = ((1 - y_true) * y_pred).sum(axis=1)\n",
|
| 825 |
-
" fn = (y_true * (1 - y_pred)).sum(axis=1)\n",
|
| 826 |
-
" precision = tp / (tp + fp + 1e-8)\n",
|
| 827 |
-
" recall = tp / (tp + fn + 1e-8)\n",
|
| 828 |
-
" f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
|
| 829 |
-
" avg_f1 = np.mean(f1)\n",
|
| 830 |
-
" if avg_f1 > fmax:\n",
|
| 831 |
-
" fmax, best_thr = avg_f1, t\n",
|
| 832 |
-
" return fmax, best_thr\n",
|
| 833 |
-
"\n",
|
| 834 |
-
"# --- 4. AuPRC micro ------------------------------------------------------\n",
|
| 835 |
-
"def compute_auprc(y_true, y_prob):\n",
|
| 836 |
-
" precision, recall, _ = precision_recall_curve(y_true.ravel(), y_prob.ravel())\n",
|
| 837 |
-
" return auc(recall, precision)\n",
|
| 838 |
-
"\n",
|
| 839 |
-
"# --- 5. Smin -------------------------------------------------------------\n",
|
| 840 |
-
"def compute_smin(y_true, y_prob, terms, threshold, go_dag, alpha=ALPHA):\n",
|
| 841 |
-
" y_pred = (y_prob >= threshold).astype(int)\n",
|
| 842 |
-
" ic = {}\n",
|
| 843 |
-
" total = (y_true + y_pred).sum(axis=0).sum()\n",
|
| 844 |
-
" for i, term in enumerate(terms):\n",
|
| 845 |
-
" freq = (y_true[:, i] + y_pred[:, i]).sum()\n",
|
| 846 |
-
" ic[term] = -np.log((freq + 1e-8) / total)\n",
|
| 847 |
-
"\n",
|
| 848 |
-
" s_values = []\n",
|
| 849 |
-
" for true_vec, pred_vec in zip(y_true, y_pred):\n",
|
| 850 |
-
" true_terms = {terms[i] for i in np.where(true_vec)[0]}\n",
|
| 851 |
-
" pred_terms = {terms[i] for i in np.where(pred_vec)[0]}\n",
|
| 852 |
-
"\n",
|
| 853 |
-
" anc_true = set()\n",
|
| 854 |
-
" for t in true_terms:\n",
|
| 855 |
-
" if t in go_dag:\n",
|
| 856 |
-
" anc_true |= go_dag[t].get_all_parents()\n",
|
| 857 |
-
" anc_pred = set()\n",
|
| 858 |
-
" for t in pred_terms:\n",
|
| 859 |
-
" if t in go_dag:\n",
|
| 860 |
-
" anc_pred |= go_dag[t].get_all_parents()\n",
|
| 861 |
-
"\n",
|
| 862 |
-
" ru = pred_terms - true_terms\n",
|
| 863 |
-
" mi = true_terms - pred_terms\n",
|
| 864 |
-
" dist_ru = sum(ic.get(t, 0) for t in ru)\n",
|
| 865 |
-
" dist_mi = sum(ic.get(t, 0) for t in mi)\n",
|
| 866 |
-
" s = math.sqrt((alpha * dist_ru)**2 + ((1 - alpha) * dist_mi)**2)\n",
|
| 867 |
-
" s_values.append(s)\n",
|
| 868 |
-
"\n",
|
| 869 |
-
" return np.mean(s_values)\n",
|
| 870 |
-
"\n",
|
| 871 |
-
"# --- 6. Avaliar ----------------------------------------------------------\n",
|
| 872 |
-
"fmax, thr = compute_fmax(y_true, y_prob, THRESHOLDS)\n",
|
| 873 |
-
"auprc = compute_auprc(y_true, y_prob)\n",
|
| 874 |
-
"smin = compute_smin(y_true, y_prob, terms, thr, go_dag)\n",
|
| 875 |
-
"\n",
|
| 876 |
-
"print(f\"\\n📊 Resultados finais (ProtBERT + PAM1 + propagação):\")\n",
|
| 877 |
-
"print(f\"Fmax = {fmax:.4f}\")\n",
|
| 878 |
-
"print(f\"Thr. = {thr:.2f}\")\n",
|
| 879 |
-
"print(f\"AuPRC = {auprc:.4f}\")\n",
|
| 880 |
-
"print(f\"Smin = {smin:.4f}\")\n"
|
| 881 |
-
]
|
| 882 |
-
},
|
| 883 |
-
{
|
| 884 |
-
"cell_type": "code",
|
| 885 |
-
"execution_count": 3,
|
| 886 |
-
"id": "70d131ef-ef84-42ee-953b-0d3f1268694d",
|
| 887 |
-
"metadata": {},
|
| 888 |
-
"outputs": [
|
| 889 |
-
{
|
| 890 |
-
"data": {
|
| 891 |
-
"text/plain": [
|
| 892 |
-
"['data/mlb_protbert.pkl']"
|
| 893 |
-
]
|
| 894 |
-
},
|
| 895 |
-
"execution_count": 3,
|
| 896 |
-
"metadata": {},
|
| 897 |
-
"output_type": "execute_result"
|
| 898 |
-
}
|
| 899 |
-
],
|
| 900 |
-
"source": [
|
| 901 |
-
"import joblib, pickle\n",
|
| 902 |
-
"joblib.dump(mlb, \"data/mlb_protbert.pkl\")"
|
| 903 |
-
]
|
| 904 |
-
},
|
| 905 |
-
{
|
| 906 |
-
"cell_type": "code",
|
| 907 |
-
"execution_count": null,
|
| 908 |
-
"id": "9f89c3bc-6b78-4a4c-8ddd-b69c7d3d0e65",
|
| 909 |
-
"metadata": {},
|
| 910 |
-
"outputs": [],
|
| 911 |
-
"source": []
|
| 912 |
-
}
|
| 913 |
-
],
|
| 914 |
-
"metadata": {
|
| 915 |
-
"kernelspec": {
|
| 916 |
-
"display_name": "Python 3 (ipykernel)",
|
| 917 |
-
"language": "python",
|
| 918 |
-
"name": "python3"
|
| 919 |
-
},
|
| 920 |
-
"language_info": {
|
| 921 |
-
"codemirror_mode": {
|
| 922 |
-
"name": "ipython",
|
| 923 |
-
"version": 3
|
| 924 |
-
},
|
| 925 |
-
"file_extension": ".py",
|
| 926 |
-
"mimetype": "text/x-python",
|
| 927 |
-
"name": "python",
|
| 928 |
-
"nbconvert_exporter": "python",
|
| 929 |
-
"pygments_lexer": "ipython3",
|
| 930 |
-
"version": "3.8.18"
|
| 931 |
-
}
|
| 932 |
-
},
|
| 933 |
-
"nbformat": 4,
|
| 934 |
-
"nbformat_minor": 5
|
| 935 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/PAM1_protbertBFD.ipynb
DELETED
|
@@ -1,871 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 1,
|
| 6 |
-
"id": "c6dbc330-062a-48f0-8242-3f21cc1c9c2b",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [
|
| 9 |
-
{
|
| 10 |
-
"name": "stdout",
|
| 11 |
-
"output_type": "stream",
|
| 12 |
-
"text": [
|
| 13 |
-
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
|
| 14 |
-
"✓ Ficheiros criados:\n",
|
| 15 |
-
" - data/mf-training.csv : (31142, 3)\n",
|
| 16 |
-
" - data/mf-validation.csv: (1724, 3)\n",
|
| 17 |
-
" - data/mf-test.csv : (1724, 3)\n",
|
| 18 |
-
"GO terms únicos (após propagação e filtro): 602\n"
|
| 19 |
-
]
|
| 20 |
-
}
|
| 21 |
-
],
|
| 22 |
-
"source": [
|
| 23 |
-
"import pandas as pd\n",
|
| 24 |
-
"from Bio import SeqIO\n",
|
| 25 |
-
"from collections import Counter\n",
|
| 26 |
-
"from goatools.obo_parser import GODag\n",
|
| 27 |
-
"from sklearn.model_selection import train_test_split\n",
|
| 28 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 29 |
-
"from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
|
| 30 |
-
"import numpy as np\n",
|
| 31 |
-
"import os\n",
|
| 32 |
-
"\n",
|
| 33 |
-
"# --- 1. Carregar GO anotações ------------------------------------------\n",
|
| 34 |
-
"annotations = pd.read_csv(\"uniprot_sprot_exp.txt\", sep=\"\\t\", names=[\"protein_id\", \"go_term\", \"go_category\"])\n",
|
| 35 |
-
"annotations_f = annotations[annotations[\"go_category\"] == \"F\"]\n",
|
| 36 |
-
"\n",
|
| 37 |
-
"# --- 2. Carregar DAG e propagar GO terms -------------------------------\n",
|
| 38 |
-
"# propagação hierárquica\n",
|
| 39 |
-
"# https://geneontology.org/docs/download-ontology/\n",
|
| 40 |
-
"go_dag = GODag(\"go.obo\")\n",
|
| 41 |
-
"mf_terms = {t for t, o in go_dag.items() if o.namespace == \"molecular_function\"}\n",
|
| 42 |
-
"\n",
|
| 43 |
-
"def propagate_terms(term_list):\n",
|
| 44 |
-
" full = set()\n",
|
| 45 |
-
" for t in term_list:\n",
|
| 46 |
-
" if t not in go_dag:\n",
|
| 47 |
-
" continue\n",
|
| 48 |
-
" full.add(t)\n",
|
| 49 |
-
" full.update(go_dag[t].get_all_parents())\n",
|
| 50 |
-
" return list(full & mf_terms)\n",
|
| 51 |
-
"\n",
|
| 52 |
-
"# --- 3. Carregar sequências --------------------------------------------\n",
|
| 53 |
-
"seqs, ids = [], []\n",
|
| 54 |
-
"for record in SeqIO.parse(\"uniprot_sprot_exp.fasta\", \"fasta\"):\n",
|
| 55 |
-
" ids.append(record.id)\n",
|
| 56 |
-
" seqs.append(str(record.seq))\n",
|
| 57 |
-
"\n",
|
| 58 |
-
"seq_df = pd.DataFrame({\"protein_id\": ids, \"sequence\": seqs})\n",
|
| 59 |
-
"\n",
|
| 60 |
-
"# --- 4. Juntar com GO anotado e propagar -------------------------------\n",
|
| 61 |
-
"grouped = annotations_f.groupby(\"protein_id\")[\"go_term\"].apply(list).reset_index()\n",
|
| 62 |
-
"data = seq_df.merge(grouped, on=\"protein_id\")\n",
|
| 63 |
-
"data = data[data[\"go_term\"].apply(len) > 0]\n",
|
| 64 |
-
"data[\"go_term\"] = data[\"go_term\"].apply(propagate_terms)\n",
|
| 65 |
-
"data = data[data[\"go_term\"].apply(len) > 0]\n",
|
| 66 |
-
"\n",
|
| 67 |
-
"# --- 5. Filtrar GO terms raros -----------------------------------------\n",
|
| 68 |
-
"# todos os terms com menos de 50 proteinas associadas\n",
|
| 69 |
-
"all_terms = [term for sublist in data[\"go_term\"] for term in sublist]\n",
|
| 70 |
-
"term_counts = Counter(all_terms)\n",
|
| 71 |
-
"valid_terms = {term for term, count in term_counts.items() if count >= 50}\n",
|
| 72 |
-
"data[\"go_term\"] = data[\"go_term\"].apply(lambda terms: [t for t in terms if t in valid_terms])\n",
|
| 73 |
-
"data = data[data[\"go_term\"].apply(len) > 0]\n",
|
| 74 |
-
"\n",
|
| 75 |
-
"# --- 6. Preparar dataset final -----------------------------------------\n",
|
| 76 |
-
"data[\"go_terms\"] = data[\"go_term\"].apply(lambda x: ';'.join(sorted(set(x))))\n",
|
| 77 |
-
"data = data[[\"protein_id\", \"sequence\", \"go_terms\"]].drop_duplicates()\n",
|
| 78 |
-
"\n",
|
| 79 |
-
"# --- 7. Binarizar labels e dividir -------------------------------------\n",
|
| 80 |
-
"mlb = MultiLabelBinarizer()\n",
|
| 81 |
-
"Y = mlb.fit_transform(data[\"go_terms\"].str.split(\";\"))\n",
|
| 82 |
-
"X = data[[\"protein_id\", \"sequence\"]].values\n",
|
| 83 |
-
"\n",
|
| 84 |
-
"mskf = MultilabelStratifiedKFold(n_splits=10, random_state=42, shuffle=True)\n",
|
| 85 |
-
"train_idx, temp_idx = next(mskf.split(X, Y))\n",
|
| 86 |
-
"val_idx, test_idx = np.array_split(temp_idx, 2)\n",
|
| 87 |
-
"\n",
|
| 88 |
-
"df_train = data.iloc[train_idx].copy()\n",
|
| 89 |
-
"df_val = data.iloc[val_idx].copy()\n",
|
| 90 |
-
"df_test = data.iloc[test_idx].copy()\n",
|
| 91 |
-
"\n",
|
| 92 |
-
"# --- 8. Guardar em CSV -------------------------------------------------\n",
|
| 93 |
-
"os.makedirs(\"data\", exist_ok=True)\n",
|
| 94 |
-
"df_train.to_csv(\"data/mf-training.csv\", index=False)\n",
|
| 95 |
-
"df_val.to_csv(\"data/mf-validation.csv\", index=False)\n",
|
| 96 |
-
"df_test.to_csv(\"data/mf-test.csv\", index=False)\n",
|
| 97 |
-
"\n",
|
| 98 |
-
"# --- 9. Confirmar ------------------------------------------------------\n",
|
| 99 |
-
"print(\"✓ Ficheiros criados:\")\n",
|
| 100 |
-
"print(\" - data/mf-training.csv :\", df_train.shape)\n",
|
| 101 |
-
"print(\" - data/mf-validation.csv:\", df_val.shape)\n",
|
| 102 |
-
"print(\" - data/mf-test.csv :\", df_test.shape)\n",
|
| 103 |
-
"print(f\"GO terms únicos (após propagação e filtro): {len(mlb.classes_)}\")\n"
|
| 104 |
-
]
|
| 105 |
-
},
|
| 106 |
-
{
|
| 107 |
-
"cell_type": "code",
|
| 108 |
-
"execution_count": 2,
|
| 109 |
-
"id": "6cf7aaa6-4941-4951-8d73-1f4f1f4362f3",
|
| 110 |
-
"metadata": {},
|
| 111 |
-
"outputs": [
|
| 112 |
-
{
|
| 113 |
-
"name": "stderr",
|
| 114 |
-
"output_type": "stream",
|
| 115 |
-
"text": [
|
| 116 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 117 |
-
" from .autonotebook import tqdm as notebook_tqdm\n",
|
| 118 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
|
| 119 |
-
" _torch_pytree._register_pytree_node(\n",
|
| 120 |
-
"100%|██████████| 31142/31142 [00:26<00:00, 1192.86it/s]\n",
|
| 121 |
-
"100%|██████████| 1724/1724 [00:00<00:00, 2570.68it/s]\n",
|
| 122 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\text\\preprocessor.py:382: UserWarning: The class_names argument is replacing the classes argument. Please update your code.\n",
|
| 123 |
-
" warnings.warn(\n"
|
| 124 |
-
]
|
| 125 |
-
},
|
| 126 |
-
{
|
| 127 |
-
"name": "stdout",
|
| 128 |
-
"output_type": "stream",
|
| 129 |
-
"text": [
|
| 130 |
-
"preprocessing train...\n",
|
| 131 |
-
"language: en\n",
|
| 132 |
-
"train sequence lengths:\n",
|
| 133 |
-
"\tmean : 423\n",
|
| 134 |
-
"\t95percentile : 604\n",
|
| 135 |
-
"\t99percentile : 715\n"
|
| 136 |
-
]
|
| 137 |
-
},
|
| 138 |
-
{
|
| 139 |
-
"data": {
|
| 140 |
-
"text/html": [
|
| 141 |
-
"\n",
|
| 142 |
-
"<style>\n",
|
| 143 |
-
" /* Turns off some styling */\n",
|
| 144 |
-
" progress {\n",
|
| 145 |
-
" /* gets rid of default border in Firefox and Opera. */\n",
|
| 146 |
-
" border: none;\n",
|
| 147 |
-
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
| 148 |
-
" background-size: auto;\n",
|
| 149 |
-
" }\n",
|
| 150 |
-
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
|
| 151 |
-
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
|
| 152 |
-
" }\n",
|
| 153 |
-
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
|
| 154 |
-
" background: #F44336;\n",
|
| 155 |
-
" }\n",
|
| 156 |
-
"</style>\n"
|
| 157 |
-
],
|
| 158 |
-
"text/plain": [
|
| 159 |
-
"<IPython.core.display.HTML object>"
|
| 160 |
-
]
|
| 161 |
-
},
|
| 162 |
-
"metadata": {},
|
| 163 |
-
"output_type": "display_data"
|
| 164 |
-
},
|
| 165 |
-
{
|
| 166 |
-
"data": {
|
| 167 |
-
"text/html": [],
|
| 168 |
-
"text/plain": [
|
| 169 |
-
"<IPython.core.display.HTML object>"
|
| 170 |
-
]
|
| 171 |
-
},
|
| 172 |
-
"metadata": {},
|
| 173 |
-
"output_type": "display_data"
|
| 174 |
-
},
|
| 175 |
-
{
|
| 176 |
-
"name": "stdout",
|
| 177 |
-
"output_type": "stream",
|
| 178 |
-
"text": [
|
| 179 |
-
"Is Multi-Label? True\n",
|
| 180 |
-
"preprocessing test...\n",
|
| 181 |
-
"language: en\n",
|
| 182 |
-
"test sequence lengths:\n",
|
| 183 |
-
"\tmean : 408\n",
|
| 184 |
-
"\t95percentile : 603\n",
|
| 185 |
-
"\t99percentile : 714\n"
|
| 186 |
-
]
|
| 187 |
-
},
|
| 188 |
-
{
|
| 189 |
-
"data": {
|
| 190 |
-
"text/html": [
|
| 191 |
-
"\n",
|
| 192 |
-
"<style>\n",
|
| 193 |
-
" /* Turns off some styling */\n",
|
| 194 |
-
" progress {\n",
|
| 195 |
-
" /* gets rid of default border in Firefox and Opera. */\n",
|
| 196 |
-
" border: none;\n",
|
| 197 |
-
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
| 198 |
-
" background-size: auto;\n",
|
| 199 |
-
" }\n",
|
| 200 |
-
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
|
| 201 |
-
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
|
| 202 |
-
" }\n",
|
| 203 |
-
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
|
| 204 |
-
" background: #F44336;\n",
|
| 205 |
-
" }\n",
|
| 206 |
-
"</style>\n"
|
| 207 |
-
],
|
| 208 |
-
"text/plain": [
|
| 209 |
-
"<IPython.core.display.HTML object>"
|
| 210 |
-
]
|
| 211 |
-
},
|
| 212 |
-
"metadata": {},
|
| 213 |
-
"output_type": "display_data"
|
| 214 |
-
},
|
| 215 |
-
{
|
| 216 |
-
"data": {
|
| 217 |
-
"text/html": [],
|
| 218 |
-
"text/plain": [
|
| 219 |
-
"<IPython.core.display.HTML object>"
|
| 220 |
-
]
|
| 221 |
-
},
|
| 222 |
-
"metadata": {},
|
| 223 |
-
"output_type": "display_data"
|
| 224 |
-
},
|
| 225 |
-
{
|
| 226 |
-
"name": "stdout",
|
| 227 |
-
"output_type": "stream",
|
| 228 |
-
"text": [
|
| 229 |
-
"\n",
|
| 230 |
-
"\n",
|
| 231 |
-
"begin training using triangular learning rate policy with max lr of 1e-05...\n",
|
| 232 |
-
"Epoch 1/10\n",
|
| 233 |
-
"40995/40995 [==============================] - 9020s 219ms/step - loss: 0.0740 - binary_accuracy: 0.9869 - val_loss: 0.0526 - val_binary_accuracy: 0.9866\n",
|
| 234 |
-
"Epoch 2/10\n",
|
| 235 |
-
"40995/40995 [==============================] - 8939s 218ms/step - loss: 0.0464 - binary_accuracy: 0.9877 - val_loss: 0.0457 - val_binary_accuracy: 0.9871\n",
|
| 236 |
-
"Epoch 3/10\n",
|
| 237 |
-
"40995/40995 [==============================] - 8881s 217ms/step - loss: 0.0413 - binary_accuracy: 0.9883 - val_loss: 0.0418 - val_binary_accuracy: 0.9877\n",
|
| 238 |
-
"Epoch 4/10\n",
|
| 239 |
-
"40995/40995 [==============================] - 10277s 251ms/step - loss: 0.0380 - binary_accuracy: 0.9888 - val_loss: 0.0396 - val_binary_accuracy: 0.9881\n",
|
| 240 |
-
"Epoch 5/10\n",
|
| 241 |
-
"40995/40995 [==============================] - 10565s 258ms/step - loss: 0.0357 - binary_accuracy: 0.9892 - val_loss: 0.0380 - val_binary_accuracy: 0.9883\n",
|
| 242 |
-
"Epoch 6/10\n",
|
| 243 |
-
"40995/40995 [==============================] - 10693s 261ms/step - loss: 0.0338 - binary_accuracy: 0.9895 - val_loss: 0.0369 - val_binary_accuracy: 0.9885\n",
|
| 244 |
-
"Epoch 7/10\n",
|
| 245 |
-
"40995/40995 [==============================] - 12055s 294ms/step - loss: 0.0323 - binary_accuracy: 0.9898 - val_loss: 0.0360 - val_binary_accuracy: 0.9888\n",
|
| 246 |
-
"Epoch 8/10\n",
|
| 247 |
-
"40995/40995 [==============================] - 10225s 249ms/step - loss: 0.0309 - binary_accuracy: 0.9901 - val_loss: 0.0353 - val_binary_accuracy: 0.9890\n",
|
| 248 |
-
"Epoch 9/10\n",
|
| 249 |
-
"40995/40995 [==============================] - 10308s 251ms/step - loss: 0.0297 - binary_accuracy: 0.9904 - val_loss: 0.0347 - val_binary_accuracy: 0.9891\n",
|
| 250 |
-
"Epoch 10/10\n",
|
| 251 |
-
"40995/40995 [==============================] - 10275s 251ms/step - loss: 0.0286 - binary_accuracy: 0.9907 - val_loss: 0.0346 - val_binary_accuracy: 0.9893\n",
|
| 252 |
-
"Weights from best epoch have been loaded into model.\n"
|
| 253 |
-
]
|
| 254 |
-
},
|
| 255 |
-
{
|
| 256 |
-
"data": {
|
| 257 |
-
"text/plain": [
|
| 258 |
-
"<keras.callbacks.History at 0x2b644b84fd0>"
|
| 259 |
-
]
|
| 260 |
-
},
|
| 261 |
-
"execution_count": 2,
|
| 262 |
-
"metadata": {},
|
| 263 |
-
"output_type": "execute_result"
|
| 264 |
-
}
|
| 265 |
-
],
|
| 266 |
-
"source": [
|
| 267 |
-
"import pandas as pd\n",
|
| 268 |
-
"import numpy as np\n",
|
| 269 |
-
"from tqdm import tqdm\n",
|
| 270 |
-
"import random\n",
|
| 271 |
-
"import os\n",
|
| 272 |
-
"import ktrain\n",
|
| 273 |
-
"from ktrain import text\n",
|
| 274 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 275 |
-
"\n",
|
| 276 |
-
"\n",
|
| 277 |
-
"# PAM1\n",
|
| 278 |
-
"# PAM matrix model of protein evolution\n",
|
| 279 |
-
"# DOI:10.1093/oxfordjournals.molbev.a040360\n",
|
| 280 |
-
"pam_data = {\n",
|
| 281 |
-
" 'A': [9948, 19, 27, 42, 31, 46, 50, 92, 17, 7, 40, 88, 42, 41, 122, 279, 255, 9, 72, 723],\n",
|
| 282 |
-
" 'R': [14, 9871, 24, 38, 37, 130, 38, 62, 49, 4, 58, 205, 26, 33, 47, 103, 104, 5, 36, 52],\n",
|
| 283 |
-
" 'N': [20, 22, 9860, 181, 29, 36, 41, 67, 31, 5, 22, 49, 23, 10, 33, 83, 66, 3, 43, 32],\n",
|
| 284 |
-
" 'D': [40, 34, 187, 9818, 11, 63, 98, 61, 23, 5, 25, 54, 43, 13, 27, 88, 55, 4, 29, 36],\n",
|
| 285 |
-
" 'C': [20, 16, 26, 9, 9987, 10, 17, 37, 12, 2, 16, 26, 10, 19, 27, 26, 25, 2, 6, 67],\n",
|
| 286 |
-
" 'Q': [29, 118, 29, 49, 8, 9816, 72, 55, 36, 4, 60, 158, 35, 22, 39, 86, 74, 3, 34, 28],\n",
|
| 287 |
-
" 'E': [35, 29, 41, 101, 12, 71, 9804, 56, 33, 5, 36, 107, 42, 20, 38, 87, 69, 4, 30, 42],\n",
|
| 288 |
-
" 'G': [96, 61, 77, 70, 38, 51, 58, 9868, 26, 6, 37, 53, 39, 28, 69, 134, 116, 5, 47, 60],\n",
|
| 289 |
-
" 'H': [17, 53, 33, 19, 15, 39, 34, 24, 9907, 3, 32, 57, 24, 15, 27, 47, 43, 2, 22, 19],\n",
|
| 290 |
-
" 'I': [6, 3, 6, 6, 3, 5, 6, 7, 3, 9973, 23, 13, 12, 41, 93, 84, 115, 3, 8, 102],\n",
|
| 291 |
-
" 'L': [26, 39, 17, 15, 7, 33, 22, 20, 19, 27, 9864, 49, 24, 78, 117, 148, 193, 5, 24, 70],\n",
|
| 292 |
-
" 'K': [60, 198, 43, 52, 12, 142, 96, 53, 42, 10, 63, 9710, 33, 26, 54, 109, 102, 5, 43, 42],\n",
|
| 293 |
-
" 'M': [21, 22, 15, 18, 6, 20, 18, 18, 17, 11, 27, 32, 9945, 26, 34, 61, 71, 3, 12, 31],\n",
|
| 294 |
-
" 'F': [18, 17, 8, 6, 8, 11, 10, 16, 10, 44, 92, 24, 29, 9899, 89, 88, 142, 7, 14, 68],\n",
|
| 295 |
-
" 'P': [97, 47, 35, 29, 23, 35, 38, 57, 21, 24, 47, 56, 28, 76, 9785, 115, 77, 4, 24, 35],\n",
|
| 296 |
-
" 'S': [241, 87, 76, 73, 17, 56, 60, 99, 32, 13, 69, 92, 42, 67, 100, 9605, 212, 8, 63, 70],\n",
|
| 297 |
-
" 'T': [186, 78, 54, 37, 14, 42, 42, 83, 28, 23, 84, 85, 53, 93, 66, 182, 9676, 8, 39, 90],\n",
|
| 298 |
-
" 'W': [2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 5, 3, 4, 4, 9960, 3, 4],\n",
|
| 299 |
-
" 'Y': [29, 21, 17, 9, 4, 13, 9, 21, 10, 7, 20, 17, 11, 23, 19, 41, 31, 3, 9935, 23],\n",
|
| 300 |
-
" 'V': [368, 27, 18, 18, 50, 23, 34, 64, 15, 85, 72, 42, 33, 88, 42, 112, 137, 4, 20, 9514]\n",
|
| 301 |
-
"}\n",
|
| 302 |
-
"pam_raw = pd.DataFrame(pam_data, index=list(pam_data.keys()))\n",
|
| 303 |
-
"pam_matrix = pam_raw.div(pam_raw.sum(axis=1), axis=0)\n",
|
| 304 |
-
"list_amino = pam_raw.columns.tolist()\n",
|
| 305 |
-
"pam_dict = {\n",
|
| 306 |
-
" aa: {sub: pam_matrix.loc[aa, sub] for sub in list_amino}\n",
|
| 307 |
-
" for aa in list_amino\n",
|
| 308 |
-
"}\n",
|
| 309 |
-
"\n",
|
| 310 |
-
"def pam1_substitution(aa):\n",
|
| 311 |
-
" if aa not in pam_dict:\n",
|
| 312 |
-
" return aa\n",
|
| 313 |
-
" subs = list(pam_dict[aa].keys())\n",
|
| 314 |
-
" probs = list(pam_dict[aa].values())\n",
|
| 315 |
-
" return np.random.choice(subs, p=probs)\n",
|
| 316 |
-
"\n",
|
| 317 |
-
"def augment_sequence(seq, sub_prob=0.05):\n",
|
| 318 |
-
" return ''.join([pam1_substitution(aa) if random.random() < sub_prob else aa for aa in seq])\n",
|
| 319 |
-
"\n",
|
| 320 |
-
"def slice_sequence(seq, win=500, min_overlap=250):\n",
|
| 321 |
-
" if len(seq) <= win:\n",
|
| 322 |
-
" return [seq]\n",
|
| 323 |
-
" slices, start = [], 0\n",
|
| 324 |
-
" while start + win <= len(seq):\n",
|
| 325 |
-
" slices.append(seq[start:start+win])\n",
|
| 326 |
-
" start += win\n",
|
| 327 |
-
" leftover = seq[start:]\n",
|
| 328 |
-
" if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
|
| 329 |
-
" extra = slices[-1][-min_overlap:] + leftover\n",
|
| 330 |
-
" slices.append(extra)\n",
|
| 331 |
-
" return slices\n",
|
| 332 |
-
"\n",
|
| 333 |
-
"def generate_data(df, augment=False):\n",
|
| 334 |
-
" X, y = [], []\n",
|
| 335 |
-
" label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
|
| 336 |
-
" for _, row in tqdm(df.iterrows(), total=len(df)):\n",
|
| 337 |
-
" seq = row[\"sequence\"]\n",
|
| 338 |
-
" if augment:\n",
|
| 339 |
-
" seq = augment_sequence(seq)\n",
|
| 340 |
-
" seq_slices = slice_sequence(seq)\n",
|
| 341 |
-
" X.extend(seq_slices)\n",
|
| 342 |
-
" lbl = row[label_cols].values.astype(int)\n",
|
| 343 |
-
" y.extend([lbl] * len(seq_slices))\n",
|
| 344 |
-
" return X, np.array(y), label_cols\n",
|
| 345 |
-
"\n",
|
| 346 |
-
"def format_sequence(seq): return \" \".join(list(seq))\n",
|
| 347 |
-
"\n",
|
| 348 |
-
"# Função para carregar e binarizar\n",
|
| 349 |
-
"def load_and_binarize(csv_path, mlb=None):\n",
|
| 350 |
-
" df = pd.read_csv(csv_path)\n",
|
| 351 |
-
" df[\"go_terms\"] = df[\"go_terms\"].str.split(\";\")\n",
|
| 352 |
-
" if mlb is None:\n",
|
| 353 |
-
" mlb = MultiLabelBinarizer()\n",
|
| 354 |
-
" labels = mlb.fit_transform(df[\"go_terms\"])\n",
|
| 355 |
-
" else:\n",
|
| 356 |
-
" labels = mlb.transform(df[\"go_terms\"])\n",
|
| 357 |
-
" labels_df = pd.DataFrame(labels, columns=mlb.classes_)\n",
|
| 358 |
-
" df = df.reset_index(drop=True).join(labels_df)\n",
|
| 359 |
-
" return df, mlb\n",
|
| 360 |
-
"\n",
|
| 361 |
-
"# Carregar os dados\n",
|
| 362 |
-
"df_train, mlb = load_and_binarize(\"data/mf-training.csv\")\n",
|
| 363 |
-
"df_val, _ = load_and_binarize(\"data/mf-validation.csv\", mlb=mlb)\n",
|
| 364 |
-
"\n",
|
| 365 |
-
"# Gerar com augmentation no treino\n",
|
| 366 |
-
"X_train, y_train, term_cols = generate_data(df_train, augment=True)\n",
|
| 367 |
-
"X_val, y_val, _ = generate_data(df_val, augment=False)\n",
|
| 368 |
-
"\n",
|
| 369 |
-
"# Preparar texto para tokenizer\n",
|
| 370 |
-
"X_train_fmt = list(map(format_sequence, X_train))\n",
|
| 371 |
-
"X_val_fmt = list(map(format_sequence, X_val))\n",
|
| 372 |
-
"\n",
|
| 373 |
-
"# Fine-tune ProtBERT\n",
|
| 374 |
-
"# https://huggingface.co/Rostlab/prot_bert\n",
|
| 375 |
-
"# https://doi.org/10.1093/bioinformatics/btac020\n",
|
| 376 |
-
"# dados de treino-> UniRef100 (216 milhões de sequências)\n",
|
| 377 |
-
"MODEL_NAME = \"Rostlab/prot_bert_bfd\"\n",
|
| 378 |
-
"MAX_LEN = 512\n",
|
| 379 |
-
"BATCH_SIZE = 1\n",
|
| 380 |
-
"\n",
|
| 381 |
-
"t = text.Transformer(MODEL_NAME, maxlen=MAX_LEN, classes=term_cols)\n",
|
| 382 |
-
"trn = t.preprocess_train(X_train_fmt, y_train)\n",
|
| 383 |
-
"val = t.preprocess_test(X_val_fmt, y_val)\n",
|
| 384 |
-
"\n",
|
| 385 |
-
"model = t.get_classifier()\n",
|
| 386 |
-
"learner = ktrain.get_learner(model,\n",
|
| 387 |
-
" train_data=trn,\n",
|
| 388 |
-
" val_data=val,\n",
|
| 389 |
-
" batch_size=BATCH_SIZE)\n",
|
| 390 |
-
"\n",
|
| 391 |
-
"learner.autofit(lr=1e-5,\n",
|
| 392 |
-
" epochs=10,\n",
|
| 393 |
-
" early_stopping=1,\n",
|
| 394 |
-
" checkpoint_folder=\"mf-fine-tuned-protbertbfd\")\n"
|
| 395 |
-
]
|
| 396 |
-
},
|
| 397 |
-
{
|
| 398 |
-
"cell_type": "code",
|
| 399 |
-
"execution_count": 6,
|
| 400 |
-
"id": "c66774b3-6cf0-41c5-bb01-9467a5283102",
|
| 401 |
-
"metadata": {},
|
| 402 |
-
"outputs": [
|
| 403 |
-
{
|
| 404 |
-
"name": "stdout",
|
| 405 |
-
"output_type": "stream",
|
| 406 |
-
"text": [
|
| 407 |
-
"✅ Existe: weights/mf-fine-tuned-protbertbfd\n",
|
| 408 |
-
"📁 Conteúdo:\n",
|
| 409 |
-
" - config.json\n",
|
| 410 |
-
" - tf_model.h5\n"
|
| 411 |
-
]
|
| 412 |
-
}
|
| 413 |
-
],
|
| 414 |
-
"source": [
|
| 415 |
-
"import os\n",
|
| 416 |
-
"learner.save_model('weights/mf-fine-tuned-protbertbfd')\n",
|
| 417 |
-
"path = \"weights/mf-fine-tuned-protbertbfd\"\n",
|
| 418 |
-
"\n",
|
| 419 |
-
"if os.path.exists(path):\n",
|
| 420 |
-
" print(f\"✅ Existe: {path}\")\n",
|
| 421 |
-
" print(\"📁 Conteúdo:\")\n",
|
| 422 |
-
" for f in os.listdir(path):\n",
|
| 423 |
-
" print(\" -\", f)\n",
|
| 424 |
-
"else:\n",
|
| 425 |
-
" print(f\"❌ Não existe: {path}\")\n",
|
| 426 |
-
"\n"
|
| 427 |
-
]
|
| 428 |
-
},
|
| 429 |
-
{
|
| 430 |
-
"cell_type": "code",
|
| 431 |
-
"execution_count": 8,
|
| 432 |
-
"id": "9b39c439-5708-4787-bfee-d3a4d3aa190d",
|
| 433 |
-
"metadata": {},
|
| 434 |
-
"outputs": [
|
| 435 |
-
{
|
| 436 |
-
"name": "stdout",
|
| 437 |
-
"output_type": "stream",
|
| 438 |
-
"text": [
|
| 439 |
-
"✓ Tokenizer base e modelo fine-tuned carregados com sucesso\n"
|
| 440 |
-
]
|
| 441 |
-
},
|
| 442 |
-
{
|
| 443 |
-
"name": "stderr",
|
| 444 |
-
"output_type": "stream",
|
| 445 |
-
"text": [
|
| 446 |
-
"Processando data/mf-training.csv: 100%|██████████| 31142/31142 [5:17:56<00:00, 1.63it/s] \n"
|
| 447 |
-
]
|
| 448 |
-
},
|
| 449 |
-
{
|
| 450 |
-
"name": "stdout",
|
| 451 |
-
"output_type": "stream",
|
| 452 |
-
"text": [
|
| 453 |
-
"✓ Guardado embeddings\\train_protbertbfd.pkl — 31142 proteínas\n"
|
| 454 |
-
]
|
| 455 |
-
},
|
| 456 |
-
{
|
| 457 |
-
"name": "stderr",
|
| 458 |
-
"output_type": "stream",
|
| 459 |
-
"text": [
|
| 460 |
-
"Processando data/mf-validation.csv: 100%|██████████| 1724/1724 [19:15<00:00, 1.49it/s]\n"
|
| 461 |
-
]
|
| 462 |
-
},
|
| 463 |
-
{
|
| 464 |
-
"name": "stdout",
|
| 465 |
-
"output_type": "stream",
|
| 466 |
-
"text": [
|
| 467 |
-
"✓ Guardado embeddings\\val_protbertbfd.pkl — 1724 proteínas\n"
|
| 468 |
-
]
|
| 469 |
-
},
|
| 470 |
-
{
|
| 471 |
-
"name": "stderr",
|
| 472 |
-
"output_type": "stream",
|
| 473 |
-
"text": [
|
| 474 |
-
"Processando data/mf-test.csv: 100%|██████████| 1724/1724 [17:15<00:00, 1.66it/s]\n"
|
| 475 |
-
]
|
| 476 |
-
},
|
| 477 |
-
{
|
| 478 |
-
"name": "stdout",
|
| 479 |
-
"output_type": "stream",
|
| 480 |
-
"text": [
|
| 481 |
-
"✓ Guardado embeddings\\test_protbertbfd.pkl — 1724 proteínas\n"
|
| 482 |
-
]
|
| 483 |
-
}
|
| 484 |
-
],
|
| 485 |
-
"source": [
|
| 486 |
-
"import os\n",
|
| 487 |
-
"import pandas as pd\n",
|
| 488 |
-
"import numpy as np\n",
|
| 489 |
-
"from tqdm import tqdm\n",
|
| 490 |
-
"import joblib\n",
|
| 491 |
-
"import gc\n",
|
| 492 |
-
"from transformers import AutoTokenizer, TFAutoModel\n",
|
| 493 |
-
"\n",
|
| 494 |
-
"# --- 1. Parâmetros --------------------------------------------------------\n",
|
| 495 |
-
"MODEL_DIR = \"weights/mf-fine-tuned-protbertbfd\"\n",
|
| 496 |
-
"MODEL_NAME = \"Rostlab/prot_bert_bfd\"\n",
|
| 497 |
-
"OUT_DIR = \"embeddings\"\n",
|
| 498 |
-
"BATCH_TOK = 16\n",
|
| 499 |
-
"\n",
|
| 500 |
-
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)\n",
|
| 501 |
-
"model = TFAutoModel.from_pretrained(MODEL_DIR, from_pt=False)\n",
|
| 502 |
-
"\n",
|
| 503 |
-
"print(\"✓ Tokenizer base e modelo fine-tuned carregados com sucesso\")\n",
|
| 504 |
-
"\n",
|
| 505 |
-
"# --- 3. Funções auxiliares ------------------------------------------------\n",
|
| 506 |
-
"def format_sequence(seq):\n",
|
| 507 |
-
" return \" \".join(list(seq))\n",
|
| 508 |
-
"\n",
|
| 509 |
-
"def slice_sequence(seq, win=500, min_overlap=250):\n",
|
| 510 |
-
" if len(seq) <= win:\n",
|
| 511 |
-
" return [seq]\n",
|
| 512 |
-
" slices, start = [], 0\n",
|
| 513 |
-
" while start + win <= len(seq):\n",
|
| 514 |
-
" slices.append(seq[start:start+win])\n",
|
| 515 |
-
" start += win\n",
|
| 516 |
-
" leftover = seq[start:]\n",
|
| 517 |
-
" if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
|
| 518 |
-
" extra = slices[-1][-min_overlap:] + leftover\n",
|
| 519 |
-
" slices.append(extra)\n",
|
| 520 |
-
" return slices\n",
|
| 521 |
-
"\n",
|
| 522 |
-
"def get_embeddings(batch, tokenizer, model):\n",
|
| 523 |
-
" tokens = tokenizer(batch, return_tensors=\"tf\", padding=True, truncation=True, max_length=512)\n",
|
| 524 |
-
" output = model(**tokens)\n",
|
| 525 |
-
" return output.last_hidden_state[:, 0, :].numpy()\n",
|
| 526 |
-
"\n",
|
| 527 |
-
"def process_split(csv_path, out_path):\n",
|
| 528 |
-
" df = pd.read_csv(csv_path)\n",
|
| 529 |
-
" label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
|
| 530 |
-
" prot_ids, embeds, labels = [], [], []\n",
|
| 531 |
-
"\n",
|
| 532 |
-
" for _, row in tqdm(df.iterrows(), total=len(df), desc=f\"Processando {csv_path}\"):\n",
|
| 533 |
-
" slices = slice_sequence(row[\"sequence\"])\n",
|
| 534 |
-
" slices_fmt = list(map(format_sequence, slices))\n",
|
| 535 |
-
"\n",
|
| 536 |
-
" slice_embeds = []\n",
|
| 537 |
-
" for i in range(0, len(slices_fmt), BATCH_TOK):\n",
|
| 538 |
-
" batch = slices_fmt[i:i+BATCH_TOK]\n",
|
| 539 |
-
" slice_embeds.append(get_embeddings(batch, tokenizer, model))\n",
|
| 540 |
-
" slice_embeds = np.vstack(slice_embeds)\n",
|
| 541 |
-
"\n",
|
| 542 |
-
" prot_embed = slice_embeds.mean(axis=0)\n",
|
| 543 |
-
" prot_ids.append(row[\"protein_id\"])\n",
|
| 544 |
-
" embeds.append(prot_embed.astype(np.float32))\n",
|
| 545 |
-
" labels.append(row[label_cols].values.astype(np.int8))\n",
|
| 546 |
-
" gc.collect()\n",
|
| 547 |
-
"\n",
|
| 548 |
-
" embeds = np.vstack(embeds)\n",
|
| 549 |
-
" labels = np.vstack(labels)\n",
|
| 550 |
-
"\n",
|
| 551 |
-
" joblib.dump({\n",
|
| 552 |
-
" \"protein_ids\": prot_ids,\n",
|
| 553 |
-
" \"embeddings\": embeds,\n",
|
| 554 |
-
" \"labels\": labels,\n",
|
| 555 |
-
" \"go_terms\": label_cols\n",
|
| 556 |
-
" }, out_path, compress=3)\n",
|
| 557 |
-
"\n",
|
| 558 |
-
" print(f\"✓ Guardado {out_path} — {embeds.shape[0]} proteínas\")\n",
|
| 559 |
-
"\n",
|
| 560 |
-
"# --- 4. Aplicar -----------------------------------------------------------\n",
|
| 561 |
-
"os.makedirs(OUT_DIR, exist_ok=True)\n",
|
| 562 |
-
"\n",
|
| 563 |
-
"process_split(\"data/mf-training.csv\", os.path.join(OUT_DIR, \"train_protbertbfd.pkl\"))\n",
|
| 564 |
-
"process_split(\"data/mf-validation.csv\", os.path.join(OUT_DIR, \"val_protbertbfd.pkl\"))\n",
|
| 565 |
-
"process_split(\"data/mf-test.csv\", os.path.join(OUT_DIR, \"test_protbertbfd.pkl\"))\n"
|
| 566 |
-
]
|
| 567 |
-
},
|
| 568 |
-
{
|
| 569 |
-
"cell_type": "code",
|
| 570 |
-
"execution_count": 9,
|
| 571 |
-
"id": "ad0c5421-e0a1-4a6a-8ace-2c69aeab0e0d",
|
| 572 |
-
"metadata": {},
|
| 573 |
-
"outputs": [
|
| 574 |
-
{
|
| 575 |
-
"name": "stdout",
|
| 576 |
-
"output_type": "stream",
|
| 577 |
-
"text": [
|
| 578 |
-
"✓ Corrigido: embeddings/train_protbertbfd.pkl — 31142 exemplos, 597 GO terms\n",
|
| 579 |
-
"✓ Corrigido: embeddings/val_protbertbfd.pkl — 1724 exemplos, 597 GO terms\n",
|
| 580 |
-
"✓ Corrigido: embeddings/test_protbertbfd.pkl — 1724 exemplos, 597 GO terms\n"
|
| 581 |
-
]
|
| 582 |
-
}
|
| 583 |
-
],
|
| 584 |
-
"source": [
|
| 585 |
-
"import pandas as pd\n",
|
| 586 |
-
"import joblib\n",
|
| 587 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 588 |
-
"\n",
|
| 589 |
-
"# --- 1. Obter GO terms do ficheiro de teste --------------------------------\n",
|
| 590 |
-
"df_test = pd.read_csv(\"data/mf-test.csv\")\n",
|
| 591 |
-
"test_terms = sorted(set(term for row in df_test[\"go_terms\"].str.split(\";\") for term in row))\n",
|
| 592 |
-
"\n",
|
| 593 |
-
"# --- 2. Função para corrigir um .pkl com base nos GO terms do teste --------\n",
|
| 594 |
-
"def patch_to_common_terms(csv_path, pkl_path, common_terms):\n",
|
| 595 |
-
" df = pd.read_csv(csv_path)\n",
|
| 596 |
-
" terms_split = df[\"go_terms\"].str.split(\";\")\n",
|
| 597 |
-
" \n",
|
| 598 |
-
" # Apenas termos presentes nos common_terms\n",
|
| 599 |
-
" terms_filtered = terms_split.apply(lambda lst: [t for t in lst if t in common_terms])\n",
|
| 600 |
-
" \n",
|
| 601 |
-
" mlb = MultiLabelBinarizer(classes=common_terms)\n",
|
| 602 |
-
" Y = mlb.fit_transform(terms_filtered)\n",
|
| 603 |
-
"\n",
|
| 604 |
-
" data = joblib.load(pkl_path)\n",
|
| 605 |
-
" data[\"labels\"] = Y\n",
|
| 606 |
-
" data[\"go_terms\"] = mlb.classes_.tolist()\n",
|
| 607 |
-
" \n",
|
| 608 |
-
" joblib.dump(data, pkl_path, compress=3)\n",
|
| 609 |
-
" print(f\"✓ Corrigido: {pkl_path} — {Y.shape[0]} exemplos, {Y.shape[1]} GO terms\")\n",
|
| 610 |
-
"\n",
|
| 611 |
-
"# --- 3. Aplicar às 3 partições --------------------------------------------\n",
|
| 612 |
-
"patch_to_common_terms(\"data/mf-training.csv\", \"embeddings/train_protbertbfd.pkl\", test_terms)\n",
|
| 613 |
-
"patch_to_common_terms(\"data/mf-validation.csv\", \"embeddings/val_protbertbfd.pkl\", test_terms)\n",
|
| 614 |
-
"patch_to_common_terms(\"data/mf-test.csv\", \"embeddings/test_protbertbfd.pkl\", test_terms)\n"
|
| 615 |
-
]
|
| 616 |
-
},
|
| 617 |
-
{
|
| 618 |
-
"cell_type": "code",
|
| 619 |
-
"execution_count": 2,
|
| 620 |
-
"id": "dbd5c35f-4a08-4906-9cf4-e1df501d1ecb",
|
| 621 |
-
"metadata": {},
|
| 622 |
-
"outputs": [],
|
| 623 |
-
"source": [
|
| 624 |
-
"import joblib\n",
|
| 625 |
-
"train = joblib.load(\"embeddings/train_protbertbfd.pkl\")\n",
|
| 626 |
-
"val = joblib.load(\"embeddings/val_protbertbfd.pkl\")\n",
|
| 627 |
-
"test = joblib.load(\"embeddings/test_protbertbfd.pkl\")\n",
|
| 628 |
-
"\n",
|
| 629 |
-
"X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
|
| 630 |
-
"X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
|
| 631 |
-
"X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n"
|
| 632 |
-
]
|
| 633 |
-
},
|
| 634 |
-
{
|
| 635 |
-
"cell_type": "code",
|
| 636 |
-
"execution_count": 5,
|
| 637 |
-
"id": "1785d8a9-23fc-4490-8d71-29cc91a4cb57",
|
| 638 |
-
"metadata": {},
|
| 639 |
-
"outputs": [
|
| 640 |
-
{
|
| 641 |
-
"name": "stdout",
|
| 642 |
-
"output_type": "stream",
|
| 643 |
-
"text": [
|
| 644 |
-
"✓ Embeddings carregados: (31142, 1024) → 597 GO terms\n",
|
| 645 |
-
"Epoch 1/100\n",
|
| 646 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0337 - binary_accuracy: 0.9901 - val_loss: 0.0331 - val_binary_accuracy: 0.9905\n",
|
| 647 |
-
"Epoch 2/100\n",
|
| 648 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0252 - binary_accuracy: 0.9921 - val_loss: 0.0326 - val_binary_accuracy: 0.9905\n",
|
| 649 |
-
"Epoch 3/100\n",
|
| 650 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0244 - binary_accuracy: 0.9924 - val_loss: 0.0330 - val_binary_accuracy: 0.9905\n",
|
| 651 |
-
"Epoch 4/100\n",
|
| 652 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0240 - binary_accuracy: 0.9925 - val_loss: 0.0322 - val_binary_accuracy: 0.9907\n",
|
| 653 |
-
"Epoch 5/100\n",
|
| 654 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0236 - binary_accuracy: 0.9925 - val_loss: 0.0328 - val_binary_accuracy: 0.9907\n",
|
| 655 |
-
"Epoch 6/100\n",
|
| 656 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0232 - binary_accuracy: 0.9926 - val_loss: 0.0325 - val_binary_accuracy: 0.9908\n",
|
| 657 |
-
"Epoch 7/100\n",
|
| 658 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0231 - binary_accuracy: 0.9926 - val_loss: 0.0325 - val_binary_accuracy: 0.9907\n",
|
| 659 |
-
"Epoch 8/100\n",
|
| 660 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0228 - binary_accuracy: 0.9927 - val_loss: 0.0326 - val_binary_accuracy: 0.9908\n",
|
| 661 |
-
"Epoch 9/100\n",
|
| 662 |
-
"974/974 [==============================] - 3s 3ms/step - loss: 0.0226 - binary_accuracy: 0.9927 - val_loss: 0.0326 - val_binary_accuracy: 0.9908\n",
|
| 663 |
-
"Previsões guardadas em mf-protbertbfd-pam1.npy\n",
|
| 664 |
-
"Modelo guardado em models/protbertbfd_mlp.keras\n"
|
| 665 |
-
]
|
| 666 |
-
}
|
| 667 |
-
],
|
| 668 |
-
"source": [
|
| 669 |
-
"import tensorflow as tf\n",
|
| 670 |
-
"import joblib\n",
|
| 671 |
-
"import numpy as np\n",
|
| 672 |
-
"from tensorflow.keras.models import Sequential\n",
|
| 673 |
-
"from tensorflow.keras.layers import Dense, Dropout\n",
|
| 674 |
-
"from tensorflow.keras.callbacks import EarlyStopping\n",
|
| 675 |
-
"\n",
|
| 676 |
-
"# --- 1. Carregar embeddings ----------------------------------------------\n",
|
| 677 |
-
"train = joblib.load(\"embeddings/train_protbertbfd.pkl\")\n",
|
| 678 |
-
"val = joblib.load(\"embeddings/val_protbertbfd.pkl\")\n",
|
| 679 |
-
"test = joblib.load(\"embeddings/test_protbertbfd.pkl\")\n",
|
| 680 |
-
"\n",
|
| 681 |
-
"X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
|
| 682 |
-
"X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
|
| 683 |
-
"X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n",
|
| 684 |
-
"\n",
|
| 685 |
-
"print(f\"✓ Embeddings carregados: {X_train.shape} → {y_train.shape[1]} GO terms\")\n",
|
| 686 |
-
"\n",
|
| 687 |
-
"# --- 2. Garantir consistência de classes ---------------------------------\n",
|
| 688 |
-
"max_classes = y_train.shape[1] # 602 GO terms (do treino)\n",
|
| 689 |
-
"\n",
|
| 690 |
-
"def pad_labels(y, target_dim=max_classes):\n",
|
| 691 |
-
" if y.shape[1] < target_dim:\n",
|
| 692 |
-
" padding = np.zeros((y.shape[0], target_dim - y.shape[1]), dtype=np.int8)\n",
|
| 693 |
-
" return np.hstack([y, padding])\n",
|
| 694 |
-
" return y\n",
|
| 695 |
-
"\n",
|
| 696 |
-
"y_val = pad_labels(y_val)\n",
|
| 697 |
-
"y_test = pad_labels(y_test)\n",
|
| 698 |
-
"\n",
|
| 699 |
-
"# --- 3. Modelo MLP ------------------------------------------------------\n",
|
| 700 |
-
"model = Sequential([\n",
|
| 701 |
-
" Dense(1024, activation=\"relu\", input_shape=(X_train.shape[1],)),\n",
|
| 702 |
-
" Dropout(0.3),\n",
|
| 703 |
-
" Dense(512, activation=\"relu\"),\n",
|
| 704 |
-
" Dropout(0.3),\n",
|
| 705 |
-
" Dense(max_classes, activation=\"sigmoid\")\n",
|
| 706 |
-
"])\n",
|
| 707 |
-
"\n",
|
| 708 |
-
"model.compile(loss=\"binary_crossentropy\",\n",
|
| 709 |
-
" optimizer=\"adam\",\n",
|
| 710 |
-
" metrics=[\"binary_accuracy\"])\n",
|
| 711 |
-
"\n",
|
| 712 |
-
"# --- 4. Early stopping e treino -----------------------------------------\n",
|
| 713 |
-
"callbacks = [\n",
|
| 714 |
-
" EarlyStopping(monitor=\"val_loss\", patience=5, restore_best_weights=True)\n",
|
| 715 |
-
"]\n",
|
| 716 |
-
"\n",
|
| 717 |
-
"model.fit(X_train, y_train,\n",
|
| 718 |
-
" validation_data=(X_val, y_val),\n",
|
| 719 |
-
" epochs=100,\n",
|
| 720 |
-
" batch_size=32,\n",
|
| 721 |
-
" callbacks=callbacks,\n",
|
| 722 |
-
" verbose=1)\n",
|
| 723 |
-
"\n",
|
| 724 |
-
"# --- 5. Previsões --------------------------------------------------------\n",
|
| 725 |
-
"y_prob = model.predict(X_test)\n",
|
| 726 |
-
"np.save(\"predictions/mf-protbertbfd-pam1.npy\", y_prob)\n",
|
| 727 |
-
"print(\"Previsões guardadas em mf-protbertbfd-pam1.npy\")\n",
|
| 728 |
-
"\n",
|
| 729 |
-
"# --- 6. Modelo ----------------------------------------------------------\n",
|
| 730 |
-
"model.save(\"models/protbertbfd_mlp.keras\")\n",
|
| 731 |
-
"print(\"Modelo guardado em models/protbertbfd_mlp.keras\")"
|
| 732 |
-
]
|
| 733 |
-
},
|
| 734 |
-
{
|
| 735 |
-
"cell_type": "code",
|
| 736 |
-
"execution_count": 12,
|
| 737 |
-
"id": "fdb66630-76dc-43a0-bd56-45052175fdba",
|
| 738 |
-
"metadata": {},
|
| 739 |
-
"outputs": [
|
| 740 |
-
{
|
| 741 |
-
"name": "stdout",
|
| 742 |
-
"output_type": "stream",
|
| 743 |
-
"text": [
|
| 744 |
-
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
|
| 745 |
-
"✓ Embeddings: (1724, 597) labels × 597 GO terms\n",
|
| 746 |
-
"\n",
|
| 747 |
-
"📊 Resultados finais (ProtBERTBFD + PAM1 + propagação):\n",
|
| 748 |
-
"Fmax = 0.6570\n",
|
| 749 |
-
"Thr. = 0.41\n",
|
| 750 |
-
"AuPRC = 0.6929\n",
|
| 751 |
-
"Smin = 13.8114\n"
|
| 752 |
-
]
|
| 753 |
-
}
|
| 754 |
-
],
|
| 755 |
-
"source": [
|
| 756 |
-
"import numpy as np\n",
|
| 757 |
-
"from sklearn.metrics import precision_recall_curve, auc\n",
|
| 758 |
-
"from goatools.obo_parser import GODag\n",
|
| 759 |
-
"import joblib\n",
|
| 760 |
-
"import math\n",
|
| 761 |
-
"\n",
|
| 762 |
-
"# --- 1. Parâmetros -------------------------------------------------------\n",
|
| 763 |
-
"GO_FILE = \"go.obo\"\n",
|
| 764 |
-
"THRESHOLDS = np.arange(0.0, 1.01, 0.01)\n",
|
| 765 |
-
"ALPHA = 0.5\n",
|
| 766 |
-
"\n",
|
| 767 |
-
"# --- 2. Carregar dados ---------------------------------------------------\n",
|
| 768 |
-
"test = joblib.load(\"embeddings/test_protbertbfd.pkl\")\n",
|
| 769 |
-
"y_true = test[\"labels\"]\n",
|
| 770 |
-
"terms = test[\"go_terms\"]\n",
|
| 771 |
-
"y_prob = np.load(\"predictions/mf-protbertbfd-pam1.npy\")\n",
|
| 772 |
-
"go_dag = GODag(GO_FILE)\n",
|
| 773 |
-
"\n",
|
| 774 |
-
"print(f\"✓ Embeddings: {y_true.shape} labels × {len(terms)} GO terms\")\n",
|
| 775 |
-
"\n",
|
| 776 |
-
"# --- 3. Fmax -------------------------------------------------------------\n",
|
| 777 |
-
"def compute_fmax(y_true, y_prob, thresholds):\n",
|
| 778 |
-
" fmax, best_thr = 0, 0\n",
|
| 779 |
-
" for t in thresholds:\n",
|
| 780 |
-
" y_pred = (y_prob >= t).astype(int)\n",
|
| 781 |
-
" tp = (y_true * y_pred).sum(axis=1)\n",
|
| 782 |
-
" fp = ((1 - y_true) * y_pred).sum(axis=1)\n",
|
| 783 |
-
" fn = (y_true * (1 - y_pred)).sum(axis=1)\n",
|
| 784 |
-
" precision = tp / (tp + fp + 1e-8)\n",
|
| 785 |
-
" recall = tp / (tp + fn + 1e-8)\n",
|
| 786 |
-
" f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
|
| 787 |
-
" avg_f1 = np.mean(f1)\n",
|
| 788 |
-
" if avg_f1 > fmax:\n",
|
| 789 |
-
" fmax, best_thr = avg_f1, t\n",
|
| 790 |
-
" return fmax, best_thr\n",
|
| 791 |
-
"\n",
|
| 792 |
-
"# --- 4. AuPRC micro ------------------------------------------------------\n",
|
| 793 |
-
"def compute_auprc(y_true, y_prob):\n",
|
| 794 |
-
" precision, recall, _ = precision_recall_curve(y_true.ravel(), y_prob.ravel())\n",
|
| 795 |
-
" return auc(recall, precision)\n",
|
| 796 |
-
"\n",
|
| 797 |
-
"# --- 5. Smin -------------------------------------------------------------\n",
|
| 798 |
-
"def compute_smin(y_true, y_prob, terms, threshold, go_dag, alpha=ALPHA):\n",
|
| 799 |
-
" y_pred = (y_prob >= threshold).astype(int)\n",
|
| 800 |
-
" ic = {}\n",
|
| 801 |
-
" total = (y_true + y_pred).sum(axis=0).sum()\n",
|
| 802 |
-
" for i, term in enumerate(terms):\n",
|
| 803 |
-
" freq = (y_true[:, i] + y_pred[:, i]).sum()\n",
|
| 804 |
-
" ic[term] = -np.log((freq + 1e-8) / total)\n",
|
| 805 |
-
"\n",
|
| 806 |
-
" s_values = []\n",
|
| 807 |
-
" for true_vec, pred_vec in zip(y_true, y_pred):\n",
|
| 808 |
-
" true_terms = {terms[i] for i in np.where(true_vec)[0]}\n",
|
| 809 |
-
" pred_terms = {terms[i] for i in np.where(pred_vec)[0]}\n",
|
| 810 |
-
"\n",
|
| 811 |
-
" anc_true = set()\n",
|
| 812 |
-
" for t in true_terms:\n",
|
| 813 |
-
" if t in go_dag:\n",
|
| 814 |
-
" anc_true |= go_dag[t].get_all_parents()\n",
|
| 815 |
-
" anc_pred = set()\n",
|
| 816 |
-
" for t in pred_terms:\n",
|
| 817 |
-
" if t in go_dag:\n",
|
| 818 |
-
" anc_pred |= go_dag[t].get_all_parents()\n",
|
| 819 |
-
"\n",
|
| 820 |
-
" ru = pred_terms - true_terms\n",
|
| 821 |
-
" mi = true_terms - pred_terms\n",
|
| 822 |
-
" dist_ru = sum(ic.get(t, 0) for t in ru)\n",
|
| 823 |
-
" dist_mi = sum(ic.get(t, 0) for t in mi)\n",
|
| 824 |
-
" s = math.sqrt((alpha * dist_ru)**2 + ((1 - alpha) * dist_mi)**2)\n",
|
| 825 |
-
" s_values.append(s)\n",
|
| 826 |
-
"\n",
|
| 827 |
-
" return np.mean(s_values)\n",
|
| 828 |
-
"\n",
|
| 829 |
-
"# --- 6. Avaliar ----------------------------------------------------------\n",
|
| 830 |
-
"fmax, thr = compute_fmax(y_true, y_prob, THRESHOLDS)\n",
|
| 831 |
-
"auprc = compute_auprc(y_true, y_prob)\n",
|
| 832 |
-
"smin = compute_smin(y_true, y_prob, terms, thr, go_dag)\n",
|
| 833 |
-
"\n",
|
| 834 |
-
"print(f\"\\n📊 Resultados finais (ProtBERTBFD + PAM1 + propagação):\")\n",
|
| 835 |
-
"print(f\"Fmax = {fmax:.4f}\")\n",
|
| 836 |
-
"print(f\"Thr. = {thr:.2f}\")\n",
|
| 837 |
-
"print(f\"AuPRC = {auprc:.4f}\")\n",
|
| 838 |
-
"print(f\"Smin = {smin:.4f}\")\n"
|
| 839 |
-
]
|
| 840 |
-
},
|
| 841 |
-
{
|
| 842 |
-
"cell_type": "code",
|
| 843 |
-
"execution_count": null,
|
| 844 |
-
"id": "70d131ef-ef84-42ee-953b-0d3f1268694d",
|
| 845 |
-
"metadata": {},
|
| 846 |
-
"outputs": [],
|
| 847 |
-
"source": []
|
| 848 |
-
}
|
| 849 |
-
],
|
| 850 |
-
"metadata": {
|
| 851 |
-
"kernelspec": {
|
| 852 |
-
"display_name": "Python 3 (ipykernel)",
|
| 853 |
-
"language": "python",
|
| 854 |
-
"name": "python3"
|
| 855 |
-
},
|
| 856 |
-
"language_info": {
|
| 857 |
-
"codemirror_mode": {
|
| 858 |
-
"name": "ipython",
|
| 859 |
-
"version": 3
|
| 860 |
-
},
|
| 861 |
-
"file_extension": ".py",
|
| 862 |
-
"mimetype": "text/x-python",
|
| 863 |
-
"name": "python",
|
| 864 |
-
"nbconvert_exporter": "python",
|
| 865 |
-
"pygments_lexer": "ipython3",
|
| 866 |
-
"version": "3.8.18"
|
| 867 |
-
}
|
| 868 |
-
},
|
| 869 |
-
"nbformat": 4,
|
| 870 |
-
"nbformat_minor": 5
|
| 871 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/keras_models_fix.ipynb
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 3,
|
| 6 |
-
"id": "39935741-50be-4766-873c-99f3c3f14e55",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [
|
| 9 |
-
{
|
| 10 |
-
"name": "stdout",
|
| 11 |
-
"output_type": "stream",
|
| 12 |
-
"text": [
|
| 13 |
-
"A converter modelos .keras para .h5...\n",
|
| 14 |
-
"\n",
|
| 15 |
-
"WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
|
| 16 |
-
"Guardado com sucesso\n",
|
| 17 |
-
"\n",
|
| 18 |
-
"WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
|
| 19 |
-
"Guardado com sucesso\n",
|
| 20 |
-
"\n",
|
| 21 |
-
"WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
|
| 22 |
-
"Guardado com sucesso\n",
|
| 23 |
-
"\n",
|
| 24 |
-
"WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
|
| 25 |
-
"Guardado com sucesso\n",
|
| 26 |
-
"\n",
|
| 27 |
-
"Conversão concluída.\n"
|
| 28 |
-
]
|
| 29 |
-
}
|
| 30 |
-
],
|
| 31 |
-
"source": [
|
| 32 |
-
"import os\n",
|
| 33 |
-
"from tensorflow.keras.models import load_model\n",
|
| 34 |
-
"\n",
|
| 35 |
-
"MODELS_DIR = \"models\"\n",
|
| 36 |
-
"\n",
|
| 37 |
-
"# Modelos a converter: (nome original .keras → novo nome .h5)\n",
|
| 38 |
-
"modelos = [\n",
|
| 39 |
-
" (\"protbert_mlp.keras\", \"mlp_protbert.h5\"),\n",
|
| 40 |
-
" (\"protbertbfd_mlp.keras\", \"mlp_protbertbfd.h5\"),\n",
|
| 41 |
-
" (\"esm2_mlp.keras\", \"mlp_esm2.h5\"),\n",
|
| 42 |
-
" (\"modelo_ensemble_stacking.keras\", \"modelo_ensemble_stack.h5\"),\n",
|
| 43 |
-
"]\n",
|
| 44 |
-
"\n",
|
| 45 |
-
"print(\"A converter modelos .keras para .h5...\\n\")\n",
|
| 46 |
-
"\n",
|
| 47 |
-
"for origem, destino in modelos:\n",
|
| 48 |
-
" origem_path = os.path.join(MODELS_DIR, origem)\n",
|
| 49 |
-
" destino_path = os.path.join(MODELS_DIR, destino)\n",
|
| 50 |
-
"\n",
|
| 51 |
-
" if not os.path.exists(origem_path):\n",
|
| 52 |
-
" print(f\"Ficheiro não encontrado: {origem_path}\")\n",
|
| 53 |
-
" continue\n",
|
| 54 |
-
"\n",
|
| 55 |
-
"\n",
|
| 56 |
-
" model = load_model(origem_path, compile=False)\n",
|
| 57 |
-
" model.save(destino_path)\n",
|
| 58 |
-
"\n",
|
| 59 |
-
" print(\"Guardado com sucesso\\n\")\n",
|
| 60 |
-
"\n",
|
| 61 |
-
"print(\"Conversão concluída.\")\n"
|
| 62 |
-
]
|
| 63 |
-
},
|
| 64 |
-
{
|
| 65 |
-
"cell_type": "code",
|
| 66 |
-
"execution_count": null,
|
| 67 |
-
"id": "c0f2a4b7-d3eb-48a7-b97e-213b58b2b2ca",
|
| 68 |
-
"metadata": {},
|
| 69 |
-
"outputs": [],
|
| 70 |
-
"source": []
|
| 71 |
-
}
|
| 72 |
-
],
|
| 73 |
-
"metadata": {
|
| 74 |
-
"kernelspec": {
|
| 75 |
-
"display_name": "Python 3 (ipykernel)",
|
| 76 |
-
"language": "python",
|
| 77 |
-
"name": "python3"
|
| 78 |
-
},
|
| 79 |
-
"language_info": {
|
| 80 |
-
"codemirror_mode": {
|
| 81 |
-
"name": "ipython",
|
| 82 |
-
"version": 3
|
| 83 |
-
},
|
| 84 |
-
"file_extension": ".py",
|
| 85 |
-
"mimetype": "text/x-python",
|
| 86 |
-
"name": "python",
|
| 87 |
-
"nbconvert_exporter": "python",
|
| 88 |
-
"pygments_lexer": "ipython3",
|
| 89 |
-
"version": "3.8.18"
|
| 90 |
-
}
|
| 91 |
-
},
|
| 92 |
-
"nbformat": 4,
|
| 93 |
-
"nbformat_minor": 5
|
| 94 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|