ynuozhang commited on
Commit ·
04c2975
1
Parent(s): d2421cb
major update
Browse files- basic_models.txt +9 -9
- best_models.txt +9 -9
- inference.py +634 -547
- training_classifiers/binding_training.py +8 -5
- training_classifiers/long_aggregated.csv +2 -2
- training_classifiers/ml_uncertainty.py +135 -0
- training_classifiers/ml_uncertainty_reg.py +182 -0
- training_classifiers/refit_binding_affinity_seed.py +270 -0
- training_classifiers/refit_ml_walltime.py +209 -0
- training_classifiers/refit_nn_seed.py +315 -0
- training_classifiers/refit_regression_seed.py +318 -0
- training_classifiers/src_bash/binding_refit.bash +56 -0
- training_classifiers/src_bash/ml_uncertainty.bash +192 -0
- training_classifiers/src_bash/nn_uncertainty.bash +51 -0
- training_classifiers/train_ml.py +0 -3
- training_data_cleaned/binding_affinity/binding_affinity_smiles_meta_with_split.csv +2 -2
- training_data_cleaned/binding_affinity/binding_affinity_wt_meta_with_split.csv +2 -2
- training_data_cleaned/binding_affinity_split.py +459 -709
- training_data_cleaned/embed_smiles.py +319 -0
- training_data_cleaned/permeability_penetrance/permeability_smiles_meta_with_split.csv +3 -0
- training_data_cleaned/smiles_data_split.py +2 -148
basic_models.txt
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
|
| 2 |
-
Hemolysis, XGB,
|
| 3 |
-
Non-Fouling,
|
| 4 |
-
Solubility, CNN,
|
| 5 |
-
Permeability (Penetrance), XGB,
|
| 6 |
-
Toxicity, -,
|
| 7 |
-
Binding_affinity,
|
| 8 |
-
Permeability_PAMPA, -, CNN, Regression, -, -,
|
| 9 |
-
Permeability_CACO2, -, SVR, Regression, -, -,
|
| 10 |
-
Halflife, Transformer, XGB, Regression, -, -,
|
|
|
|
| 1 |
Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
|
| 2 |
+
Hemolysis, XGB, CNN (chemberta), Classifier, 0.2801, 0.564,
|
| 3 |
+
Non-Fouling, Transformer, XGB (peptideclm), Classifier, 0.57, 0.3892,
|
| 4 |
+
Solubility, CNN, Transformer (peptideclm), Classifier, 0.377, 0.329,
|
| 5 |
+
Permeability (Penetrance), XGB, XGB (chemberta), Classifier, 0.4301, 0.5028,
|
| 6 |
+
Toxicity, -, CNN (chemberta), Classifier, -, 0.49,
|
| 7 |
+
Binding_affinity, wt_wt_pooled, chemberta_smiles_pooled, Regression, -, -,
|
| 8 |
+
Permeability_PAMPA, -, CNN (chemberta), Regression, -, -,
|
| 9 |
+
Permeability_CACO2, -, SVR (chemberta), Regression, -, -,
|
| 10 |
+
Halflife, Transformer, XGB (peptideclm), Regression, -, -,
|
best_models.txt
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
|
| 2 |
-
Hemolysis, SVM,
|
| 3 |
-
Non-Fouling,
|
| 4 |
-
Solubility, CNN,
|
| 5 |
-
Permeability (Penetrance), SVM,
|
| 6 |
-
Toxicity, -,
|
| 7 |
-
Binding_affinity,
|
| 8 |
-
Permeability_PAMPA, -, CNN, Regression, -, -,
|
| 9 |
-
Permeability_CACO2, -, SVR, Regression, -, -,
|
| 10 |
-
Halflife, Transformer, XGB, Regression, -, -,
|
|
|
|
| 1 |
Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
|
| 2 |
+
Hemolysis, SVM, CNN (chemberta), Classifier, 0.2521, 0.564,
|
| 3 |
+
Non-Fouling, Transformer, ENET (peptideclm), Classifier, 0.57, 0.6969,
|
| 4 |
+
Solubility, CNN, Transformer (peptideclm), Classifier, 0.377, 0.329,
|
| 5 |
+
Permeability (Penetrance), SVM, SVM (chemberta), Classifier, 0.5493, 0.573,
|
| 6 |
+
Toxicity, -, CNN (chemberta), Classifier, -, 0.49,
|
| 7 |
+
Binding_affinity, wt_wt_pooled, chemberta_smiles_pooled, Regression, -, -,
|
| 8 |
+
Permeability_PAMPA, -, CNN (chemberta), Regression, -, -,
|
| 9 |
+
Permeability_CACO2, -, SVR (chemberta), Regression, -, -,
|
| 10 |
+
Halflife, Transformer, XGB (peptideclm), Regression, -, -,
|
inference.py
CHANGED
|
@@ -1,16 +1,13 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
-
|
| 3 |
import csv, re, json
|
| 4 |
from dataclasses import dataclass
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Dict, Optional, Tuple, Any, List
|
| 7 |
-
|
| 8 |
import numpy as np
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
import joblib
|
| 12 |
import xgboost as xgb
|
| 13 |
-
|
| 14 |
from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM
|
| 15 |
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 16 |
from lightning.pytorch import seed_everything
|
|
@@ -19,13 +16,31 @@ seed_everything(1986)
|
|
| 19 |
# -----------------------------
|
| 20 |
# Manifest
|
| 21 |
# -----------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
@dataclass(frozen=True)
|
| 23 |
class BestRow:
|
| 24 |
property_key: str
|
| 25 |
-
best_wt: Optional[str]
|
| 26 |
-
best_smiles: Optional[str]
|
| 27 |
-
task_type: str
|
| 28 |
-
thr_wt:
|
| 29 |
thr_smiles: Optional[float]
|
| 30 |
|
| 31 |
|
|
@@ -34,21 +49,16 @@ def _clean(s: str) -> str:
|
|
| 34 |
|
| 35 |
def _none_if_dash(s: str) -> Optional[str]:
|
| 36 |
s = _clean(s)
|
| 37 |
-
if s in {"", "-", "
|
| 38 |
-
return None
|
| 39 |
-
return s
|
| 40 |
|
| 41 |
def _float_or_none(s: str) -> Optional[float]:
|
| 42 |
s = _clean(s)
|
| 43 |
-
if s in {"", "-", "
|
| 44 |
-
return None
|
| 45 |
-
return float(s)
|
| 46 |
|
| 47 |
def normalize_property_key(name: str) -> str:
|
| 48 |
n = name.strip().lower()
|
| 49 |
n = re.sub(r"\s*\(.*?\)\s*", "", n)
|
| 50 |
n = n.replace("-", "_").replace(" ", "_")
|
| 51 |
-
|
| 52 |
if "permeability" in n and "pampa" not in n and "caco" not in n:
|
| 53 |
return "permeability_penetrance"
|
| 54 |
if n == "binding_affinity":
|
|
@@ -60,11 +70,40 @@ def normalize_property_key(name: str) -> str:
|
|
| 60 |
return n
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
|
| 64 |
-
"""
|
| 65 |
-
Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
|
| 66 |
-
Hemolysis, SVM, SGB, Classifier, 0.2801, 0.2223,
|
| 67 |
-
"""
|
| 68 |
p = Path(path)
|
| 69 |
out: Dict[str, BestRow] = {}
|
| 70 |
|
|
@@ -90,10 +129,13 @@ def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
|
|
| 90 |
continue
|
| 91 |
prop_key = normalize_property_key(prop_raw)
|
| 92 |
|
|
|
|
|
|
|
|
|
|
| 93 |
row = BestRow(
|
| 94 |
property_key=prop_key,
|
| 95 |
-
best_wt=
|
| 96 |
-
best_smiles=
|
| 97 |
task_type=_clean(rec.get("Type", "Classifier")),
|
| 98 |
thr_wt=_float_or_none(rec.get("Threshold_WT", "")),
|
| 99 |
thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")),
|
|
@@ -103,53 +145,32 @@ def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
|
|
| 103 |
return out
|
| 104 |
|
| 105 |
|
| 106 |
-
MODEL_ALIAS = {
|
| 107 |
-
"SVM": "svm_gpu",
|
| 108 |
-
"SVR": "svr",
|
| 109 |
-
"ENET": "enet_gpu",
|
| 110 |
-
"CNN": "cnn",
|
| 111 |
-
"MLP": "mlp",
|
| 112 |
-
"TRANSFORMER": "transformer",
|
| 113 |
-
"XGB": "xgb",
|
| 114 |
-
"XGB_REG": "xgb_reg",
|
| 115 |
-
"POOLED": "pooled",
|
| 116 |
-
"UNPOOLED": "unpooled",
|
| 117 |
-
"TRANSFORMER_WT_LOG": "transformer_wt_log",
|
| 118 |
-
}
|
| 119 |
-
def canon_model(label: Optional[str]) -> Optional[str]:
|
| 120 |
-
if label is None:
|
| 121 |
-
return None
|
| 122 |
-
k = label.strip().upper()
|
| 123 |
-
return MODEL_ALIAS.get(k, label.strip().lower())
|
| 124 |
-
|
| 125 |
-
|
| 126 |
# -----------------------------
|
| 127 |
# Generic artifact loading
|
| 128 |
# -----------------------------
|
| 129 |
def find_best_artifact(model_dir: Path) -> Path:
|
| 130 |
-
for pat in ["best_model.json", "best_model.pt", "best_model*.joblib"
|
|
|
|
| 131 |
hits = sorted(model_dir.glob(pat))
|
| 132 |
if hits:
|
| 133 |
return hits[0]
|
|
|
|
|
|
|
|
|
|
| 134 |
raise FileNotFoundError(f"No best_model artifact found in {model_dir}")
|
| 135 |
|
| 136 |
def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]:
|
| 137 |
art = find_best_artifact(model_dir)
|
| 138 |
-
|
| 139 |
if art.suffix == ".json":
|
| 140 |
booster = xgb.Booster()
|
| 141 |
-
#print(str(art))
|
| 142 |
booster.load_model(str(art))
|
| 143 |
return "xgb", booster, art
|
| 144 |
-
|
| 145 |
if art.suffix == ".joblib":
|
| 146 |
obj = joblib.load(art)
|
| 147 |
return "joblib", obj, art
|
| 148 |
-
|
| 149 |
if art.suffix == ".pt":
|
| 150 |
ckpt = torch.load(art, map_location=device, weights_only=False)
|
| 151 |
return "torch_ckpt", ckpt, art
|
| 152 |
-
|
| 153 |
raise ValueError(f"Unknown artifact type: {art}")
|
| 154 |
|
| 155 |
|
|
@@ -157,7 +178,7 @@ def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path
|
|
| 157 |
# NN architectures
|
| 158 |
# -----------------------------
|
| 159 |
class MaskedMeanPool(nn.Module):
|
| 160 |
-
def forward(self, X, M):
|
| 161 |
Mf = M.unsqueeze(-1).float()
|
| 162 |
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 163 |
return (X * Mf).sum(dim=1) / denom
|
|
@@ -167,34 +188,25 @@ class MLPHead(nn.Module):
|
|
| 167 |
super().__init__()
|
| 168 |
self.pool = MaskedMeanPool()
|
| 169 |
self.net = nn.Sequential(
|
| 170 |
-
nn.Linear(in_dim, hidden),
|
| 171 |
-
nn.GELU(),
|
| 172 |
-
nn.Dropout(dropout),
|
| 173 |
nn.Linear(hidden, 1),
|
| 174 |
)
|
| 175 |
def forward(self, X, M):
|
| 176 |
-
|
| 177 |
-
return self.net(z).squeeze(-1)
|
| 178 |
|
| 179 |
class CNNHead(nn.Module):
|
| 180 |
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
|
| 181 |
super().__init__()
|
| 182 |
-
blocks = []
|
| 183 |
-
ch = in_ch
|
| 184 |
for _ in range(layers):
|
| 185 |
-
blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
|
| 186 |
-
nn.GELU(),
|
| 187 |
-
nn.Dropout(dropout)]
|
| 188 |
ch = c
|
| 189 |
self.conv = nn.Sequential(*blocks)
|
| 190 |
self.head = nn.Linear(c, 1)
|
| 191 |
-
|
| 192 |
def forward(self, X, M):
|
| 193 |
-
|
| 194 |
-
Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
|
| 195 |
Mf = M.unsqueeze(-1).float()
|
| 196 |
-
|
| 197 |
-
pooled = (Y * Mf).sum(dim=1) / denom
|
| 198 |
return self.head(pooled).squeeze(-1)
|
| 199 |
|
| 200 |
class TransformerHead(nn.Module):
|
|
@@ -207,55 +219,36 @@ class TransformerHead(nn.Module):
|
|
| 207 |
)
|
| 208 |
self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
|
| 209 |
self.head = nn.Linear(d_model, 1)
|
| 210 |
-
|
| 211 |
def forward(self, X, M):
|
| 212 |
-
|
| 213 |
-
Z = self.proj(X)
|
| 214 |
-
Z = self.enc(Z, src_key_padding_mask=pad_mask)
|
| 215 |
Mf = M.unsqueeze(-1).float()
|
| 216 |
-
|
| 217 |
-
pooled = (Z * Mf).sum(dim=1) / denom
|
| 218 |
return self.head(pooled).squeeze(-1)
|
| 219 |
|
| 220 |
def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int:
|
| 221 |
-
if model_name == "mlp":
|
| 222 |
-
return int(sd["
|
| 223 |
-
if model_name == "
|
| 224 |
-
return int(sd["conv.0.weight"].shape[1])
|
| 225 |
-
if model_name == "transformer":
|
| 226 |
-
return int(sd["proj.weight"].shape[1])
|
| 227 |
raise ValueError(model_name)
|
| 228 |
|
| 229 |
def _infer_num_layers_from_sd(sd: dict, prefix: str = "enc.layers.") -> int:
|
| 230 |
-
# enc.layers.0.*, enc.layers.1.*, ...
|
| 231 |
idxs = set()
|
| 232 |
for k in sd.keys():
|
| 233 |
if k.startswith(prefix):
|
| 234 |
-
|
| 235 |
-
m = re.match(r"(\d+)\.", rest)
|
| 236 |
if m:
|
| 237 |
idxs.add(int(m.group(1)))
|
| 238 |
return (max(idxs) + 1) if idxs else 1
|
| 239 |
|
| 240 |
def _infer_transformer_arch_from_sd(sd: dict) -> Tuple[int, int, int]:
|
| 241 |
-
"""
|
| 242 |
-
Returns (d_model, layers, ff) inferred from weights.
|
| 243 |
-
- d_model from proj.weight (shape: [d_model, in_dim])
|
| 244 |
-
- layers from count of enc.layers.*
|
| 245 |
-
- ff from enc.layers.0.linear1.weight (shape: [ff, d_model])
|
| 246 |
-
"""
|
| 247 |
if "proj.weight" not in sd:
|
| 248 |
-
raise KeyError("Missing proj.weight in state_dict
|
| 249 |
d_model = int(sd["proj.weight"].shape[0])
|
| 250 |
-
layers
|
| 251 |
-
if "enc.layers.0.linear1.weight" in sd
|
| 252 |
-
ff = int(sd["enc.layers.0.linear1.weight"].shape[0])
|
| 253 |
-
else:
|
| 254 |
-
ff = 4 * d_model
|
| 255 |
return d_model, layers, ff
|
| 256 |
|
| 257 |
def _pick_nhead(d_model: int) -> int:
|
| 258 |
-
# prefer common head counts; must divide d_model
|
| 259 |
for h in (8, 6, 4, 3, 2, 1):
|
| 260 |
if d_model % h == 0:
|
| 261 |
return h
|
|
@@ -263,7 +256,7 @@ def _pick_nhead(d_model: int) -> int:
|
|
| 263 |
|
| 264 |
def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module:
|
| 265 |
params = ckpt["best_params"]
|
| 266 |
-
sd
|
| 267 |
in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name)))
|
| 268 |
dropout = float(params.get("dropout", 0.1))
|
| 269 |
|
|
@@ -273,44 +266,127 @@ def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.devic
|
|
| 273 |
model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
|
| 274 |
layers=int(params["layers"]), dropout=dropout)
|
| 275 |
elif model_name == "transformer":
|
| 276 |
-
# if transfer-learning ckpt omits arch params, infer from state_dict. special case for transformer_wt_log
|
| 277 |
d_model = params.get("d_model") or params.get("hidden") or params.get("hidden_dim")
|
| 278 |
-
|
| 279 |
if d_model is None:
|
| 280 |
d_model_i, layers_i, ff_i = _infer_transformer_arch_from_sd(sd)
|
| 281 |
nhead_i = _pick_nhead(d_model_i)
|
| 282 |
model = TransformerHead(
|
| 283 |
-
in_dim=in_dim,
|
| 284 |
-
|
| 285 |
-
nhead=int(params.get("nhead", nhead_i)),
|
| 286 |
-
layers=int(params.get("layers", layers_i)),
|
| 287 |
-
ff=int(params.get("ff", ff_i)),
|
| 288 |
dropout=float(params.get("dropout", dropout)),
|
| 289 |
)
|
| 290 |
else:
|
| 291 |
d_model = int(d_model)
|
| 292 |
model = TransformerHead(
|
| 293 |
-
in_dim=in_dim,
|
| 294 |
-
d_model=d_model,
|
| 295 |
nhead=int(params.get("nhead", _pick_nhead(d_model))),
|
| 296 |
layers=int(params.get("layers", 2)),
|
| 297 |
ff=int(params.get("ff", 4 * d_model)),
|
| 298 |
-
dropout=dropout
|
| 299 |
)
|
| 300 |
else:
|
| 301 |
raise ValueError(f"Unknown NN model_name={model_name}")
|
| 302 |
|
| 303 |
model.load_state_dict(sd)
|
| 304 |
-
model.to(device)
|
| 305 |
-
model.eval()
|
| 306 |
return model
|
| 307 |
|
| 308 |
|
| 309 |
# -----------------------------
|
| 310 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
# -----------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
def affinity_to_class(y: float) -> int:
|
| 313 |
-
# 0=High(>=9), 1=Moderate(7-9), 2=Low(<7)
|
| 314 |
if y >= 9.0: return 0
|
| 315 |
if y < 7.0: return 2
|
| 316 |
return 1
|
|
@@ -320,38 +396,31 @@ class CrossAttnPooled(nn.Module):
|
|
| 320 |
super().__init__()
|
| 321 |
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
|
| 322 |
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
|
| 323 |
-
|
| 324 |
self.layers = nn.ModuleList([])
|
| 325 |
for _ in range(n_layers):
|
| 326 |
self.layers.append(nn.ModuleDict({
|
| 327 |
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
|
| 328 |
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
|
| 329 |
-
"n1t": nn.LayerNorm(hidden),
|
| 330 |
-
"
|
| 331 |
-
"n1b": nn.LayerNorm(hidden),
|
| 332 |
-
"n2b": nn.LayerNorm(hidden),
|
| 333 |
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 334 |
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 335 |
}))
|
| 336 |
-
|
| 337 |
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
|
| 338 |
self.reg = nn.Linear(hidden, 1)
|
| 339 |
self.cls = nn.Linear(hidden, 3)
|
| 340 |
|
| 341 |
def forward(self, t_vec, b_vec):
|
| 342 |
-
t = self.t_proj(t_vec).unsqueeze(0)
|
| 343 |
-
b = self.b_proj(b_vec).unsqueeze(0)
|
| 344 |
for L in self.layers:
|
| 345 |
t_attn, _ = L["attn_tb"](t, b, b)
|
| 346 |
t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
|
| 347 |
t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
|
| 348 |
-
|
| 349 |
b_attn, _ = L["attn_bt"](b, t, t)
|
| 350 |
b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
|
| 351 |
b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
|
| 352 |
-
|
| 353 |
-
z = torch.cat([t[0], b[0]], dim=-1)
|
| 354 |
-
h = self.shared(z)
|
| 355 |
return self.reg(h).squeeze(-1), self.cls(h)
|
| 356 |
|
| 357 |
class CrossAttnUnpooled(nn.Module):
|
|
@@ -359,334 +428,247 @@ class CrossAttnUnpooled(nn.Module):
|
|
| 359 |
super().__init__()
|
| 360 |
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
|
| 361 |
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
|
| 362 |
-
|
| 363 |
self.layers = nn.ModuleList([])
|
| 364 |
for _ in range(n_layers):
|
| 365 |
self.layers.append(nn.ModuleDict({
|
| 366 |
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 367 |
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 368 |
-
"n1t": nn.LayerNorm(hidden),
|
| 369 |
-
"
|
| 370 |
-
"n1b": nn.LayerNorm(hidden),
|
| 371 |
-
"n2b": nn.LayerNorm(hidden),
|
| 372 |
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 373 |
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 374 |
}))
|
| 375 |
-
|
| 376 |
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
|
| 377 |
self.reg = nn.Linear(hidden, 1)
|
| 378 |
self.cls = nn.Linear(hidden, 3)
|
| 379 |
|
| 380 |
def _masked_mean(self, X, M):
|
| 381 |
Mf = M.unsqueeze(-1).float()
|
| 382 |
-
|
| 383 |
-
return (X * Mf).sum(dim=1) / denom
|
| 384 |
|
| 385 |
def forward(self, T, Mt, B, Mb):
|
| 386 |
-
T = self.t_proj(T)
|
| 387 |
-
|
| 388 |
-
kp_t = ~Mt
|
| 389 |
-
kp_b = ~Mb
|
| 390 |
-
|
| 391 |
for L in self.layers:
|
| 392 |
T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
|
| 393 |
-
T = L["n1t"](T + T_attn)
|
| 394 |
-
T = L["n2t"](T + L["fft"](T))
|
| 395 |
-
|
| 396 |
B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
|
| 397 |
-
Bx = L["n1b"](Bx + B_attn)
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
t_pool = self._masked_mean(T, Mt)
|
| 401 |
-
b_pool = self._masked_mean(Bx, Mb)
|
| 402 |
-
z = torch.cat([t_pool, b_pool], dim=-1)
|
| 403 |
-
h = self.shared(z)
|
| 404 |
return self.reg(h).squeeze(-1), self.cls(h)
|
| 405 |
|
| 406 |
def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module:
|
| 407 |
ckpt = torch.load(best_model_pt, map_location=device, weights_only=False)
|
| 408 |
params = ckpt["best_params"]
|
| 409 |
-
sd
|
| 410 |
-
|
| 411 |
-
# infer Ht/Hb from projection weights
|
| 412 |
Ht = int(sd["t_proj.0.weight"].shape[1])
|
| 413 |
Hb = int(sd["b_proj.0.weight"].shape[1])
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
n_layers=int(params["n_layers"]),
|
| 420 |
-
dropout=float(params["dropout"]),
|
| 421 |
-
)
|
| 422 |
-
|
| 423 |
-
if pooled_or_unpooled == "pooled":
|
| 424 |
-
model = CrossAttnPooled(**common)
|
| 425 |
-
elif pooled_or_unpooled == "unpooled":
|
| 426 |
-
model = CrossAttnUnpooled(**common)
|
| 427 |
-
else:
|
| 428 |
-
raise ValueError(pooled_or_unpooled)
|
| 429 |
-
|
| 430 |
model.load_state_dict(sd)
|
| 431 |
-
model.to(device).eval()
|
| 432 |
-
return model
|
| 433 |
|
| 434 |
|
| 435 |
# -----------------------------
|
| 436 |
# Embedding generation
|
| 437 |
# -----------------------------
|
| 438 |
def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor:
|
| 439 |
-
"""
|
| 440 |
-
Pytorch patch
|
| 441 |
-
"""
|
| 442 |
if hasattr(torch, "isin"):
|
| 443 |
return torch.isin(ids, test_ids)
|
| 444 |
-
# Fallback: compare against each special id
|
| 445 |
-
# (B,L,1) == (1,1,K) -> (B,L,K)
|
| 446 |
return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1)
|
| 447 |
-
|
| 448 |
class SMILESEmbedder:
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
- pooled(): mean over tokens where attention_mask==1 AND token_id not in SPECIAL_IDS
|
| 452 |
-
- unpooled(): returns token embeddings filtered to valid tokens (specials removed),
|
| 453 |
-
plus a 1-mask of length Li (since already filtered).
|
| 454 |
-
"""
|
| 455 |
-
def __init__(
|
| 456 |
-
self,
|
| 457 |
-
device: torch.device,
|
| 458 |
-
vocab_path: str,
|
| 459 |
-
splits_path: str,
|
| 460 |
-
clm_name: str = "aaronfeller/PeptideCLM-23M-all",
|
| 461 |
-
max_len: int = 512,
|
| 462 |
-
use_cache: bool = True,
|
| 463 |
-
):
|
| 464 |
self.device = device
|
| 465 |
self.max_len = max_len
|
| 466 |
self.use_cache = use_cache
|
| 467 |
-
|
| 468 |
self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path)
|
| 469 |
self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval()
|
| 470 |
-
|
| 471 |
self.special_ids = self._get_special_ids(self.tokenizer)
|
| 472 |
self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
|
| 473 |
-
if
|
| 474 |
-
|
| 475 |
self._cache_pooled: Dict[str, torch.Tensor] = {}
|
| 476 |
self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
|
| 477 |
|
| 478 |
@staticmethod
|
| 479 |
def _get_special_ids(tokenizer) -> List[int]:
|
| 480 |
-
cand = [
|
| 481 |
-
|
| 482 |
-
getattr(tokenizer, "cls_token_id", None),
|
| 483 |
-
getattr(tokenizer, "sep_token_id", None),
|
| 484 |
-
getattr(tokenizer, "bos_token_id", None),
|
| 485 |
-
getattr(tokenizer, "eos_token_id", None),
|
| 486 |
-
getattr(tokenizer, "mask_token_id", None),
|
| 487 |
-
]
|
| 488 |
return sorted({int(x) for x in cand if x is not None})
|
| 489 |
|
| 490 |
-
def _tokenize(self, smiles_list
|
| 491 |
-
tok = self.tokenizer(
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
padding=True,
|
| 495 |
-
truncation=True,
|
| 496 |
-
max_length=self.max_len,
|
| 497 |
-
)
|
| 498 |
-
for k in tok:
|
| 499 |
-
tok[k] = tok[k].to(self.device)
|
| 500 |
if "attention_mask" not in tok:
|
| 501 |
tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
|
| 502 |
return tok
|
| 503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
@torch.no_grad()
|
| 505 |
def pooled(self, smiles: str) -> torch.Tensor:
|
| 506 |
s = smiles.strip()
|
| 507 |
-
if self.use_cache and s in self._cache_pooled:
|
| 508 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
tok = self._tokenize([s])
|
| 511 |
-
|
| 512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
|
| 514 |
-
|
| 515 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
|
| 517 |
-
|
|
|
|
| 518 |
if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
|
| 519 |
valid = valid & (~_safe_isin(ids, self.special_ids_t))
|
|
|
|
| 520 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
vf = valid.unsqueeze(-1).float()
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
pooled = summed / denom # (1,H)
|
| 525 |
-
|
| 526 |
-
if self.use_cache:
|
| 527 |
-
self._cache_pooled[s] = pooled
|
| 528 |
return pooled
|
| 529 |
|
| 530 |
@torch.no_grad()
|
| 531 |
def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 532 |
-
"""
|
| 533 |
-
Returns:
|
| 534 |
-
X: (1, Li, H) float32 on device
|
| 535 |
-
M: (1, Li) bool on device
|
| 536 |
-
where Li excludes padding + special tokens.
|
| 537 |
-
"""
|
| 538 |
s = smiles.strip()
|
| 539 |
-
if self.use_cache and s in self._cache_unpooled:
|
| 540 |
-
return self._cache_unpooled[s]
|
| 541 |
-
|
| 542 |
tok = self._tokenize([s])
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
out = self.model(input_ids=ids, attention_mask=tok["attention_mask"])
|
| 547 |
-
h = out.last_hidden_state # (1,L,H)
|
| 548 |
-
|
| 549 |
-
valid = attn
|
| 550 |
-
if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
|
| 551 |
-
valid = valid & (~_safe_isin(ids, self.special_ids_t))
|
| 552 |
-
|
| 553 |
-
# filter valid tokens
|
| 554 |
-
keep = valid[0] # (L,)
|
| 555 |
-
X = h[:, keep, :] # (1,Li,H)
|
| 556 |
M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
|
| 557 |
-
|
| 558 |
-
if self.use_cache:
|
| 559 |
-
self._cache_unpooled[s] = (X, M)
|
| 560 |
return X, M
|
| 561 |
|
| 562 |
|
| 563 |
class WTEmbedder:
|
| 564 |
-
""
|
| 565 |
-
ESM2 embeddings for AA sequences.
|
| 566 |
-
- pooled(): mean over tokens where attention_mask==1 AND token_id not in {CLS, EOS, PAD,...}
|
| 567 |
-
- unpooled(): returns token embeddings filtered to valid tokens (specials removed),
|
| 568 |
-
plus a 1-mask of length Li (since already filtered).
|
| 569 |
-
"""
|
| 570 |
-
def __init__(
|
| 571 |
-
self,
|
| 572 |
-
device: torch.device,
|
| 573 |
-
esm_name: str = "facebook/esm2_t33_650M_UR50D",
|
| 574 |
-
max_len: int = 1022,
|
| 575 |
-
use_cache: bool = True,
|
| 576 |
-
):
|
| 577 |
self.device = device
|
| 578 |
self.max_len = max_len
|
| 579 |
self.use_cache = use_cache
|
| 580 |
-
|
| 581 |
self.tokenizer = EsmTokenizer.from_pretrained(esm_name)
|
| 582 |
self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval()
|
| 583 |
-
|
| 584 |
self.special_ids = self._get_special_ids(self.tokenizer)
|
| 585 |
self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
|
| 586 |
-
if
|
| 587 |
-
|
| 588 |
self._cache_pooled: Dict[str, torch.Tensor] = {}
|
| 589 |
self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
|
| 590 |
|
| 591 |
@staticmethod
|
| 592 |
def _get_special_ids(tokenizer) -> List[int]:
|
| 593 |
-
cand = [
|
| 594 |
-
|
| 595 |
-
getattr(tokenizer, "cls_token_id", None),
|
| 596 |
-
getattr(tokenizer, "sep_token_id", None),
|
| 597 |
-
getattr(tokenizer, "bos_token_id", None),
|
| 598 |
-
getattr(tokenizer, "eos_token_id", None),
|
| 599 |
-
getattr(tokenizer, "mask_token_id", None),
|
| 600 |
-
]
|
| 601 |
return sorted({int(x) for x in cand if x is not None})
|
| 602 |
|
| 603 |
-
def _tokenize(self, seq_list
|
| 604 |
-
tok = self.tokenizer(
|
| 605 |
-
|
| 606 |
-
return_tensors="pt",
|
| 607 |
-
padding=True,
|
| 608 |
-
truncation=True,
|
| 609 |
-
max_length=self.max_len,
|
| 610 |
-
)
|
| 611 |
tok = {k: v.to(self.device) for k, v in tok.items()}
|
| 612 |
if "attention_mask" not in tok:
|
| 613 |
tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
|
| 614 |
return tok
|
| 615 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
@torch.no_grad()
|
| 617 |
def pooled(self, seq: str) -> torch.Tensor:
|
| 618 |
s = seq.strip()
|
| 619 |
-
if self.use_cache and s in self._cache_pooled:
|
| 620 |
-
return self._cache_pooled[s]
|
| 621 |
-
|
| 622 |
tok = self._tokenize([s])
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
out = self.model(**tok)
|
| 627 |
-
h = out.last_hidden_state # (1,L,H)
|
| 628 |
-
|
| 629 |
-
valid = attn
|
| 630 |
-
if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
|
| 631 |
-
valid = valid & (~_safe_isin(ids, self.special_ids_t))
|
| 632 |
-
|
| 633 |
vf = valid.unsqueeze(-1).float()
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
pooled = summed / denom # (1,H)
|
| 637 |
-
|
| 638 |
-
if self.use_cache:
|
| 639 |
-
self._cache_pooled[s] = pooled
|
| 640 |
return pooled
|
| 641 |
|
| 642 |
@torch.no_grad()
|
| 643 |
def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 644 |
-
"""
|
| 645 |
-
Returns:
|
| 646 |
-
X: (1, Li, H) float32 on device
|
| 647 |
-
M: (1, Li) bool on device
|
| 648 |
-
where Li excludes padding + special tokens.
|
| 649 |
-
"""
|
| 650 |
s = seq.strip()
|
| 651 |
-
if self.use_cache and s in self._cache_unpooled:
|
| 652 |
-
return self._cache_unpooled[s]
|
| 653 |
-
|
| 654 |
tok = self._tokenize([s])
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
out = self.model(**tok)
|
| 659 |
-
h = out.last_hidden_state # (1,L,H)
|
| 660 |
-
|
| 661 |
-
valid = attn
|
| 662 |
-
if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
|
| 663 |
-
valid = valid & (~_safe_isin(ids, self.special_ids_t))
|
| 664 |
-
|
| 665 |
-
keep = valid[0] # (L,)
|
| 666 |
-
X = h[:, keep, :] # (1,Li,H)
|
| 667 |
M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
|
| 668 |
-
|
| 669 |
-
if self.use_cache:
|
| 670 |
-
self._cache_unpooled[s] = (X, M)
|
| 671 |
return X, M
|
| 672 |
|
| 673 |
|
| 674 |
-
|
| 675 |
# -----------------------------
|
| 676 |
# Predictor
|
| 677 |
# -----------------------------
|
|
|
|
| 678 |
class PeptiVersePredictor:
|
| 679 |
-
"""
|
| 680 |
-
- loads best models from training_classifiers/
|
| 681 |
-
- computes embeddings as needed (pooled/unpooled)
|
| 682 |
-
- supports: xgb, joblib(ENET/SVM/SVR), NN(mlp/cnn/transformer), binding pooled/unpooled.
|
| 683 |
-
"""
|
| 684 |
def __init__(
|
| 685 |
self,
|
| 686 |
manifest_path: str | Path,
|
| 687 |
classifier_weight_root: str | Path,
|
| 688 |
esm_name="facebook/esm2_t33_650M_UR50D",
|
| 689 |
clm_name="aaronfeller/PeptideCLM-23M-all",
|
|
|
|
| 690 |
smiles_vocab="tokenizer/new_vocab.txt",
|
| 691 |
smiles_splits="tokenizer/new_splits.txt",
|
| 692 |
device: Optional[str] = None,
|
|
@@ -697,293 +679,398 @@ class PeptiVersePredictor:
|
|
| 697 |
|
| 698 |
self.manifest = read_best_manifest_csv(manifest_path)
|
| 699 |
|
| 700 |
-
self.wt_embedder
|
| 701 |
-
self.smiles_embedder
|
| 702 |
-
|
| 703 |
-
|
|
|
|
| 704 |
|
| 705 |
-
self.models:
|
| 706 |
-
self.meta:
|
|
|
|
|
|
|
| 707 |
|
| 708 |
self._load_all_best_models()
|
| 709 |
|
| 710 |
-
def
|
| 711 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 712 |
disk_prop = "half_life" if prop_key == "halflife" else prop_key
|
| 713 |
base = self.training_root / disk_prop
|
| 714 |
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
if
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
if d.exists():
|
| 725 |
-
return d
|
| 726 |
-
|
| 727 |
-
if prop_key == "halflife" and model_name == "xgb":
|
| 728 |
-
d = base / ("xgb_wt_log" if mode == "wt" else "xgb_smiles")
|
| 729 |
-
if d.exists():
|
| 730 |
-
return d
|
| 731 |
|
| 732 |
candidates = [
|
| 733 |
-
base / f"{model_name}_{
|
| 734 |
base / model_name,
|
| 735 |
]
|
| 736 |
-
if mode == "wt":
|
| 737 |
-
candidates += [base / f"{model_name}_wt"]
|
| 738 |
-
if mode == "smiles":
|
| 739 |
-
candidates += [base / f"{model_name}_smiles"]
|
| 740 |
-
|
| 741 |
for d in candidates:
|
| 742 |
-
if d.exists():
|
| 743 |
-
return d
|
| 744 |
|
| 745 |
raise FileNotFoundError(
|
| 746 |
-
f"Cannot find model
|
| 747 |
)
|
| 748 |
|
| 749 |
-
|
| 750 |
def _load_all_best_models(self):
|
| 751 |
for prop_key, row in self.manifest.items():
|
| 752 |
-
for
|
| 753 |
-
("wt",
|
| 754 |
-
("smiles", row.best_smiles,
|
| 755 |
]:
|
| 756 |
-
|
| 757 |
-
if m is None:
|
| 758 |
continue
|
|
|
|
| 759 |
|
| 760 |
-
#
|
| 761 |
if prop_key == "binding_affinity":
|
| 762 |
-
|
| 763 |
-
pooled_or_unpooled =
|
| 764 |
-
folder = f"wt_{mode}_{pooled_or_unpooled}" # wt_wt_pooled / wt_smiles_unpooled etc.
|
| 765 |
model_dir = self.training_root / "binding_affinity" / folder
|
| 766 |
art = find_best_artifact(model_dir)
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
"
|
| 773 |
-
"
|
| 774 |
-
"
|
| 775 |
-
"
|
|
|
|
| 776 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 777 |
continue
|
| 778 |
|
| 779 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
kind, obj, art = load_artifact(model_dir, self.device)
|
| 781 |
|
| 782 |
-
if kind
|
| 783 |
-
self.
|
|
|
|
| 784 |
else:
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 804 |
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
- pooled np array shape (1,H) for xgb/joblib
|
| 809 |
-
- unpooled torch tensors (X,M) for NN
|
| 810 |
-
"""
|
| 811 |
-
model = self.models[(prop_key, mode)]
|
| 812 |
-
meta = self.meta[(prop_key, mode)]
|
| 813 |
-
kind = meta.get("kind", None)
|
| 814 |
-
model_name = meta.get("model_name", "")
|
| 815 |
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 820 |
if kind == "torch_ckpt":
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
""
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 850 |
|
| 851 |
if prop_key == "binding_affinity":
|
| 852 |
raise RuntimeError("Use predict_binding_affinity().")
|
| 853 |
|
| 854 |
-
#
|
| 855 |
if kind == "torch_ckpt":
|
| 856 |
-
X, M = self.
|
| 857 |
with torch.no_grad():
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
and mode == "wt"
|
| 864 |
-
and model_name in {"xgb_wt_log", "transformer_wt_log"}
|
| 865 |
-
):
|
| 866 |
-
y = float(np.expm1(y))
|
| 867 |
if task_type == "classifier":
|
| 868 |
-
|
| 869 |
-
out
|
|
|
|
| 870 |
if thr is not None:
|
| 871 |
-
out["label"] = int(
|
| 872 |
-
out["threshold"] = float(thr)
|
| 873 |
-
return out
|
| 874 |
else:
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
model_name = meta.get("model_name", "")
|
| 884 |
-
if (
|
| 885 |
-
prop_key == "halflife"
|
| 886 |
-
and mode == "wt"
|
| 887 |
-
and model_name in {"xgb_wt_log", "transformer_wt_log"}
|
| 888 |
-
):
|
| 889 |
pred = float(np.expm1(pred))
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
# joblib
|
| 896 |
-
|
| 897 |
-
feats = self.
|
| 898 |
-
# classifier vs regressor behavior differs by estimator
|
| 899 |
if task_type == "classifier":
|
| 900 |
if hasattr(model, "predict_proba"):
|
| 901 |
pred = float(model.predict_proba(feats)[:, 1][0])
|
|
|
|
|
|
|
| 902 |
else:
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
else:
|
| 907 |
-
pred = float(model.predict(feats)[0])
|
| 908 |
-
out = {"property": prop_key, "mode": mode, "score": pred}
|
| 909 |
if thr is not None:
|
| 910 |
-
out["label"] = int(pred >= float(thr))
|
| 911 |
-
out["threshold"] = float(thr)
|
| 912 |
-
return out
|
| 913 |
else:
|
| 914 |
pred = float(model.predict(feats)[0])
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
|
|
|
| 918 |
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
"""
|
| 924 |
-
prop_key = "binding_affinity"
|
| 925 |
-
if (prop_key, mode) not in self.models:
|
| 926 |
-
raise KeyError(f"No binding model loaded for ({prop_key}, {mode}).")
|
| 927 |
|
| 928 |
-
|
| 929 |
-
pooled_or_unpooled = self.meta[(prop_key, mode)]["model_name"] # pooled/unpooled
|
| 930 |
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 938 |
with torch.no_grad():
|
| 939 |
reg, logits = model(t_vec, b_vec)
|
| 940 |
-
affinity = float(reg.squeeze().cpu().item())
|
| 941 |
-
cls_logit = int(torch.argmax(logits, dim=-1).cpu().item())
|
| 942 |
-
cls_thr = affinity_to_class(affinity)
|
| 943 |
else:
|
| 944 |
T, Mt = self.wt_embedder.unpooled(target_seq)
|
| 945 |
-
if
|
| 946 |
-
|
| 947 |
-
|
| 948 |
-
B, Mb = self.smiles_embedder.unpooled(binder_str)
|
| 949 |
with torch.no_grad():
|
| 950 |
reg, logits = model(T, Mt, B, Mb)
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
names
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
"
|
| 959 |
-
"
|
|
|
|
| 960 |
"class_by_threshold": names[cls_thr],
|
| 961 |
-
"class_by_logits":
|
| 962 |
-
"binding_model":
|
| 963 |
}
|
| 964 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 965 |
|
| 966 |
if __name__ == "__main__":
|
| 967 |
-
|
| 968 |
-
manifest_path="basic_models.txt",
|
| 969 |
-
classifier_weight_root="./"
|
| 970 |
-
)
|
| 971 |
-
print(predictor.predict_property("hemolysis", "wt", "GIGAVLKVLTTGLPALISWIKRKRQQ"))
|
| 972 |
-
print(predictor.predict_binding_affinity("wt", target_seq="...", binder_str="..."))
|
| 973 |
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
|
| 977 |
-
|
| 978 |
-
wt = WTEmbedder(device)
|
| 979 |
-
sm = SMILESEmbedder(device,
|
| 980 |
-
vocab_path="./tokeizner/new_vocab.txt",
|
| 981 |
-
splits_path="./tokenizer/new_splits.txt"
|
| 982 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 983 |
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
|
|
|
| 1 |
from __future__ import annotations
|
|
|
|
| 2 |
import csv, re, json
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Dict, Optional, Tuple, Any, List
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
| 9 |
import joblib
|
| 10 |
import xgboost as xgb
|
|
|
|
| 11 |
from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM
|
| 12 |
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 13 |
from lightning.pytorch import seed_everything
|
|
|
|
| 16 |
# -----------------------------
|
| 17 |
# Manifest
|
| 18 |
# -----------------------------
|
| 19 |
+
|
| 20 |
+
EMB_TAG_TO_FOLDER_SUFFIX = {
|
| 21 |
+
"wt": "wt",
|
| 22 |
+
"peptideclm": "smiles",
|
| 23 |
+
"chemberta": "chemberta",
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
EMB_TAG_TO_RUNTIME_MODE = {
|
| 27 |
+
"wt": "wt",
|
| 28 |
+
"peptideclm": "smiles",
|
| 29 |
+
"chemberta": "chemberta",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
MAPIE_REGRESSION_MODELS = {"svr", "enet_gpu"}
|
| 33 |
+
DNN_ARCHS = {"mlp", "cnn", "transformer"}
|
| 34 |
+
XGB_MODELS = {"xgb", "xgb_reg", "xgb_wt_log", "xgb_smiles"}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
@dataclass(frozen=True)
|
| 38 |
class BestRow:
|
| 39 |
property_key: str
|
| 40 |
+
best_wt: Optional[Tuple[str, Optional[str]]]
|
| 41 |
+
best_smiles: Optional[Tuple[str, Optional[str]]]
|
| 42 |
+
task_type: str
|
| 43 |
+
thr_wt: Optional[float]
|
| 44 |
thr_smiles: Optional[float]
|
| 45 |
|
| 46 |
|
|
|
|
| 49 |
|
| 50 |
def _none_if_dash(s: str) -> Optional[str]:
|
| 51 |
s = _clean(s)
|
| 52 |
+
return None if s in {"", "-", "-", "NA", "N/A"} else s
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def _float_or_none(s: str) -> Optional[float]:
|
| 55 |
s = _clean(s)
|
| 56 |
+
return None if s in {"", "-", "-", "NA", "N/A"} else float(s)
|
|
|
|
|
|
|
| 57 |
|
| 58 |
def normalize_property_key(name: str) -> str:
|
| 59 |
n = name.strip().lower()
|
| 60 |
n = re.sub(r"\s*\(.*?\)\s*", "", n)
|
| 61 |
n = n.replace("-", "_").replace(" ", "_")
|
|
|
|
| 62 |
if "permeability" in n and "pampa" not in n and "caco" not in n:
|
| 63 |
return "permeability_penetrance"
|
| 64 |
if n == "binding_affinity":
|
|
|
|
| 70 |
return n
|
| 71 |
|
| 72 |
|
| 73 |
+
MODEL_ALIAS = {
|
| 74 |
+
"SVM": "svm_gpu",
|
| 75 |
+
"SVR": "svr",
|
| 76 |
+
"ENET": "enet_gpu",
|
| 77 |
+
"CNN": "cnn",
|
| 78 |
+
"MLP": "mlp",
|
| 79 |
+
"TRANSFORMER": "transformer",
|
| 80 |
+
"XGB": "xgb",
|
| 81 |
+
"XGB_REG": "xgb_reg",
|
| 82 |
+
"POOLED": "pooled",
|
| 83 |
+
"UNPOOLED": "unpooled",
|
| 84 |
+
"TRANSFORMER_WT_LOG": "transformer_wt_log",
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
def _parse_model_and_emb(raw: Optional[str]) -> Optional[Tuple[str, Optional[str]]]:
|
| 88 |
+
if raw is None:
|
| 89 |
+
return None
|
| 90 |
+
raw = _clean(raw)
|
| 91 |
+
if not raw or raw in {"-", "-", "NA", "N/A"}:
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
m = re.match(r"^(.+?)\s*\((.+?)\)\s*$", raw)
|
| 95 |
+
if m:
|
| 96 |
+
model_raw = m.group(1).strip()
|
| 97 |
+
emb_tag = m.group(2).strip().lower()
|
| 98 |
+
else:
|
| 99 |
+
model_raw = raw
|
| 100 |
+
emb_tag = None
|
| 101 |
+
|
| 102 |
+
canon = MODEL_ALIAS.get(model_raw.upper(), model_raw.lower())
|
| 103 |
+
return canon, emb_tag
|
| 104 |
+
|
| 105 |
+
|
| 106 |
def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
p = Path(path)
|
| 108 |
out: Dict[str, BestRow] = {}
|
| 109 |
|
|
|
|
| 129 |
continue
|
| 130 |
prop_key = normalize_property_key(prop_raw)
|
| 131 |
|
| 132 |
+
best_wt = _parse_model_and_emb(_none_if_dash(rec.get("Best_Model_WT", "")))
|
| 133 |
+
best_smiles = _parse_model_and_emb(_none_if_dash(rec.get("Best_Model_SMILES", "")))
|
| 134 |
+
|
| 135 |
row = BestRow(
|
| 136 |
property_key=prop_key,
|
| 137 |
+
best_wt=best_wt,
|
| 138 |
+
best_smiles=best_smiles,
|
| 139 |
task_type=_clean(rec.get("Type", "Classifier")),
|
| 140 |
thr_wt=_float_or_none(rec.get("Threshold_WT", "")),
|
| 141 |
thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")),
|
|
|
|
| 145 |
return out
|
| 146 |
|
| 147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
# -----------------------------
|
| 149 |
# Generic artifact loading
|
| 150 |
# -----------------------------
|
| 151 |
def find_best_artifact(model_dir: Path) -> Path:
|
| 152 |
+
for pat in ["best_model.json", "best_model.pt", "best_model*.joblib",
|
| 153 |
+
"model.json", "model.ubj", "final_model.json"]:
|
| 154 |
hits = sorted(model_dir.glob(pat))
|
| 155 |
if hits:
|
| 156 |
return hits[0]
|
| 157 |
+
seed_pt = model_dir / "seed_1986" / "model.pt"
|
| 158 |
+
if seed_pt.exists():
|
| 159 |
+
return seed_pt
|
| 160 |
raise FileNotFoundError(f"No best_model artifact found in {model_dir}")
|
| 161 |
|
| 162 |
def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]:
|
| 163 |
art = find_best_artifact(model_dir)
|
|
|
|
| 164 |
if art.suffix == ".json":
|
| 165 |
booster = xgb.Booster()
|
|
|
|
| 166 |
booster.load_model(str(art))
|
| 167 |
return "xgb", booster, art
|
|
|
|
| 168 |
if art.suffix == ".joblib":
|
| 169 |
obj = joblib.load(art)
|
| 170 |
return "joblib", obj, art
|
|
|
|
| 171 |
if art.suffix == ".pt":
|
| 172 |
ckpt = torch.load(art, map_location=device, weights_only=False)
|
| 173 |
return "torch_ckpt", ckpt, art
|
|
|
|
| 174 |
raise ValueError(f"Unknown artifact type: {art}")
|
| 175 |
|
| 176 |
|
|
|
|
| 178 |
# NN architectures
|
| 179 |
# -----------------------------
|
| 180 |
class MaskedMeanPool(nn.Module):
|
| 181 |
+
def forward(self, X, M):
|
| 182 |
Mf = M.unsqueeze(-1).float()
|
| 183 |
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 184 |
return (X * Mf).sum(dim=1) / denom
|
|
|
|
| 188 |
super().__init__()
|
| 189 |
self.pool = MaskedMeanPool()
|
| 190 |
self.net = nn.Sequential(
|
| 191 |
+
nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout),
|
|
|
|
|
|
|
| 192 |
nn.Linear(hidden, 1),
|
| 193 |
)
|
| 194 |
def forward(self, X, M):
|
| 195 |
+
return self.net(self.pool(X, M)).squeeze(-1)
|
|
|
|
| 196 |
|
| 197 |
class CNNHead(nn.Module):
|
| 198 |
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
|
| 199 |
super().__init__()
|
| 200 |
+
blocks, ch = [], in_ch
|
|
|
|
| 201 |
for _ in range(layers):
|
| 202 |
+
blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)]
|
|
|
|
|
|
|
| 203 |
ch = c
|
| 204 |
self.conv = nn.Sequential(*blocks)
|
| 205 |
self.head = nn.Linear(c, 1)
|
|
|
|
| 206 |
def forward(self, X, M):
|
| 207 |
+
Y = self.conv(X.transpose(1, 2)).transpose(1, 2)
|
|
|
|
| 208 |
Mf = M.unsqueeze(-1).float()
|
| 209 |
+
pooled = (Y * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
|
|
|
|
| 210 |
return self.head(pooled).squeeze(-1)
|
| 211 |
|
| 212 |
class TransformerHead(nn.Module):
|
|
|
|
| 219 |
)
|
| 220 |
self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
|
| 221 |
self.head = nn.Linear(d_model, 1)
|
|
|
|
| 222 |
def forward(self, X, M):
|
| 223 |
+
Z = self.enc(self.proj(X), src_key_padding_mask=~M)
|
|
|
|
|
|
|
| 224 |
Mf = M.unsqueeze(-1).float()
|
| 225 |
+
pooled = (Z * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
|
|
|
|
| 226 |
return self.head(pooled).squeeze(-1)
|
| 227 |
|
| 228 |
def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int:
|
| 229 |
+
if model_name == "mlp": return int(sd["net.0.weight"].shape[1])
|
| 230 |
+
if model_name == "cnn": return int(sd["conv.0.weight"].shape[1])
|
| 231 |
+
if model_name == "transformer": return int(sd["proj.weight"].shape[1])
|
|
|
|
|
|
|
|
|
|
| 232 |
raise ValueError(model_name)
|
| 233 |
|
| 234 |
def _infer_num_layers_from_sd(sd: dict, prefix: str = "enc.layers.") -> int:
|
|
|
|
| 235 |
idxs = set()
|
| 236 |
for k in sd.keys():
|
| 237 |
if k.startswith(prefix):
|
| 238 |
+
m = re.match(r"(\d+)\.", k[len(prefix):])
|
|
|
|
| 239 |
if m:
|
| 240 |
idxs.add(int(m.group(1)))
|
| 241 |
return (max(idxs) + 1) if idxs else 1
|
| 242 |
|
| 243 |
def _infer_transformer_arch_from_sd(sd: dict) -> Tuple[int, int, int]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
if "proj.weight" not in sd:
|
| 245 |
+
raise KeyError("Missing proj.weight in state_dict")
|
| 246 |
d_model = int(sd["proj.weight"].shape[0])
|
| 247 |
+
layers = _infer_num_layers_from_sd(sd, prefix="enc.layers.")
|
| 248 |
+
ff = int(sd["enc.layers.0.linear1.weight"].shape[0]) if "enc.layers.0.linear1.weight" in sd else 4 * d_model
|
|
|
|
|
|
|
|
|
|
| 249 |
return d_model, layers, ff
|
| 250 |
|
| 251 |
def _pick_nhead(d_model: int) -> int:
|
|
|
|
| 252 |
for h in (8, 6, 4, 3, 2, 1):
|
| 253 |
if d_model % h == 0:
|
| 254 |
return h
|
|
|
|
| 256 |
|
| 257 |
def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module:
|
| 258 |
params = ckpt["best_params"]
|
| 259 |
+
sd = ckpt["state_dict"]
|
| 260 |
in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name)))
|
| 261 |
dropout = float(params.get("dropout", 0.1))
|
| 262 |
|
|
|
|
| 266 |
model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
|
| 267 |
layers=int(params["layers"]), dropout=dropout)
|
| 268 |
elif model_name == "transformer":
|
|
|
|
| 269 |
d_model = params.get("d_model") or params.get("hidden") or params.get("hidden_dim")
|
|
|
|
| 270 |
if d_model is None:
|
| 271 |
d_model_i, layers_i, ff_i = _infer_transformer_arch_from_sd(sd)
|
| 272 |
nhead_i = _pick_nhead(d_model_i)
|
| 273 |
model = TransformerHead(
|
| 274 |
+
in_dim=in_dim, d_model=int(d_model_i), nhead=int(params.get("nhead", nhead_i)),
|
| 275 |
+
layers=int(params.get("layers", layers_i)), ff=int(params.get("ff", ff_i)),
|
|
|
|
|
|
|
|
|
|
| 276 |
dropout=float(params.get("dropout", dropout)),
|
| 277 |
)
|
| 278 |
else:
|
| 279 |
d_model = int(d_model)
|
| 280 |
model = TransformerHead(
|
| 281 |
+
in_dim=in_dim, d_model=d_model,
|
|
|
|
| 282 |
nhead=int(params.get("nhead", _pick_nhead(d_model))),
|
| 283 |
layers=int(params.get("layers", 2)),
|
| 284 |
ff=int(params.get("ff", 4 * d_model)),
|
| 285 |
+
dropout=dropout,
|
| 286 |
)
|
| 287 |
else:
|
| 288 |
raise ValueError(f"Unknown NN model_name={model_name}")
|
| 289 |
|
| 290 |
model.load_state_dict(sd)
|
| 291 |
+
model.to(device).eval()
|
|
|
|
| 292 |
return model
|
| 293 |
|
| 294 |
|
| 295 |
# -----------------------------
|
| 296 |
+
# Wrappers
|
| 297 |
+
# -----------------------------
|
| 298 |
+
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
|
| 299 |
+
|
| 300 |
+
class PassthroughRegressor(BaseEstimator, RegressorMixin):
|
| 301 |
+
def __init__(self, preds: np.ndarray):
|
| 302 |
+
self.preds = preds
|
| 303 |
+
def fit(self, X, y): return self
|
| 304 |
+
def predict(self, X): return self.preds[:len(X)]
|
| 305 |
+
|
| 306 |
+
class PassthroughClassifier(BaseEstimator, ClassifierMixin):
|
| 307 |
+
def __init__(self, preds: np.ndarray):
|
| 308 |
+
self.preds = preds
|
| 309 |
+
self.classes_ = np.array([0, 1])
|
| 310 |
+
def fit(self, X, y): return self
|
| 311 |
+
def predict(self, X): return (self.preds[:len(X)] >= 0.5).astype(int)
|
| 312 |
+
def predict_proba(self, X):
|
| 313 |
+
p = self.preds[:len(X)]
|
| 314 |
+
return np.stack([1 - p, p], axis=1)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
# -----------------------------
|
| 318 |
+
# Uncertainty helpers
|
| 319 |
+
# -----------------------------
|
| 320 |
+
SEED_DIRS = ["seed_1986", "seed_42", "seed_0", "seed_123", "seed_12345"]
|
| 321 |
+
|
| 322 |
+
def load_seed_ensemble(model_dir: Path, arch: str, device: torch.device) -> List[nn.Module]:
|
| 323 |
+
ensemble = []
|
| 324 |
+
for sd_name in SEED_DIRS:
|
| 325 |
+
pt = model_dir / sd_name / "model.pt"
|
| 326 |
+
if not pt.exists():
|
| 327 |
+
continue
|
| 328 |
+
ckpt = torch.load(pt, map_location=device, weights_only=False)
|
| 329 |
+
ensemble.append(build_torch_model_from_ckpt(arch, ckpt, device))
|
| 330 |
+
return ensemble
|
| 331 |
+
|
| 332 |
+
def _binary_entropy(p: float) -> float:
|
| 333 |
+
p = float(np.clip(p, 1e-9, 1 - 1e-9))
|
| 334 |
+
return float(-p * np.log(p) - (1 - p) * np.log(1 - p))
|
| 335 |
+
|
| 336 |
+
def _ensemble_clf_uncertainty(ensemble: List[nn.Module], X: torch.Tensor, M: torch.Tensor) -> float:
|
| 337 |
+
probs = []
|
| 338 |
+
with torch.no_grad():
|
| 339 |
+
for m in ensemble:
|
| 340 |
+
logit = m(X, M).squeeze().float().cpu().item()
|
| 341 |
+
probs.append(1.0 / (1.0 + np.exp(-logit)))
|
| 342 |
+
return _binary_entropy(float(np.mean(probs)))
|
| 343 |
+
|
| 344 |
+
def _ensemble_reg_uncertainty(ensemble: List[nn.Module], X: torch.Tensor, M: torch.Tensor) -> float:
|
| 345 |
+
preds = []
|
| 346 |
+
with torch.no_grad():
|
| 347 |
+
for m in ensemble:
|
| 348 |
+
preds.append(m(X, M).squeeze().float().cpu().item())
|
| 349 |
+
return float(np.std(preds))
|
| 350 |
+
|
| 351 |
+
def _mapie_uncertainty(mapie_bundle: dict, score: float,
|
| 352 |
+
embedding: Optional[np.ndarray] = None) -> Tuple[float, float]:
|
| 353 |
+
"""
|
| 354 |
+
Returns (ci_low, ci_high) from a conformal bundle.
|
| 355 |
+
- adaptive: {"quantile": q, "sigma_model": xgb, "emb_tag": ..., "adaptive": True}
|
| 356 |
+
Input-dependent: interval = score +/- q * sigma(embedding)
|
| 357 |
+
- plain_quantile: {"quantile": q, "alpha": ...}
|
| 358 |
+
Fixed-width: interval = score +/- q
|
| 359 |
+
"""
|
| 360 |
+
# Adaptive format is input-dependent interval
|
| 361 |
+
if mapie_bundle.get("adaptive") and "sigma_model" in mapie_bundle:
|
| 362 |
+
q = float(mapie_bundle["quantile"])
|
| 363 |
+
if embedding is not None:
|
| 364 |
+
sigma_model = mapie_bundle["sigma_model"]
|
| 365 |
+
sigma = float(sigma_model.predict(xgb.DMatrix(embedding.reshape(1, -1)))[0])
|
| 366 |
+
sigma = max(sigma, 1e-6)
|
| 367 |
+
else:
|
| 368 |
+
# No embedding available - fall back to fixed interval with sigma=1
|
| 369 |
+
sigma = 1.0
|
| 370 |
+
return float(score - q * sigma), float(score + q * sigma)
|
| 371 |
+
|
| 372 |
+
# Plain quantile format
|
| 373 |
+
if "quantile" in mapie_bundle:
|
| 374 |
+
q = float(mapie_bundle["quantile"])
|
| 375 |
+
return float(score - q), float(score + q)
|
| 376 |
+
|
| 377 |
+
X_dummy = np.zeros((1, 1))
|
| 378 |
+
result = mapie.predict(X_dummy)
|
| 379 |
+
if isinstance(result, tuple):
|
| 380 |
+
intervals = np.asarray(result[1])
|
| 381 |
+
if intervals.ndim == 3:
|
| 382 |
+
return float(intervals[0, 0, 0]), float(intervals[0, 1, 0])
|
| 383 |
+
return float(intervals[0, 0]), float(intervals[0, 1])
|
| 384 |
+
raise RuntimeError(
|
| 385 |
+
f"Cannot extract intervals: unknown MAPIE bundle format. "
|
| 386 |
+
f"Bundle keys: {list(mapie_bundle.keys())}."
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
def affinity_to_class(y: float) -> int:
|
|
|
|
| 390 |
if y >= 9.0: return 0
|
| 391 |
if y < 7.0: return 2
|
| 392 |
return 1
|
|
|
|
| 396 |
super().__init__()
|
| 397 |
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
|
| 398 |
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
|
|
|
|
| 399 |
self.layers = nn.ModuleList([])
|
| 400 |
for _ in range(n_layers):
|
| 401 |
self.layers.append(nn.ModuleDict({
|
| 402 |
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
|
| 403 |
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
|
| 404 |
+
"n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden),
|
| 405 |
+
"n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden),
|
|
|
|
|
|
|
| 406 |
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 407 |
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 408 |
}))
|
|
|
|
| 409 |
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
|
| 410 |
self.reg = nn.Linear(hidden, 1)
|
| 411 |
self.cls = nn.Linear(hidden, 3)
|
| 412 |
|
| 413 |
def forward(self, t_vec, b_vec):
|
| 414 |
+
t = self.t_proj(t_vec).unsqueeze(0)
|
| 415 |
+
b = self.b_proj(b_vec).unsqueeze(0)
|
| 416 |
for L in self.layers:
|
| 417 |
t_attn, _ = L["attn_tb"](t, b, b)
|
| 418 |
t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
|
| 419 |
t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
|
|
|
|
| 420 |
b_attn, _ = L["attn_bt"](b, t, t)
|
| 421 |
b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
|
| 422 |
b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
|
| 423 |
+
h = self.shared(torch.cat([t[0], b[0]], dim=-1))
|
|
|
|
|
|
|
| 424 |
return self.reg(h).squeeze(-1), self.cls(h)
|
| 425 |
|
| 426 |
class CrossAttnUnpooled(nn.Module):
|
|
|
|
| 428 |
super().__init__()
|
| 429 |
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
|
| 430 |
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
|
|
|
|
| 431 |
self.layers = nn.ModuleList([])
|
| 432 |
for _ in range(n_layers):
|
| 433 |
self.layers.append(nn.ModuleDict({
|
| 434 |
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 435 |
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 436 |
+
"n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden),
|
| 437 |
+
"n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden),
|
|
|
|
|
|
|
| 438 |
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 439 |
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 440 |
}))
|
|
|
|
| 441 |
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
|
| 442 |
self.reg = nn.Linear(hidden, 1)
|
| 443 |
self.cls = nn.Linear(hidden, 3)
|
| 444 |
|
| 445 |
def _masked_mean(self, X, M):
|
| 446 |
Mf = M.unsqueeze(-1).float()
|
| 447 |
+
return (X * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
|
|
|
|
| 448 |
|
| 449 |
def forward(self, T, Mt, B, Mb):
|
| 450 |
+
T = self.t_proj(T); Bx = self.b_proj(B)
|
| 451 |
+
kp_t, kp_b = ~Mt, ~Mb
|
|
|
|
|
|
|
|
|
|
| 452 |
for L in self.layers:
|
| 453 |
T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
|
| 454 |
+
T = L["n1t"](T + T_attn); T = L["n2t"](T + L["fft"](T))
|
|
|
|
|
|
|
| 455 |
B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
|
| 456 |
+
Bx = L["n1b"](Bx + B_attn); Bx = L["n2b"](Bx + L["ffb"](Bx))
|
| 457 |
+
h = self.shared(torch.cat([self._masked_mean(T, Mt), self._masked_mean(Bx, Mb)], dim=-1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
return self.reg(h).squeeze(-1), self.cls(h)
|
| 459 |
|
| 460 |
def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module:
|
| 461 |
ckpt = torch.load(best_model_pt, map_location=device, weights_only=False)
|
| 462 |
params = ckpt["best_params"]
|
| 463 |
+
sd = ckpt["state_dict"]
|
|
|
|
|
|
|
| 464 |
Ht = int(sd["t_proj.0.weight"].shape[1])
|
| 465 |
Hb = int(sd["b_proj.0.weight"].shape[1])
|
| 466 |
+
common = dict(Ht=Ht, Hb=Hb, hidden=int(params["hidden_dim"]),
|
| 467 |
+
n_heads=int(params["n_heads"]), n_layers=int(params["n_layers"]),
|
| 468 |
+
dropout=float(params["dropout"]))
|
| 469 |
+
cls = CrossAttnPooled if pooled_or_unpooled == "pooled" else CrossAttnUnpooled
|
| 470 |
+
model = cls(**common)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
model.load_state_dict(sd)
|
| 472 |
+
return model.to(device).eval()
|
|
|
|
| 473 |
|
| 474 |
|
| 475 |
# -----------------------------
|
| 476 |
# Embedding generation
|
| 477 |
# -----------------------------
|
| 478 |
def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
| 479 |
if hasattr(torch, "isin"):
|
| 480 |
return torch.isin(ids, test_ids)
|
|
|
|
|
|
|
| 481 |
return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1)
|
| 482 |
+
|
| 483 |
class SMILESEmbedder:
|
| 484 |
+
def __init__(self, device, vocab_path, splits_path,
|
| 485 |
+
clm_name="aaronfeller/PeptideCLM-23M-all", max_len=512, use_cache=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
self.device = device
|
| 487 |
self.max_len = max_len
|
| 488 |
self.use_cache = use_cache
|
|
|
|
| 489 |
self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path)
|
| 490 |
self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval()
|
|
|
|
| 491 |
self.special_ids = self._get_special_ids(self.tokenizer)
|
| 492 |
self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
|
| 493 |
+
if self.special_ids else None)
|
|
|
|
| 494 |
self._cache_pooled: Dict[str, torch.Tensor] = {}
|
| 495 |
self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
|
| 496 |
|
| 497 |
@staticmethod
|
| 498 |
def _get_special_ids(tokenizer) -> List[int]:
|
| 499 |
+
cand = [getattr(tokenizer, f"{x}_token_id", None)
|
| 500 |
+
for x in ("pad", "cls", "sep", "bos", "eos", "mask")]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
return sorted({int(x) for x in cand if x is not None})
|
| 502 |
|
| 503 |
+
def _tokenize(self, smiles_list):
|
| 504 |
+
tok = self.tokenizer(smiles_list, return_tensors="pt", padding=True,
|
| 505 |
+
truncation=True, max_length=self.max_len)
|
| 506 |
+
for k in tok: tok[k] = tok[k].to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
if "attention_mask" not in tok:
|
| 508 |
tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
|
| 509 |
return tok
|
| 510 |
|
| 511 |
+
def _valid_mask(self, ids, attn):
|
| 512 |
+
valid = attn.bool()
|
| 513 |
+
if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
|
| 514 |
+
valid = valid & (~_safe_isin(ids, self.special_ids_t))
|
| 515 |
+
return valid
|
| 516 |
+
|
| 517 |
@torch.no_grad()
|
| 518 |
def pooled(self, smiles: str) -> torch.Tensor:
|
| 519 |
s = smiles.strip()
|
| 520 |
+
if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s]
|
| 521 |
+
tok = self._tokenize([s])
|
| 522 |
+
h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
|
| 523 |
+
valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
|
| 524 |
+
vf = valid.unsqueeze(-1).float()
|
| 525 |
+
pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9)
|
| 526 |
+
if self.use_cache: self._cache_pooled[s] = pooled
|
| 527 |
+
return pooled
|
| 528 |
|
| 529 |
+
@torch.no_grad()
|
| 530 |
+
def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 531 |
+
s = smiles.strip()
|
| 532 |
+
if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s]
|
| 533 |
tok = self._tokenize([s])
|
| 534 |
+
h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
|
| 535 |
+
valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
|
| 536 |
+
X = h[:, valid[0], :]
|
| 537 |
+
M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
|
| 538 |
+
if self.use_cache: self._cache_unpooled[s] = (X, M)
|
| 539 |
+
return X, M
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class ChemBERTaEmbedder:
|
| 543 |
+
def __init__(self, device, model_name="DeepChem/ChemBERTa-77M-MLM",
|
| 544 |
+
max_len=512, use_cache=True):
|
| 545 |
+
from transformers import AutoTokenizer, AutoModel
|
| 546 |
+
self.device = device
|
| 547 |
+
self.max_len = max_len
|
| 548 |
+
self.use_cache = use_cache
|
| 549 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 550 |
+
self.model = AutoModel.from_pretrained(model_name).to(device).eval()
|
| 551 |
+
self.special_ids = self._get_special_ids(self.tokenizer)
|
| 552 |
+
self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
|
| 553 |
+
if self.special_ids else None)
|
| 554 |
+
self._cache_pooled: Dict[str, torch.Tensor] = {}
|
| 555 |
+
self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
|
| 556 |
+
|
| 557 |
+
@staticmethod
|
| 558 |
+
def _get_special_ids(tokenizer) -> List[int]:
|
| 559 |
+
cand = [getattr(tokenizer, f"{x}_token_id", None)
|
| 560 |
+
for x in ("pad", "cls", "sep", "bos", "eos", "mask")]
|
| 561 |
+
return sorted({int(x) for x in cand if x is not None})
|
| 562 |
|
| 563 |
+
def _tokenize(self, smiles_list):
|
| 564 |
+
tok = self.tokenizer(smiles_list, return_tensors="pt", padding=True,
|
| 565 |
+
truncation=True, max_length=self.max_len)
|
| 566 |
+
for k in tok: tok[k] = tok[k].to(self.device)
|
| 567 |
+
if "attention_mask" not in tok:
|
| 568 |
+
tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
|
| 569 |
+
return tok
|
| 570 |
|
| 571 |
+
def _valid_mask(self, ids, attn):
|
| 572 |
+
valid = attn.bool()
|
| 573 |
if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
|
| 574 |
valid = valid & (~_safe_isin(ids, self.special_ids_t))
|
| 575 |
+
return valid
|
| 576 |
|
| 577 |
+
@torch.no_grad()
|
| 578 |
+
def pooled(self, smiles: str) -> torch.Tensor:
|
| 579 |
+
s = smiles.strip()
|
| 580 |
+
if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s]
|
| 581 |
+
tok = self._tokenize([s])
|
| 582 |
+
h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
|
| 583 |
+
valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
|
| 584 |
vf = valid.unsqueeze(-1).float()
|
| 585 |
+
pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9)
|
| 586 |
+
if self.use_cache: self._cache_pooled[s] = pooled
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
return pooled
|
| 588 |
|
| 589 |
@torch.no_grad()
|
| 590 |
def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 591 |
s = smiles.strip()
|
| 592 |
+
if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s]
|
|
|
|
|
|
|
| 593 |
tok = self._tokenize([s])
|
| 594 |
+
h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
|
| 595 |
+
valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
|
| 596 |
+
X = h[:, valid[0], :]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
|
| 598 |
+
if self.use_cache: self._cache_unpooled[s] = (X, M)
|
|
|
|
|
|
|
| 599 |
return X, M
|
| 600 |
|
| 601 |
|
| 602 |
class WTEmbedder:
|
| 603 |
+
def __init__(self, device, esm_name="facebook/esm2_t33_650M_UR50D", max_len=1022, use_cache=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
self.device = device
|
| 605 |
self.max_len = max_len
|
| 606 |
self.use_cache = use_cache
|
|
|
|
| 607 |
self.tokenizer = EsmTokenizer.from_pretrained(esm_name)
|
| 608 |
self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval()
|
|
|
|
| 609 |
self.special_ids = self._get_special_ids(self.tokenizer)
|
| 610 |
self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
|
| 611 |
+
if self.special_ids else None)
|
|
|
|
| 612 |
self._cache_pooled: Dict[str, torch.Tensor] = {}
|
| 613 |
self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
|
| 614 |
|
| 615 |
@staticmethod
|
| 616 |
def _get_special_ids(tokenizer) -> List[int]:
|
| 617 |
+
cand = [getattr(tokenizer, f"{x}_token_id", None)
|
| 618 |
+
for x in ("pad", "cls", "sep", "bos", "eos", "mask")]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
return sorted({int(x) for x in cand if x is not None})
|
| 620 |
|
| 621 |
+
def _tokenize(self, seq_list):
|
| 622 |
+
tok = self.tokenizer(seq_list, return_tensors="pt", padding=True,
|
| 623 |
+
truncation=True, max_length=self.max_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
tok = {k: v.to(self.device) for k, v in tok.items()}
|
| 625 |
if "attention_mask" not in tok:
|
| 626 |
tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
|
| 627 |
return tok
|
| 628 |
|
| 629 |
+
def _valid_mask(self, ids, attn):
|
| 630 |
+
valid = attn.bool()
|
| 631 |
+
if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
|
| 632 |
+
valid = valid & (~_safe_isin(ids, self.special_ids_t))
|
| 633 |
+
return valid
|
| 634 |
+
|
| 635 |
@torch.no_grad()
|
| 636 |
def pooled(self, seq: str) -> torch.Tensor:
|
| 637 |
s = seq.strip()
|
| 638 |
+
if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s]
|
|
|
|
|
|
|
| 639 |
tok = self._tokenize([s])
|
| 640 |
+
h = self.model(**tok).last_hidden_state
|
| 641 |
+
valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
vf = valid.unsqueeze(-1).float()
|
| 643 |
+
pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9)
|
| 644 |
+
if self.use_cache: self._cache_pooled[s] = pooled
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
return pooled
|
| 646 |
|
| 647 |
@torch.no_grad()
|
| 648 |
def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
s = seq.strip()
|
| 650 |
+
if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s]
|
|
|
|
|
|
|
| 651 |
tok = self._tokenize([s])
|
| 652 |
+
h = self.model(**tok).last_hidden_state
|
| 653 |
+
valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
|
| 654 |
+
X = h[:, valid[0], :]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 655 |
M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
|
| 656 |
+
if self.use_cache: self._cache_unpooled[s] = (X, M)
|
|
|
|
|
|
|
| 657 |
return X, M
|
| 658 |
|
| 659 |
|
|
|
|
| 660 |
# -----------------------------
|
| 661 |
# Predictor
|
| 662 |
# -----------------------------
|
| 663 |
+
|
| 664 |
class PeptiVersePredictor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
def __init__(
|
| 666 |
self,
|
| 667 |
manifest_path: str | Path,
|
| 668 |
classifier_weight_root: str | Path,
|
| 669 |
esm_name="facebook/esm2_t33_650M_UR50D",
|
| 670 |
clm_name="aaronfeller/PeptideCLM-23M-all",
|
| 671 |
+
chemberta_name="DeepChem/ChemBERTa-77M-MLM",
|
| 672 |
smiles_vocab="tokenizer/new_vocab.txt",
|
| 673 |
smiles_splits="tokenizer/new_splits.txt",
|
| 674 |
device: Optional[str] = None,
|
|
|
|
| 679 |
|
| 680 |
self.manifest = read_best_manifest_csv(manifest_path)
|
| 681 |
|
| 682 |
+
self.wt_embedder = WTEmbedder(self.device, esm_name=esm_name)
|
| 683 |
+
self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name,
|
| 684 |
+
vocab_path=str(self.root / smiles_vocab),
|
| 685 |
+
splits_path=str(self.root / smiles_splits))
|
| 686 |
+
self.chemberta_embedder = ChemBERTaEmbedder(self.device, model_name=chemberta_name)
|
| 687 |
|
| 688 |
+
self.models: Dict[Tuple[str, str], Any] = {}
|
| 689 |
+
self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {}
|
| 690 |
+
self.mapie: Dict[Tuple[str, str], dict] = {}
|
| 691 |
+
self.ensembles: Dict[Tuple[str, str], List] = {}
|
| 692 |
|
| 693 |
self._load_all_best_models()
|
| 694 |
|
| 695 |
+
def _get_embedder(self, emb_tag: str):
|
| 696 |
+
if emb_tag == "wt": return self.wt_embedder
|
| 697 |
+
if emb_tag == "peptideclm": return self.smiles_embedder
|
| 698 |
+
if emb_tag == "chemberta": return self.chemberta_embedder
|
| 699 |
+
raise ValueError(f"Unknown emb_tag={emb_tag!r}")
|
| 700 |
+
|
| 701 |
+
def _embed_pooled(self, emb_tag: str, input_str: str) -> np.ndarray:
|
| 702 |
+
v = self._get_embedder(emb_tag).pooled(input_str)
|
| 703 |
+
feats = v.detach().cpu().numpy().astype(np.float32)
|
| 704 |
+
feats = np.nan_to_num(feats, nan=0.0)
|
| 705 |
+
return np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 706 |
+
|
| 707 |
+
def _embed_unpooled(self, emb_tag: str, input_str: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 708 |
+
return self._get_embedder(emb_tag).unpooled(input_str)
|
| 709 |
+
|
| 710 |
+
def _resolve_dir(self, prop_key: str, model_name: str, emb_tag: str) -> Path:
|
| 711 |
disk_prop = "half_life" if prop_key == "halflife" else prop_key
|
| 712 |
base = self.training_root / disk_prop
|
| 713 |
|
| 714 |
+
folder_suffix = EMB_TAG_TO_FOLDER_SUFFIX.get(emb_tag, emb_tag)
|
| 715 |
+
|
| 716 |
+
if prop_key == "halflife" and emb_tag == "wt":
|
| 717 |
+
if model_name == "transformer":
|
| 718 |
+
for d in [base / "transformer_wt_log", base / "transformer_wt"]:
|
| 719 |
+
if d.exists(): return d
|
| 720 |
+
if model_name in {"xgb", "xgb_reg"}:
|
| 721 |
+
d = base / "xgb_wt_log"
|
| 722 |
+
if d.exists(): return d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 723 |
|
| 724 |
candidates = [
|
| 725 |
+
base / f"{model_name}_{folder_suffix}",
|
| 726 |
base / model_name,
|
| 727 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
for d in candidates:
|
| 729 |
+
if d.exists(): return d
|
|
|
|
| 730 |
|
| 731 |
raise FileNotFoundError(
|
| 732 |
+
f"Cannot find model dir for {prop_key}/{model_name}/{emb_tag}. Tried: {candidates}"
|
| 733 |
)
|
| 734 |
|
|
|
|
| 735 |
def _load_all_best_models(self):
|
| 736 |
for prop_key, row in self.manifest.items():
|
| 737 |
+
for col, parsed, thr in [
|
| 738 |
+
("wt", row.best_wt, row.thr_wt),
|
| 739 |
+
("smiles", row.best_smiles, row.thr_smiles),
|
| 740 |
]:
|
| 741 |
+
if parsed is None:
|
|
|
|
| 742 |
continue
|
| 743 |
+
model_name, emb_tag = parsed
|
| 744 |
|
| 745 |
+
# binding affinity
|
| 746 |
if prop_key == "binding_affinity":
|
| 747 |
+
folder = model_name
|
| 748 |
+
pooled_or_unpooled = "unpooled" if "unpooled" in folder else "pooled"
|
|
|
|
| 749 |
model_dir = self.training_root / "binding_affinity" / folder
|
| 750 |
art = find_best_artifact(model_dir)
|
| 751 |
+
model = load_binding_model(art, pooled_or_unpooled, self.device)
|
| 752 |
+
self.models[(prop_key, col)] = model
|
| 753 |
+
self.meta[(prop_key, col)] = {
|
| 754 |
+
"task_type": "Regression",
|
| 755 |
+
"threshold": None,
|
| 756 |
+
"artifact": str(art),
|
| 757 |
+
"model_name": pooled_or_unpooled,
|
| 758 |
+
"emb_tag": emb_tag,
|
| 759 |
+
"folder": folder,
|
| 760 |
+
"kind": "binding",
|
| 761 |
}
|
| 762 |
+
print(f" [LOAD] binding_affinity ({col}): folder={folder}, arch={pooled_or_unpooled}, emb_tag={emb_tag}, art={art.name}")
|
| 763 |
+
mapie_path = model_dir / "mapie_calibration.joblib"
|
| 764 |
+
if mapie_path.exists():
|
| 765 |
+
try:
|
| 766 |
+
self.mapie[(prop_key, col)] = joblib.load(mapie_path)
|
| 767 |
+
print(f" MAPIE loaded from {mapie_path.name}")
|
| 768 |
+
except Exception as e:
|
| 769 |
+
print(f" MAPIE load FAILED for ({prop_key}, {col}): {e}")
|
| 770 |
+
else:
|
| 771 |
+
print(f" No MAPIE bundle found (uncertainty will be unavailable)")
|
| 772 |
continue
|
| 773 |
|
| 774 |
+
# infer emb_tag
|
| 775 |
+
if emb_tag is None:
|
| 776 |
+
emb_tag = col
|
| 777 |
+
|
| 778 |
+
model_dir = self._resolve_dir(prop_key, model_name, emb_tag)
|
| 779 |
kind, obj, art = load_artifact(model_dir, self.device)
|
| 780 |
|
| 781 |
+
if kind == "torch_ckpt":
|
| 782 |
+
arch = self._base_arch(model_name)
|
| 783 |
+
model = build_torch_model_from_ckpt(arch, obj, self.device)
|
| 784 |
else:
|
| 785 |
+
model = obj
|
| 786 |
+
|
| 787 |
+
self.models[(prop_key, col)] = model
|
| 788 |
+
self.meta[(prop_key, col)] = {
|
| 789 |
+
"task_type": row.task_type,
|
| 790 |
+
"threshold": thr,
|
| 791 |
+
"artifact": str(art),
|
| 792 |
+
"model_name": model_name,
|
| 793 |
+
"emb_tag": emb_tag,
|
| 794 |
+
"kind": kind,
|
| 795 |
+
}
|
| 796 |
+
|
| 797 |
+
print(f" [LOAD] ({prop_key}, {col}): kind={kind}, model={model_name}, emb={emb_tag}, task={row.task_type}, art={art.name}")
|
| 798 |
+
|
| 799 |
+
# MAPIE: SVR/ElasticNet, XGBoost regression, AND all regression torch_ckpt
|
| 800 |
+
is_regression = row.task_type.lower() == "regression"
|
| 801 |
+
wants_mapie = (
|
| 802 |
+
(model_name in MAPIE_REGRESSION_MODELS and is_regression)
|
| 803 |
+
or (kind == "xgb" and is_regression)
|
| 804 |
+
or (kind == "torch_ckpt" and is_regression)
|
| 805 |
+
)
|
| 806 |
+
if wants_mapie:
|
| 807 |
+
mapie_path = model_dir / "mapie_calibration.joblib"
|
| 808 |
+
if mapie_path.exists():
|
| 809 |
+
try:
|
| 810 |
+
self.mapie[(prop_key, col)] = joblib.load(mapie_path)
|
| 811 |
+
print(f" MAPIE loaded from {mapie_path.name}")
|
| 812 |
+
except Exception as e:
|
| 813 |
+
print(f" MAPIE load FAILED for ({prop_key}, {col}): {e}")
|
| 814 |
+
else:
|
| 815 |
+
print(f" No MAPIE bundle found at {mapie_path} (will fall back to ensemble if available)")
|
| 816 |
+
|
| 817 |
+
# Seed ensembles: DNN only, used when MAPIE not available
|
| 818 |
+
if kind == "torch_ckpt":
|
| 819 |
+
arch = self._base_arch(model_name)
|
| 820 |
+
ens = load_seed_ensemble(model_dir, arch, self.device)
|
| 821 |
+
if ens:
|
| 822 |
+
self.ensembles[(prop_key, col)] = ens
|
| 823 |
+
if (prop_key, col) in self.mapie:
|
| 824 |
+
print(f" Seed ensemble: {len(ens)} seeds loaded (MAPIE takes priority for regression)")
|
| 825 |
+
else:
|
| 826 |
+
unc_type = "ensemble_predictive_entropy" if row.task_type.lower() == "classifier" else "ensemble_std"
|
| 827 |
+
print(f" Seed ensemble: {len(ens)} seeds loaded uncertainty method: {unc_type}")
|
| 828 |
+
else:
|
| 829 |
+
if (prop_key, col) in self.mapie:
|
| 830 |
+
print(f" No seed ensemble (MAPIE covers uncertainty)")
|
| 831 |
+
else:
|
| 832 |
+
print(f" No seed ensemble found (checked: {SEED_DIRS}) - uncertainty unavailable")
|
| 833 |
|
| 834 |
+
# XGBoost/SVM classifiers: binary entropy
|
| 835 |
+
if kind in ("xgb", "joblib") and row.task_type.lower() == "classifier":
|
| 836 |
+
print(f" Uncertainty method: binary_predictive_entropy (computed at inference)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 837 |
|
| 838 |
+
@staticmethod
|
| 839 |
+
def _base_arch(model_name: str) -> str:
|
| 840 |
+
if model_name.startswith("transformer"): return "transformer"
|
| 841 |
+
if model_name.startswith("mlp"): return "mlp"
|
| 842 |
+
if model_name.startswith("cnn"): return "cnn"
|
| 843 |
+
return model_name
|
| 844 |
+
|
| 845 |
+
# Feature extraction
|
| 846 |
+
def _get_features(self, prop_key: str, col: str, input_str: str):
|
| 847 |
+
meta = self.meta[(prop_key, col)]
|
| 848 |
+
emb_tag = meta["emb_tag"]
|
| 849 |
+
kind = meta["kind"]
|
| 850 |
if kind == "torch_ckpt":
|
| 851 |
+
return self._embed_unpooled(emb_tag, input_str)
|
| 852 |
+
return self._embed_pooled(emb_tag, input_str)
|
| 853 |
+
|
| 854 |
+
# Uncertainty
|
| 855 |
+
def _compute_uncertainty(self, prop_key: str, col: str, input_str: str,
|
| 856 |
+
score: float) -> Tuple[Any, str]:
|
| 857 |
+
meta = self.meta[(prop_key, col)]
|
| 858 |
+
kind = meta["kind"]
|
| 859 |
+
model_name = meta["model_name"]
|
| 860 |
+
task_type = meta["task_type"].lower()
|
| 861 |
+
emb_tag = meta["emb_tag"]
|
| 862 |
+
|
| 863 |
+
# Pooled embedding for adaptive MAPIE sigma model
|
| 864 |
+
def get_pooled_emb():
|
| 865 |
+
return self._embed_pooled(emb_tag, input_str) if emb_tag else None
|
| 866 |
+
|
| 867 |
+
# DNN
|
| 868 |
+
if kind == "torch_ckpt":
|
| 869 |
+
# Regression: prefer MAPIE if available
|
| 870 |
+
if task_type == "regression":
|
| 871 |
+
mapie_bundle = self.mapie.get((prop_key, col))
|
| 872 |
+
if mapie_bundle:
|
| 873 |
+
emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None
|
| 874 |
+
lo, hi = _mapie_uncertainty(mapie_bundle, score, emb)
|
| 875 |
+
return (lo, hi), "conformal_prediction_interval"
|
| 876 |
+
# Fall back to seed ensemble std
|
| 877 |
+
ens = self.ensembles.get((prop_key, col))
|
| 878 |
+
if ens:
|
| 879 |
+
X, M = self._embed_unpooled(emb_tag, input_str)
|
| 880 |
+
return _ensemble_reg_uncertainty(ens, X, M), "ensemble_std"
|
| 881 |
+
return None, "unavailable (no MAPIE bundle and no seed ensemble)"
|
| 882 |
+
# Classifier: ensemble predictive entropy
|
| 883 |
+
ens = self.ensembles.get((prop_key, col))
|
| 884 |
+
if not ens:
|
| 885 |
+
return None, "unavailable (no seed ensemble found)"
|
| 886 |
+
X, M = self._embed_unpooled(emb_tag, input_str)
|
| 887 |
+
return _ensemble_clf_uncertainty(ens, X, M), "ensemble_predictive_entropy"
|
| 888 |
+
|
| 889 |
+
# XGBoost
|
| 890 |
+
if kind == "xgb":
|
| 891 |
+
if task_type == "classifier":
|
| 892 |
+
return _binary_entropy(score), "binary_predictive_entropy"
|
| 893 |
+
mapie_bundle = self.mapie.get((prop_key, col))
|
| 894 |
+
if mapie_bundle:
|
| 895 |
+
emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None
|
| 896 |
+
lo, hi = _mapie_uncertainty(mapie_bundle, score, emb)
|
| 897 |
+
return (lo, hi), "conformal_prediction_interval"
|
| 898 |
+
return None, "unavailable (no MAPIE bundle for XGBoost regression)"
|
| 899 |
+
|
| 900 |
+
# SVR / ElasticNet regression: MAPIE
|
| 901 |
+
if kind == "joblib" and model_name in MAPIE_REGRESSION_MODELS and task_type == "regression":
|
| 902 |
+
mapie_bundle = self.mapie.get((prop_key, col))
|
| 903 |
+
if mapie_bundle:
|
| 904 |
+
emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None
|
| 905 |
+
lo, hi = _mapie_uncertainty(mapie_bundle, score, emb)
|
| 906 |
+
return (lo, hi), "conformal_prediction_interval"
|
| 907 |
+
return None, "unavailable (MAPIE bundle not found)"
|
| 908 |
+
|
| 909 |
+
# joblib classifiers (SVM, ElasticNet used as classifier)
|
| 910 |
+
if kind == "joblib" and task_type == "classifier":
|
| 911 |
+
return _binary_entropy(score), "binary_predictive_entropy_single_model"
|
| 912 |
+
|
| 913 |
+
return None, "unavailable"
|
| 914 |
+
|
| 915 |
+
def predict_property(self, prop_key: str, col: str, input_str: str,
|
| 916 |
+
uncertainty: bool = False) -> Dict[str, Any]:
|
| 917 |
+
if (prop_key, col) not in self.models:
|
| 918 |
+
raise KeyError(f"No model loaded for ({prop_key}, {col}).")
|
| 919 |
+
|
| 920 |
+
meta = self.meta[(prop_key, col)]
|
| 921 |
+
model = self.models[(prop_key, col)]
|
| 922 |
+
task_type = meta["task_type"].lower()
|
| 923 |
+
thr = meta.get("threshold")
|
| 924 |
+
kind = meta["kind"]
|
| 925 |
+
model_name = meta["model_name"]
|
| 926 |
|
| 927 |
if prop_key == "binding_affinity":
|
| 928 |
raise RuntimeError("Use predict_binding_affinity().")
|
| 929 |
|
| 930 |
+
# DNN
|
| 931 |
if kind == "torch_ckpt":
|
| 932 |
+
X, M = self._get_features(prop_key, col, input_str)
|
| 933 |
with torch.no_grad():
|
| 934 |
+
raw = model(X, M).squeeze().float().cpu().item()
|
| 935 |
+
|
| 936 |
+
if prop_key == "halflife" and col == "wt" and "log" in model_name:
|
| 937 |
+
raw = float(np.expm1(raw))
|
| 938 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 939 |
if task_type == "classifier":
|
| 940 |
+
score = float(1.0 / (1.0 + np.exp(-raw)))
|
| 941 |
+
out = {"property": prop_key, "col": col, "score": score,
|
| 942 |
+
"emb_tag": meta["emb_tag"]}
|
| 943 |
if thr is not None:
|
| 944 |
+
out["label"] = int(score >= float(thr)); out["threshold"] = float(thr)
|
|
|
|
|
|
|
| 945 |
else:
|
| 946 |
+
out = {"property": prop_key, "col": col, "score": float(raw),
|
| 947 |
+
"emb_tag": meta["emb_tag"]}
|
| 948 |
+
|
| 949 |
+
# XGBoost
|
| 950 |
+
elif kind == "xgb":
|
| 951 |
+
feats = self._get_features(prop_key, col, input_str)
|
| 952 |
+
pred = float(model.predict(xgb.DMatrix(feats))[0])
|
| 953 |
+
if prop_key == "halflife" and col == "wt" and "log" in model_name:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 954 |
pred = float(np.expm1(pred))
|
| 955 |
+
out = {"property": prop_key, "col": col, "score": pred,
|
| 956 |
+
"emb_tag": meta["emb_tag"]}
|
| 957 |
+
if task_type == "classifier" and thr is not None:
|
| 958 |
+
out["label"] = int(pred >= float(thr)); out["threshold"] = float(thr)
|
| 959 |
+
|
| 960 |
+
# joblib (SVM / ElasticNet / SVR)
|
| 961 |
+
elif kind == "joblib":
|
| 962 |
+
feats = self._get_features(prop_key, col, input_str)
|
|
|
|
| 963 |
if task_type == "classifier":
|
| 964 |
if hasattr(model, "predict_proba"):
|
| 965 |
pred = float(model.predict_proba(feats)[:, 1][0])
|
| 966 |
+
elif hasattr(model, "decision_function"):
|
| 967 |
+
pred = float(1.0 / (1.0 + np.exp(-model.decision_function(feats)[0])))
|
| 968 |
else:
|
| 969 |
+
pred = float(model.predict(feats)[0])
|
| 970 |
+
out = {"property": prop_key, "col": col, "score": pred,
|
| 971 |
+
"emb_tag": meta["emb_tag"]}
|
|
|
|
|
|
|
|
|
|
| 972 |
if thr is not None:
|
| 973 |
+
out["label"] = int(pred >= float(thr)); out["threshold"] = float(thr)
|
|
|
|
|
|
|
| 974 |
else:
|
| 975 |
pred = float(model.predict(feats)[0])
|
| 976 |
+
out = {"property": prop_key, "col": col, "score": pred,
|
| 977 |
+
"emb_tag": meta["emb_tag"]}
|
| 978 |
+
else:
|
| 979 |
+
raise RuntimeError(f"Unknown kind={kind}")
|
| 980 |
|
| 981 |
+
if uncertainty:
|
| 982 |
+
u_val, u_type = self._compute_uncertainty(prop_key, col, input_str, out["score"])
|
| 983 |
+
out["uncertainty"] = u_val
|
| 984 |
+
out["uncertainty_type"] = u_type
|
|
|
|
|
|
|
|
|
|
|
|
|
| 985 |
|
| 986 |
+
return out
|
|
|
|
| 987 |
|
| 988 |
+
def predict_binding_affinity(self, col: str, target_seq: str, binder_str: str,
|
| 989 |
+
uncertainty: bool = False) -> Dict[str, Any]:
|
| 990 |
+
prop_key = "binding_affinity"
|
| 991 |
+
if (prop_key, col) not in self.models:
|
| 992 |
+
raise KeyError(f"No binding model loaded for ({prop_key}, {col}).")
|
| 993 |
+
|
| 994 |
+
model = self.models[(prop_key, col)]
|
| 995 |
+
meta = self.meta[(prop_key, col)]
|
| 996 |
+
arch = meta["model_name"]
|
| 997 |
+
emb_tag = meta.get("emb_tag")
|
| 998 |
+
|
| 999 |
+
if arch == "pooled":
|
| 1000 |
+
t_vec = self.wt_embedder.pooled(target_seq)
|
| 1001 |
+
b_vec = self._get_embedder(emb_tag or col).pooled(binder_str) if emb_tag else \
|
| 1002 |
+
(self.wt_embedder.pooled(binder_str) if col == "wt" else self.smiles_embedder.pooled(binder_str))
|
| 1003 |
with torch.no_grad():
|
| 1004 |
reg, logits = model(t_vec, b_vec)
|
|
|
|
|
|
|
|
|
|
| 1005 |
else:
|
| 1006 |
T, Mt = self.wt_embedder.unpooled(target_seq)
|
| 1007 |
+
binder_emb = self._get_embedder(emb_tag or col) if emb_tag else \
|
| 1008 |
+
(self.wt_embedder if col == "wt" else self.smiles_embedder)
|
| 1009 |
+
B, Mb = binder_emb.unpooled(binder_str)
|
|
|
|
| 1010 |
with torch.no_grad():
|
| 1011 |
reg, logits = model(T, Mt, B, Mb)
|
| 1012 |
+
|
| 1013 |
+
affinity = float(reg.squeeze().cpu().item())
|
| 1014 |
+
cls_logit = int(torch.argmax(logits, dim=-1).cpu().item())
|
| 1015 |
+
cls_thr = affinity_to_class(affinity)
|
| 1016 |
+
names = {0: "High (≥9)", 1: "Moderate (7-9)", 2: "Low (<7)"}
|
| 1017 |
+
|
| 1018 |
+
out = {
|
| 1019 |
+
"property": "binding_affinity",
|
| 1020 |
+
"col": col,
|
| 1021 |
+
"affinity": affinity,
|
| 1022 |
"class_by_threshold": names[cls_thr],
|
| 1023 |
+
"class_by_logits": names[cls_logit],
|
| 1024 |
+
"binding_model": arch,
|
| 1025 |
}
|
| 1026 |
|
| 1027 |
+
if uncertainty:
|
| 1028 |
+
mapie_bundle = self.mapie.get((prop_key, col))
|
| 1029 |
+
if mapie_bundle:
|
| 1030 |
+
if mapie_bundle.get("adaptive") and "sigma_model" in mapie_bundle:
|
| 1031 |
+
# Concatenate target + binder pooled embeddings for sigma model
|
| 1032 |
+
binder_emb_tag = mapie_bundle.get("emb_tag") or col
|
| 1033 |
+
target_emb_tag = mapie_bundle.get("target_emb_tag", "wt")
|
| 1034 |
+
t_vec = self.wt_embedder.pooled(target_seq).cpu().float().numpy()
|
| 1035 |
+
b_vec = self._get_embedder(binder_emb_tag).pooled(binder_str).cpu().float().numpy()
|
| 1036 |
+
emb = np.concatenate([t_vec, b_vec], axis=1)
|
| 1037 |
+
else:
|
| 1038 |
+
emb = None
|
| 1039 |
+
lo, hi = _mapie_uncertainty(mapie_bundle, affinity, emb)
|
| 1040 |
+
out["uncertainty"] = (lo, hi)
|
| 1041 |
+
out["uncertainty_type"] = "conformal_prediction_interval"
|
| 1042 |
+
else:
|
| 1043 |
+
out["uncertainty"] = None
|
| 1044 |
+
out["uncertainty_type"] = "unavailable (no MAPIE bundle found)"
|
| 1045 |
+
|
| 1046 |
+
return out
|
| 1047 |
|
| 1048 |
if __name__ == "__main__":
|
| 1049 |
+
root = Path(__file__).resolve().parent # current script folder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1050 |
|
| 1051 |
+
predictor = PeptiVersePredictor(
|
| 1052 |
+
manifest_path=root / "best_models.txt",
|
| 1053 |
+
classifier_weight_root=root
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1054 |
)
|
| 1055 |
+
print(predictor.training_root)
|
| 1056 |
+
print("MAPIE keys:", list(predictor.mapie.keys()))
|
| 1057 |
+
print("Ensemble keys:", list(predictor.ensembles.keys()))
|
| 1058 |
+
|
| 1059 |
+
seq = "GIGAVLKVLTTGLPALISWIKRKRQQ"
|
| 1060 |
+
smiles = "C(C)C[C@@H]1NC(=O)[C@@H]2CCCN2C(=O)[C@@H](CC(C)C)NC(=O)[C@@H](CC(C)C)N(C)C(=O)[C@H](C)NC(=O)[C@H](Cc2ccccc2)NC1=O"
|
| 1061 |
+
|
| 1062 |
+
print(predictor.predict_property("hemolysis", "wt", seq))
|
| 1063 |
+
print(predictor.predict_property("hemolysis", "smiles", smiles, uncertainty=True))
|
| 1064 |
+
print(predictor.predict_property("nf", "wt", seq, uncertainty=True))
|
| 1065 |
+
print(predictor.predict_property("nf", "smiles", smiles, uncertainty=True))
|
| 1066 |
+
print(predictor.predict_binding_affinity("wt", target_seq=seq, binder_str="GIGAVLKVLT"))
|
| 1067 |
+
print(predictor.predict_binding_affinity("wt", target_seq=seq, binder_str="GIGAVLKVLT", uncertainty=True))
|
| 1068 |
+
seq1 = "GIGAVLKVLTTGLPALISWIKRKRQQ"
|
| 1069 |
+
seq2 = "ACDEFGHIKLMNPQRSTVWY"
|
| 1070 |
|
| 1071 |
+
r1 = predictor.predict_binding_affinity("wt", target_seq=seq2, binder_str="GIGAVLKVLT", uncertainty=True)
|
| 1072 |
+
r2 = predictor.predict_property("nf", "wt", seq1, uncertainty=True)
|
| 1073 |
+
r3 = predictor.predict_property("nf", "wt", seq2, uncertainty=True)
|
| 1074 |
+
print(r1)
|
| 1075 |
+
print(r2)
|
| 1076 |
+
print(r3)
|
training_classifiers/binding_training.py
CHANGED
|
@@ -51,8 +51,9 @@ def load_split_paired(path: str):
|
|
| 51 |
# Collate: pooled paired
|
| 52 |
# -----------------------------
|
| 53 |
def collate_pair_pooled(batch):
|
| 54 |
-
|
| 55 |
-
|
|
|
|
| 56 |
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
|
| 57 |
return Pt, Pb, y
|
| 58 |
|
|
@@ -147,7 +148,7 @@ class CrossAttnUnpooled(nn.Module):
|
|
| 147 |
self.layers = nn.ModuleList([])
|
| 148 |
for _ in range(n_layers):
|
| 149 |
self.layers.append(nn.ModuleDict({
|
| 150 |
-
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 151 |
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 152 |
"n1t": nn.LayerNorm(hidden),
|
| 153 |
"n2t": nn.LayerNorm(hidden),
|
|
@@ -272,7 +273,8 @@ def objective_crossattn(trial: optuna.Trial, mode: str, train_ds, val_ds) -> flo
|
|
| 272 |
# infer dims from first row
|
| 273 |
if mode == "pooled":
|
| 274 |
Ht = len(train_ds[0]["target_embedding"])
|
| 275 |
-
|
|
|
|
| 276 |
collate = collate_pair_pooled
|
| 277 |
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 278 |
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
|
@@ -349,7 +351,8 @@ def run(dataset_path: str, out_dir: str, mode: str, n_trials: int = 50):
|
|
| 349 |
|
| 350 |
if mode == "pooled":
|
| 351 |
Ht = len(train_ds[0]["target_embedding"])
|
| 352 |
-
|
|
|
|
| 353 |
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 354 |
collate = collate_pair_pooled
|
| 355 |
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
|
|
|
| 51 |
# Collate: pooled paired
|
| 52 |
# -----------------------------
|
| 53 |
def collate_pair_pooled(batch):
|
| 54 |
+
binder_key = "binder_embedding" if "binder_embedding" in batch[0] else "embedding"
|
| 55 |
+
Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32)
|
| 56 |
+
Pb = torch.tensor([x[binder_key] for x in batch], dtype=torch.float32)
|
| 57 |
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
|
| 58 |
return Pt, Pb, y
|
| 59 |
|
|
|
|
| 148 |
self.layers = nn.ModuleList([])
|
| 149 |
for _ in range(n_layers):
|
| 150 |
self.layers.append(nn.ModuleDict({
|
| 151 |
+
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), # (B, L, H) for embeddings now
|
| 152 |
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 153 |
"n1t": nn.LayerNorm(hidden),
|
| 154 |
"n2t": nn.LayerNorm(hidden),
|
|
|
|
| 273 |
# infer dims from first row
|
| 274 |
if mode == "pooled":
|
| 275 |
Ht = len(train_ds[0]["target_embedding"])
|
| 276 |
+
binder_key = "binder_embedding" if "binder_embedding" in train_ds.column_names else "embedding"
|
| 277 |
+
Hb = len(train_ds[0][binder_key])
|
| 278 |
collate = collate_pair_pooled
|
| 279 |
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 280 |
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
|
|
|
| 351 |
|
| 352 |
if mode == "pooled":
|
| 353 |
Ht = len(train_ds[0]["target_embedding"])
|
| 354 |
+
binder_key = "binder_embedding" if "binder_embedding" in train_ds.column_names else "embedding"
|
| 355 |
+
Hb = len(train_ds[0][binder_key])
|
| 356 |
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 357 |
collate = collate_pair_pooled
|
| 358 |
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
training_classifiers/long_aggregated.csv
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:513cd88f97ef4b04ef92baaec85f2a5fe255a7dd50664025b2628a4ab6d94a99
|
| 3 |
+
size 45539
|
training_classifiers/ml_uncertainty.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import argparse
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import xgboost as xgb
|
| 7 |
+
from scipy import stats
|
| 8 |
+
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_curve
|
| 9 |
+
from datasets import load_from_disk, DatasetDict
|
| 10 |
+
|
| 11 |
+
def best_f1_threshold(y_true, y_prob):
|
| 12 |
+
p, r, thr = precision_recall_curve(y_true, y_prob)
|
| 13 |
+
f1s = (2 * p[:-1] * r[:-1]) / (p[:-1] + r[:-1] + 1e-12)
|
| 14 |
+
i = int(np.nanargmax(f1s))
|
| 15 |
+
return float(thr[i]), float(f1s[i])
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def bootstrap_ci(
|
| 19 |
+
y_true: np.ndarray,
|
| 20 |
+
y_prob: np.ndarray,
|
| 21 |
+
n_bootstrap: int = 2000,
|
| 22 |
+
ci: float = 0.95,
|
| 23 |
+
seed: int = 1986,
|
| 24 |
+
) -> dict:
|
| 25 |
+
"""
|
| 26 |
+
Non-parametric bootstrap CI for F1 (at val-optimal threshold) and AUC.
|
| 27 |
+
Resamples (y_true, y_prob) pairs
|
| 28 |
+
"""
|
| 29 |
+
rng = np.random.default_rng(seed=seed)
|
| 30 |
+
n = len(y_true)
|
| 31 |
+
|
| 32 |
+
# Threshold picked on the full val set
|
| 33 |
+
thr, _ = best_f1_threshold(y_true, y_prob)
|
| 34 |
+
|
| 35 |
+
f1_scores, auc_scores = [], []
|
| 36 |
+
|
| 37 |
+
for _ in range(n_bootstrap):
|
| 38 |
+
idx = rng.integers(0, n, size=n)
|
| 39 |
+
yt, yp = y_true[idx], y_prob[idx]
|
| 40 |
+
|
| 41 |
+
# Skip degenerate bootstraps (only one class)
|
| 42 |
+
if len(np.unique(yt)) < 2:
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
f1_scores.append(f1_score(yt, (yp >= thr).astype(int), zero_division=0))
|
| 46 |
+
auc_scores.append(roc_auc_score(yt, yp))
|
| 47 |
+
|
| 48 |
+
alpha = 1 - ci
|
| 49 |
+
lo, hi = alpha / 2, 1 - alpha / 2
|
| 50 |
+
|
| 51 |
+
results = {}
|
| 52 |
+
for name, arr in [("f1", f1_scores), ("auc", auc_scores)]:
|
| 53 |
+
arr = np.array(arr)
|
| 54 |
+
results[name] = {
|
| 55 |
+
"mean": float(arr.mean()),
|
| 56 |
+
"std": float(arr.std()),
|
| 57 |
+
"ci_low": float(np.quantile(arr, lo)),
|
| 58 |
+
"ci_high": float(np.quantile(arr, hi)),
|
| 59 |
+
"report": f"{arr.mean():.4f} [{np.quantile(arr, lo):.4f}, {np.quantile(arr, hi):.4f}]",
|
| 60 |
+
"n_bootstrap": len(arr),
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
results["threshold_used"] = float(thr)
|
| 64 |
+
results["n_samples"] = int(n)
|
| 65 |
+
return results
|
| 66 |
+
|
| 67 |
+
def prob_margin_uncertainty(val_preds_df: pd.DataFrame) -> pd.DataFrame:
|
| 68 |
+
"""
|
| 69 |
+
Uncertainty = distance from the decision boundary in probability space.
|
| 70 |
+
|
| 71 |
+
|prob - 0.5| if = 0.0 means maximally uncertain, 0.5 means maximally confident.
|
| 72 |
+
Normalized to [0, 1]: confidence = 2 * |prob - 0.5|
|
| 73 |
+
This reflecting how far the model is from a coin-flip on given sequence.
|
| 74 |
+
"""
|
| 75 |
+
df = val_preds_df.copy()
|
| 76 |
+
df["uncertainty"] = 1 - 2 * (df["y_prob"] - 0.5).abs() # 0=confident, 1=uncertain
|
| 77 |
+
df["confidence"] = 1 - df["uncertainty"] # 0=uncertain, 1=confident
|
| 78 |
+
return df
|
| 79 |
+
|
| 80 |
+
def save_ci_report(ci_results: dict, out_dir: str, model_name: str = ""):
|
| 81 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 82 |
+
path = os.path.join(out_dir, "bootstrap_ci.json")
|
| 83 |
+
with open(path, "w") as f:
|
| 84 |
+
json.dump(ci_results, f, indent=2)
|
| 85 |
+
|
| 86 |
+
print(f"\n=== Bootstrap 95% CI ({model_name}) ===")
|
| 87 |
+
print(f" F1 : {ci_results['f1']['report']}")
|
| 88 |
+
print(f" AUC : {ci_results['auc']['report']}")
|
| 89 |
+
print(f" (threshold={ci_results['threshold_used']:.4f}, "
|
| 90 |
+
f"n_bootstrap={ci_results['f1']['n_bootstrap']}, "
|
| 91 |
+
f"n_val={ci_results['n_samples']})")
|
| 92 |
+
print(f"Saved to {path}")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def save_uncertainty_csv(df: pd.DataFrame, out_dir: str, fname: str = "val_uncertainty.csv"):
|
| 96 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 97 |
+
path = os.path.join(out_dir, fname)
|
| 98 |
+
df.to_csv(path, index=False)
|
| 99 |
+
print(f"\n=== Per-molecule uncertainty ===")
|
| 100 |
+
print(f" Mean uncertainty : {df['uncertainty'].mean():.4f}")
|
| 101 |
+
print(f" Mean confidence : {df['confidence'].mean():.4f}")
|
| 102 |
+
print(f" Saved to {path}")
|
| 103 |
+
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
parser = argparse.ArgumentParser()
|
| 106 |
+
parser.add_argument("--mode", choices=["ci", "uncertainty_xgb", "uncertainty_prob"],
|
| 107 |
+
required=True,
|
| 108 |
+
help=(
|
| 109 |
+
"ci : bootstrap CI from val_predictions.csv (all models)\n"
|
| 110 |
+
"uncertainty_prob : margin uncertainty for SVM/ElasticNet/XGB"
|
| 111 |
+
))
|
| 112 |
+
parser.add_argument("--val_preds", type=str, help="Path to val_predictions.csv")
|
| 113 |
+
parser.add_argument("--model_path", type=str, help="Path to best_model.json (XGB only)")
|
| 114 |
+
parser.add_argument("--dataset_path", type=str, help="HuggingFace dataset path (XGB uncertainty only)")
|
| 115 |
+
parser.add_argument("--out_dir", type=str, required=True)
|
| 116 |
+
parser.add_argument("--model_name", type=str, default="", help="Label for report (xgb_smiles)")
|
| 117 |
+
parser.add_argument("--n_bootstrap", type=int, default=2000)
|
| 118 |
+
args = parser.parse_args()
|
| 119 |
+
|
| 120 |
+
if args.mode == "ci":
|
| 121 |
+
assert args.val_preds, "--val_preds required for ci mode"
|
| 122 |
+
df = pd.read_csv(args.val_preds)
|
| 123 |
+
ci = bootstrap_ci(df["y_true"].values, df["y_prob"].values,
|
| 124 |
+
n_bootstrap=args.n_bootstrap)
|
| 125 |
+
save_ci_report(ci, args.out_dir, args.model_name)
|
| 126 |
+
elif args.mode == "uncertainty_prob":
|
| 127 |
+
assert args.val_preds, "--val_preds required for uncertainty_prob"
|
| 128 |
+
df_preds = pd.read_csv(args.val_preds)
|
| 129 |
+
# CI
|
| 130 |
+
ci = bootstrap_ci(df_preds["y_true"].values, df_preds["y_prob"].values,
|
| 131 |
+
n_bootstrap=args.n_bootstrap)
|
| 132 |
+
save_ci_report(ci, args.out_dir, args.model_name)
|
| 133 |
+
# Uncertainty from margin
|
| 134 |
+
df_unc = prob_margin_uncertainty(df_preds)
|
| 135 |
+
save_uncertainty_csv(df_unc, args.out_dir, "val_uncertainty_prob.csv")
|
training_classifiers/ml_uncertainty_reg.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import argparse
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import xgboost as xgb
|
| 7 |
+
from scipy.stats import spearmanr
|
| 8 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
| 9 |
+
from datasets import load_from_disk, DatasetDict
|
| 10 |
+
|
| 11 |
+
def safe_spearmanr(y_true, y_pred):
|
| 12 |
+
rho = spearmanr(y_true, y_pred).correlation
|
| 13 |
+
return 0.0 if (rho is None or np.isnan(rho)) else float(rho)
|
| 14 |
+
|
| 15 |
+
def eval_regression(y_true, y_pred):
|
| 16 |
+
try:
|
| 17 |
+
from sklearn.metrics import root_mean_squared_error
|
| 18 |
+
rmse = float(root_mean_squared_error(y_true, y_pred))
|
| 19 |
+
except Exception:
|
| 20 |
+
rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
|
| 21 |
+
return {
|
| 22 |
+
"spearman_rho": safe_spearmanr(y_true, y_pred),
|
| 23 |
+
"rmse": rmse,
|
| 24 |
+
"mae": float(mean_absolute_error(y_true, y_pred)),
|
| 25 |
+
"r2": float(r2_score(y_true, y_pred)),
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
# ======================== Bootstrap CI =========================================
|
| 29 |
+
|
| 30 |
+
def bootstrap_ci_reg(
|
| 31 |
+
y_true: np.ndarray,
|
| 32 |
+
y_pred: np.ndarray,
|
| 33 |
+
n_bootstrap: int = 2000,
|
| 34 |
+
ci: float = 0.95,
|
| 35 |
+
seed: int = 1986,
|
| 36 |
+
) -> dict:
|
| 37 |
+
"""
|
| 38 |
+
Percentile bootstrap CI for regression metrics.
|
| 39 |
+
Uses percentile method (not t-CI) because:
|
| 40 |
+
- Spearman rho is bounded [-1, 1] - t-CI can produce impossible values near extremes
|
| 41 |
+
- RMSE is strictly positive - symmetric t-CI is inappropriate near 0
|
| 42 |
+
- Percentile bootstrap makes no distributional assumptions
|
| 43 |
+
|
| 44 |
+
Fisher z-transform CI for rho is also computed as a cross-check.
|
| 45 |
+
"""
|
| 46 |
+
rng = np.random.default_rng(seed=seed)
|
| 47 |
+
n = len(y_true)
|
| 48 |
+
alpha = 1 - ci
|
| 49 |
+
lo, hi = alpha / 2, 1 - alpha / 2
|
| 50 |
+
|
| 51 |
+
boot_metrics = {k: [] for k in ["spearman_rho", "rmse", "mae", "r2"]}
|
| 52 |
+
|
| 53 |
+
for _ in range(n_bootstrap):
|
| 54 |
+
idx = rng.integers(0, n, size=n)
|
| 55 |
+
yt, yp = y_true[idx], y_pred[idx]
|
| 56 |
+
if len(np.unique(yt)) < 2:
|
| 57 |
+
continue
|
| 58 |
+
m = eval_regression(yt, yp)
|
| 59 |
+
for k in boot_metrics:
|
| 60 |
+
boot_metrics[k].append(m[k])
|
| 61 |
+
|
| 62 |
+
results = {}
|
| 63 |
+
for name, arr in boot_metrics.items():
|
| 64 |
+
arr = np.array(arr)
|
| 65 |
+
results[name] = {
|
| 66 |
+
"mean": float(arr.mean()),
|
| 67 |
+
"std": float(arr.std()),
|
| 68 |
+
"ci_low": float(np.quantile(arr, lo)),
|
| 69 |
+
"ci_high": float(np.quantile(arr, hi)),
|
| 70 |
+
"report": f"{arr.mean():.4f} [{np.quantile(arr, lo):.4f}, {np.quantile(arr, hi):.4f}]",
|
| 71 |
+
"n_bootstrap": len(arr),
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
# Fisher z-transform CI for Spearman rho (cross-check, more accurate near ±1)
|
| 75 |
+
rho_vals = np.array(boot_metrics["spearman_rho"])
|
| 76 |
+
rho_obs = safe_spearmanr(y_true, y_pred)
|
| 77 |
+
# z-transform: arctanh(rho), SE = 1/sqrt(n-3)
|
| 78 |
+
z = np.arctanh(np.clip(rho_obs, -0.9999, 0.9999))
|
| 79 |
+
se_z = 1.0 / np.sqrt(max(n - 3, 1))
|
| 80 |
+
z_lo = z - 1.96 * se_z
|
| 81 |
+
z_hi = z + 1.96 * se_z
|
| 82 |
+
results["spearman_rho"]["fisher_z_ci"] = {
|
| 83 |
+
"ci_low": float(np.tanh(z_lo)),
|
| 84 |
+
"ci_high": float(np.tanh(z_hi)),
|
| 85 |
+
"report": f"[{np.tanh(z_lo):.4f}, {np.tanh(z_hi):.4f}]",
|
| 86 |
+
"note": "Fisher z-transform CI - more accurate when rho > 0.9",
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
results["n_samples"] = int(n)
|
| 90 |
+
return results
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def residual_uncertainty(val_preds_df: pd.DataFrame, coverage: float = 0.95) -> pd.DataFrame:
|
| 94 |
+
"""
|
| 95 |
+
- Assume residuals ~ N(0, sigma) where sigma = std(residuals)
|
| 96 |
+
- 95% prediction interval for molecule i: y_pred_i ± z * sigma
|
| 97 |
+
- Uncertainty score = sigma (constant across all molecules for linear models)
|
| 98 |
+
- Dataset-level uncertainty
|
| 99 |
+
"""
|
| 100 |
+
df = val_preds_df.copy()
|
| 101 |
+
|
| 102 |
+
residuals = df["y_true"] - df["y_pred"]
|
| 103 |
+
sigma = float(residuals.std(ddof=1))
|
| 104 |
+
z = {0.90: 1.645, 0.95: 1.960, 0.99: 2.576}.get(coverage, 1.960)
|
| 105 |
+
half_width = z * sigma
|
| 106 |
+
|
| 107 |
+
df["pred_interval_low"] = df["y_pred"] - half_width
|
| 108 |
+
df["pred_interval_high"] = df["y_pred"] + half_width
|
| 109 |
+
df["pred_interval_width"] = 2 * half_width # constant for linear models
|
| 110 |
+
df["abs_error"] = residuals.abs()
|
| 111 |
+
|
| 112 |
+
# what fraction of y_true actually falls inside the interval
|
| 113 |
+
empirical_coverage = float(
|
| 114 |
+
((df["y_true"] >= df["pred_interval_low"]) &
|
| 115 |
+
(df["y_true"] <= df["pred_interval_high"])).mean()
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
meta = {
|
| 119 |
+
"residual_std": round(sigma, 6),
|
| 120 |
+
"interval_halfwidth": round(half_width, 6),
|
| 121 |
+
f"nominal_coverage": coverage,
|
| 122 |
+
"empirical_coverage": round(empirical_coverage, 4),
|
| 123 |
+
"note": (
|
| 124 |
+
"Prediction interval assumes N(0, sigma) residuals."
|
| 125 |
+
"Interval width is constant across molecules for linear models. "
|
| 126 |
+
),
|
| 127 |
+
}
|
| 128 |
+
return df, meta
|
| 129 |
+
|
| 130 |
+
def save_ci_report(ci_results: dict, out_dir: str, model_name: str = ""):
|
| 131 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 132 |
+
path = os.path.join(out_dir, "bootstrap_ci_reg.json")
|
| 133 |
+
with open(path, "w") as f:
|
| 134 |
+
json.dump(ci_results, f, indent=2)
|
| 135 |
+
|
| 136 |
+
print(f"\n=== Bootstrap 95% CI - Regression ({model_name}) ===")
|
| 137 |
+
for metric in ["spearman_rho", "rmse", "mae", "r2"]:
|
| 138 |
+
r = ci_results[metric]
|
| 139 |
+
print(f" {metric:15s}: {r['report']}")
|
| 140 |
+
if metric == "spearman_rho" and "fisher_z_ci" in r:
|
| 141 |
+
fz = r["fisher_z_ci"]
|
| 142 |
+
print(f" Fisher z CI : {fz['report']} ← use this if rho > 0.9")
|
| 143 |
+
print(f" n_val={ci_results['n_samples']}, n_bootstrap={ci_results['spearman_rho']['n_bootstrap']}")
|
| 144 |
+
print(f"Saved to {path}")
|
| 145 |
+
|
| 146 |
+
if __name__ == "__main__":
|
| 147 |
+
parser = argparse.ArgumentParser()
|
| 148 |
+
parser.add_argument("--mode", required=True,
|
| 149 |
+
choices=["ci", "uncertainty_residual"],
|
| 150 |
+
help=(
|
| 151 |
+
"ci : bootstrap CI from val_predictions.csv\n"
|
| 152 |
+
"uncertainty_residual: residual interval for ElasticNet/SVR"
|
| 153 |
+
))
|
| 154 |
+
parser.add_argument("--val_preds", type=str, help="Path to val_predictions.csv")
|
| 155 |
+
parser.add_argument("--out_dir", type=str, required=True)
|
| 156 |
+
parser.add_argument("--model_name", type=str, default="")
|
| 157 |
+
parser.add_argument("--n_bootstrap", type=int, default=2000)
|
| 158 |
+
args = parser.parse_args()
|
| 159 |
+
|
| 160 |
+
if args.mode == "ci":
|
| 161 |
+
assert args.val_preds, "--val_preds required"
|
| 162 |
+
df = pd.read_csv(args.val_preds)
|
| 163 |
+
ci = bootstrap_ci_reg(df["y_true"].values, df["y_pred"].values,
|
| 164 |
+
n_bootstrap=args.n_bootstrap)
|
| 165 |
+
save_ci_report(ci, args.out_dir, args.model_name)
|
| 166 |
+
elif args.mode == "uncertainty_residual":
|
| 167 |
+
assert args.val_preds
|
| 168 |
+
df_preds = pd.read_csv(args.val_preds)
|
| 169 |
+
ci = bootstrap_ci_reg(df_preds["y_true"].values, df_preds["y_pred"].values,
|
| 170 |
+
n_bootstrap=args.n_bootstrap)
|
| 171 |
+
save_ci_report(ci, args.out_dir, args.model_name)
|
| 172 |
+
df_unc, meta = residual_uncertainty(df_preds)
|
| 173 |
+
path = os.path.join(args.out_dir, "val_uncertainty_residual.csv")
|
| 174 |
+
df_unc.to_csv(path, index=False)
|
| 175 |
+
meta_path = os.path.join(args.out_dir, "residual_interval_meta.json")
|
| 176 |
+
with open(meta_path, "w") as f:
|
| 177 |
+
json.dump(meta, f, indent=2)
|
| 178 |
+
print(f"\nResidual interval summary:")
|
| 179 |
+
print(f" Residual std : {meta['residual_std']:.4f}")
|
| 180 |
+
print(f" 95% interval ± {meta['interval_halfwidth']:.4f}")
|
| 181 |
+
print(f" Empirical coverage : {meta['empirical_coverage']:.4f} (nominal={meta['nominal_coverage']})")
|
| 182 |
+
print(f" Saved to {path}")
|
training_classifiers/refit_binding_affinity_seed.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import argparse
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from datasets import load_from_disk, DatasetDict
|
| 10 |
+
from scipy.stats import spearmanr
|
| 11 |
+
from scipy import stats as scipy_stats
|
| 12 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
| 13 |
+
from lightning.pytorch import seed_everything
|
| 14 |
+
import sys
|
| 15 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 16 |
+
from binding_training import (
|
| 17 |
+
CrossAttnPooled,
|
| 18 |
+
CrossAttnUnpooled,
|
| 19 |
+
collate_pair_pooled,
|
| 20 |
+
collate_pair_unpooled,
|
| 21 |
+
eval_spearman_pooled,
|
| 22 |
+
eval_spearman_unpooled,
|
| 23 |
+
train_one_epoch_pooled,
|
| 24 |
+
train_one_epoch_unpooled,
|
| 25 |
+
affinity_to_class_tensor,
|
| 26 |
+
safe_spearmanr,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 30 |
+
|
| 31 |
+
def load_split_paired(path: str):
|
| 32 |
+
dd = load_from_disk(path)
|
| 33 |
+
if not isinstance(dd, DatasetDict):
|
| 34 |
+
raise ValueError(f"Expected DatasetDict at {path}")
|
| 35 |
+
return dd["train"], dd["val"]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> dict:
|
| 39 |
+
try:
|
| 40 |
+
from sklearn.metrics import root_mean_squared_error
|
| 41 |
+
rmse = float(root_mean_squared_error(y_true, y_pred))
|
| 42 |
+
except Exception:
|
| 43 |
+
rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
|
| 44 |
+
return {
|
| 45 |
+
"spearman_rho": safe_spearmanr(y_true, y_pred),
|
| 46 |
+
"rmse": rmse,
|
| 47 |
+
"mae": float(mean_absolute_error(y_true, y_pred)),
|
| 48 |
+
"r2": float(r2_score(y_true, y_pred)),
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
@torch.no_grad()
|
| 52 |
+
def predict_all_pooled(model, loader):
|
| 53 |
+
model.eval()
|
| 54 |
+
ys, ps = [], []
|
| 55 |
+
for t, b, y in loader:
|
| 56 |
+
t = t.to(DEVICE, non_blocking=True)
|
| 57 |
+
b = b.to(DEVICE, non_blocking=True)
|
| 58 |
+
pred, _ = model(t, b)
|
| 59 |
+
ys.append(y.numpy())
|
| 60 |
+
ps.append(pred.detach().cpu().numpy())
|
| 61 |
+
return np.concatenate(ys), np.concatenate(ps)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@torch.no_grad()
|
| 65 |
+
def predict_all_unpooled(model, loader):
|
| 66 |
+
model.eval()
|
| 67 |
+
ys, ps = [], []
|
| 68 |
+
for T, Mt, B, Mb, y in loader:
|
| 69 |
+
T = T.to(DEVICE, non_blocking=True)
|
| 70 |
+
Mt = Mt.to(DEVICE, non_blocking=True)
|
| 71 |
+
B = B.to(DEVICE, non_blocking=True)
|
| 72 |
+
Mb = Mb.to(DEVICE, non_blocking=True)
|
| 73 |
+
pred, _ = model(T, Mt, B, Mb)
|
| 74 |
+
ys.append(y.numpy())
|
| 75 |
+
ps.append(pred.detach().cpu().numpy())
|
| 76 |
+
return np.concatenate(ys), np.concatenate(ps)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def build_model(mode: str, params: dict, train_ds) -> nn.Module:
|
| 80 |
+
hidden = int(params["hidden_dim"])
|
| 81 |
+
n_heads = int(params["n_heads"])
|
| 82 |
+
n_layers = int(params["n_layers"])
|
| 83 |
+
dropout = float(params["dropout"])
|
| 84 |
+
|
| 85 |
+
binder_key = "embedding" if "binder_embedding" not in train_ds.column_names else "binder_embedding"
|
| 86 |
+
|
| 87 |
+
if mode == "pooled":
|
| 88 |
+
Ht = len(train_ds[0]["target_embedding"])
|
| 89 |
+
Hb = len(train_ds[0][binder_key])
|
| 90 |
+
return CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads,
|
| 91 |
+
n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 92 |
+
else:
|
| 93 |
+
Ht = len(train_ds[0]["target_embedding"][0])
|
| 94 |
+
Hb = len(train_ds[0]["binder_embedding"][0])
|
| 95 |
+
return CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads,
|
| 96 |
+
n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Refit
|
| 100 |
+
def refit_with_seed(dataset_path: str, base_out_dir: str, mode: str,
|
| 101 |
+
seed: int, patience: int = 20) -> dict:
|
| 102 |
+
model_path = os.path.join(base_out_dir, "best_model.pt")
|
| 103 |
+
if not os.path.exists(model_path):
|
| 104 |
+
raise FileNotFoundError(
|
| 105 |
+
f"No best_model.pt found at {model_path}. Run Optuna (binding_training.py) first."
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
| 109 |
+
best_params = checkpoint["best_params"]
|
| 110 |
+
print(f"Loaded best_params from {model_path}")
|
| 111 |
+
print(json.dumps(best_params, indent=2))
|
| 112 |
+
|
| 113 |
+
seed_everything(seed)
|
| 114 |
+
out_dir = os.path.join(base_out_dir, f"seed_{seed}")
|
| 115 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 116 |
+
|
| 117 |
+
train_ds, val_ds = load_split_paired(dataset_path)
|
| 118 |
+
print(f"[Data] Train={len(train_ds)} Val={len(val_ds)} mode={mode}")
|
| 119 |
+
|
| 120 |
+
batch = int(best_params["batch_size"])
|
| 121 |
+
cls_w = float(best_params["cls_weight"])
|
| 122 |
+
|
| 123 |
+
if mode == "pooled":
|
| 124 |
+
collate = collate_pair_pooled
|
| 125 |
+
eval_fn = eval_spearman_pooled
|
| 126 |
+
train_fn = train_one_epoch_pooled
|
| 127 |
+
predict = predict_all_pooled
|
| 128 |
+
else:
|
| 129 |
+
collate = collate_pair_unpooled
|
| 130 |
+
eval_fn = eval_spearman_unpooled
|
| 131 |
+
train_fn = train_one_epoch_unpooled
|
| 132 |
+
predict = predict_all_unpooled
|
| 133 |
+
|
| 134 |
+
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True,
|
| 135 |
+
num_workers=4, pin_memory=True, collate_fn=collate)
|
| 136 |
+
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False,
|
| 137 |
+
num_workers=4, pin_memory=True, collate_fn=collate)
|
| 138 |
+
|
| 139 |
+
model = build_model(mode, best_params, train_ds)
|
| 140 |
+
opt = torch.optim.AdamW(model.parameters(),
|
| 141 |
+
lr=float(best_params["lr"]),
|
| 142 |
+
weight_decay=float(best_params["weight_decay"]))
|
| 143 |
+
loss_reg = nn.MSELoss()
|
| 144 |
+
loss_cls = nn.CrossEntropyLoss()
|
| 145 |
+
|
| 146 |
+
best_rho, bad, best_state = -1e9, 0, None
|
| 147 |
+
|
| 148 |
+
for epoch in range(1, 201):
|
| 149 |
+
train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
|
| 150 |
+
rho = eval_fn(model, val_loader)
|
| 151 |
+
|
| 152 |
+
if rho > best_rho + 1e-6:
|
| 153 |
+
best_rho = rho
|
| 154 |
+
bad = 0
|
| 155 |
+
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
| 156 |
+
else:
|
| 157 |
+
bad += 1
|
| 158 |
+
if bad >= patience:
|
| 159 |
+
print(f" Early stopping at epoch {epoch} (best rho={best_rho:.4f})")
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
if best_state:
|
| 163 |
+
model.load_state_dict(best_state)
|
| 164 |
+
|
| 165 |
+
y_true, y_pred = predict(model, val_loader)
|
| 166 |
+
metrics = eval_regression(y_true, y_pred)
|
| 167 |
+
|
| 168 |
+
# Save predictions
|
| 169 |
+
df_val = pd.DataFrame({
|
| 170 |
+
"y_true": y_true.astype(float),
|
| 171 |
+
"y_pred": y_pred.astype(float),
|
| 172 |
+
"residual": (y_true - y_pred).astype(float),
|
| 173 |
+
"abs_error": np.abs(y_true - y_pred).astype(float),
|
| 174 |
+
})
|
| 175 |
+
for col in ("target_sequence", "sequence", "affinity_class"):
|
| 176 |
+
if col in val_ds.column_names:
|
| 177 |
+
df_val.insert(0, col, np.asarray(val_ds[col]))
|
| 178 |
+
df_val.to_csv(os.path.join(out_dir, "val_predictions.csv"), index=False)
|
| 179 |
+
|
| 180 |
+
torch.save({"state_dict": model.state_dict(),
|
| 181 |
+
"best_params": best_params,
|
| 182 |
+
"mode": mode,
|
| 183 |
+
"seed": seed},
|
| 184 |
+
os.path.join(out_dir, "model.pt"))
|
| 185 |
+
|
| 186 |
+
summary = {"mode": mode, "seed": seed,
|
| 187 |
+
**{k: round(v, 6) for k, v in metrics.items()}}
|
| 188 |
+
with open(os.path.join(out_dir, "metrics.json"), "w") as f:
|
| 189 |
+
json.dump(summary, f, indent=2)
|
| 190 |
+
|
| 191 |
+
print(f"\n[Seed {seed}] rho={metrics['spearman_rho']:.4f} "
|
| 192 |
+
f"RMSE={metrics['rmse']:.4f} R2={metrics['r2']:.4f}")
|
| 193 |
+
return summary
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# CI aggregation
|
| 197 |
+
|
| 198 |
+
def aggregate_seed_results(base_out_dir: str, seeds: list) -> pd.DataFrame:
|
| 199 |
+
records = []
|
| 200 |
+
for seed in seeds:
|
| 201 |
+
p = os.path.join(base_out_dir, f"seed_{seed}", "metrics.json")
|
| 202 |
+
if os.path.exists(p):
|
| 203 |
+
records.append(json.load(open(p)))
|
| 204 |
+
else:
|
| 205 |
+
print(f"[WARN] Missing seed {seed} at {p}")
|
| 206 |
+
|
| 207 |
+
if not records:
|
| 208 |
+
raise ValueError("No seed results found — did the refit jobs complete?")
|
| 209 |
+
|
| 210 |
+
df = pd.DataFrame(records)
|
| 211 |
+
print("\nPer-seed results:")
|
| 212 |
+
print(df.to_string(index=False))
|
| 213 |
+
|
| 214 |
+
summary_rows = []
|
| 215 |
+
for metric in ["spearman_rho", "rmse", "mae", "r2"]:
|
| 216 |
+
vals = df[metric].values
|
| 217 |
+
n = len(vals)
|
| 218 |
+
mean = vals.mean()
|
| 219 |
+
std = vals.std(ddof=1)
|
| 220 |
+
se = std / np.sqrt(n)
|
| 221 |
+
t_crit = scipy_stats.t.ppf(0.975, df=n - 1)
|
| 222 |
+
ci = t_crit * se
|
| 223 |
+
row = {
|
| 224 |
+
"metric": metric,
|
| 225 |
+
"mean": round(mean, 4),
|
| 226 |
+
"std": round(std, 4),
|
| 227 |
+
"ci_95": round(ci, 4),
|
| 228 |
+
"report": f"{mean:.4f} ± {ci:.4f}",
|
| 229 |
+
"n_seeds": n,
|
| 230 |
+
}
|
| 231 |
+
if metric == "spearman_rho" and (mean + ci > 0.95 or mean - ci < -0.95):
|
| 232 |
+
row["note"] = "rho near boundary — consider Fisher z-transform CI"
|
| 233 |
+
summary_rows.append(row)
|
| 234 |
+
|
| 235 |
+
summary_df = pd.DataFrame(summary_rows)
|
| 236 |
+
out_path = os.path.join(base_out_dir, "seed_aggregated_metrics.csv")
|
| 237 |
+
summary_df.to_csv(out_path, index=False)
|
| 238 |
+
|
| 239 |
+
print("\n=== Aggregated Metrics (95% CI, t-distribution) ===")
|
| 240 |
+
for _, row in summary_df.iterrows():
|
| 241 |
+
note = f" ← {row['note']}" if "note" in row and pd.notna(row.get("note")) else ""
|
| 242 |
+
print(f" {row['metric']:15s}: {row['report']}{note}")
|
| 243 |
+
print(f"\nSaved → {out_path}")
|
| 244 |
+
return summary_df
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
parser = argparse.ArgumentParser()
|
| 249 |
+
parser.add_argument("--dataset_path", type=str, required=True,
|
| 250 |
+
help="Paired DatasetDict path")
|
| 251 |
+
parser.add_argument("--base_out_dir", type=str, required=True,
|
| 252 |
+
help="Directory containing best_model.pt from the Optuna run")
|
| 253 |
+
parser.add_argument("--mode", type=str, required=True)
|
| 254 |
+
parser.add_argument("--seed", type=int, required=True)
|
| 255 |
+
parser.add_argument("--patience", type=int, default=20)
|
| 256 |
+
parser.add_argument("--aggregate", action="store_true",
|
| 257 |
+
help="Aggregate across seed runs instead of training")
|
| 258 |
+
parser.add_argument("--all_seeds", type=int, nargs="+", default=[1986, 42, 0, 123, 12345])
|
| 259 |
+
args = parser.parse_args()
|
| 260 |
+
|
| 261 |
+
if args.aggregate:
|
| 262 |
+
aggregate_seed_results(args.base_out_dir, args.all_seeds)
|
| 263 |
+
else:
|
| 264 |
+
refit_with_seed(
|
| 265 |
+
dataset_path=args.dataset_path,
|
| 266 |
+
base_out_dir=args.base_out_dir,
|
| 267 |
+
mode=args.mode,
|
| 268 |
+
seed=args.seed,
|
| 269 |
+
patience=args.patience,
|
| 270 |
+
)
|
training_classifiers/refit_ml_walltime.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Loads best params from optimization_summary.txt, refits the model once on the
|
| 3 |
+
train split, and appends a wall-time record to wall_clock_ml.jsonl.
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
import joblib
|
| 8 |
+
import argparse
|
| 9 |
+
import re
|
| 10 |
+
import numpy as np
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
# Classification trainers
|
| 14 |
+
from train_ml import (
|
| 15 |
+
load_split_data as load_split_cls,
|
| 16 |
+
train_cuml_svc,
|
| 17 |
+
train_cuml_elastic_net,
|
| 18 |
+
train_xgb,
|
| 19 |
+
train_svm,
|
| 20 |
+
)
|
| 21 |
+
# Regression trainers
|
| 22 |
+
from train_ml_regression import (
|
| 23 |
+
load_split_data as load_split_reg,
|
| 24 |
+
train_cuml_elasticnet_reg,
|
| 25 |
+
train_svr_reg,
|
| 26 |
+
train_xgb_reg,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
MODEL_FILE_MAP = [
|
| 30 |
+
("best_model_cuml_svc.joblib", "svm_gpu", "classification"),
|
| 31 |
+
("best_model_cuml_enet.joblib", "enet_gpu", "auto"),
|
| 32 |
+
("best_model_svr.joblib", "svr", "regression"),
|
| 33 |
+
("best_model.joblib", "svm", "classification"),
|
| 34 |
+
("best_model.json", "xgb", "auto"),
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
def detect_model_type(model_dir: Path) -> tuple:
|
| 38 |
+
"""Returns (model_type, task)."""
|
| 39 |
+
for fname, model_type, task in MODEL_FILE_MAP:
|
| 40 |
+
if (model_dir / fname).exists():
|
| 41 |
+
if task == "auto":
|
| 42 |
+
if (model_dir / "scaler.joblib").exists():
|
| 43 |
+
task = "regression"
|
| 44 |
+
if model_type == "xgb":
|
| 45 |
+
model_type = "xgb_reg"
|
| 46 |
+
else:
|
| 47 |
+
task = "classification"
|
| 48 |
+
return model_type, task
|
| 49 |
+
raise FileNotFoundError(
|
| 50 |
+
f"No recognised model file in {model_dir}. "
|
| 51 |
+
f"Expected one of: {[f for f, _, _ in MODEL_FILE_MAP]}"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def parse_best_params(model_dir: Path) -> dict:
|
| 56 |
+
"""
|
| 57 |
+
Extracts the JSON block after 'Best params:' in optimization_summary.txt.
|
| 58 |
+
"""
|
| 59 |
+
summary_path = model_dir / "optimization_summary.txt"
|
| 60 |
+
if not summary_path.exists():
|
| 61 |
+
raise FileNotFoundError(f"optimization_summary.txt not found in {model_dir}")
|
| 62 |
+
|
| 63 |
+
text = summary_path.read_text()
|
| 64 |
+
match = re.search(r"Best params:\s*(\{.*?\})\s*={10,}", text, re.DOTALL)
|
| 65 |
+
if not match:
|
| 66 |
+
raise ValueError(
|
| 67 |
+
f"Could not find 'Best params:' JSON block in {summary_path}.\n"
|
| 68 |
+
f"File contents:\n{text}"
|
| 69 |
+
)
|
| 70 |
+
return json.loads(match.group(1))
|
| 71 |
+
|
| 72 |
+
def parse_objective_and_wt(model_dir: Path) -> tuple:
|
| 73 |
+
"""
|
| 74 |
+
Expects layout: .../training_classifiers/<objective>/<model>_<wt>/
|
| 75 |
+
Example: hemolysis/svm_gpu_smiles -> objective=hemolysis, wt=smiles
|
| 76 |
+
"""
|
| 77 |
+
parts = model_dir.parts
|
| 78 |
+
model_folder = parts[-1].lower()
|
| 79 |
+
objective = parts[-2]
|
| 80 |
+
|
| 81 |
+
for suffix, wt in [("_chemberta", "chemberta"), ("_smiles", "smiles"), ("_wt", "wt")]:
|
| 82 |
+
if model_folder.endswith(suffix):
|
| 83 |
+
return objective, wt
|
| 84 |
+
return objective, "wt"
|
| 85 |
+
|
| 86 |
+
def refit_and_time(model_dir: Path, dataset_path: str) -> tuple:
|
| 87 |
+
model_type, task = detect_model_type(model_dir)
|
| 88 |
+
best_params = parse_best_params(model_dir)
|
| 89 |
+
|
| 90 |
+
print(f" Model type : {model_type} ({task})")
|
| 91 |
+
print(f" Best params: {best_params}")
|
| 92 |
+
|
| 93 |
+
# Load scaler if present (regression models)
|
| 94 |
+
scaler_path = model_dir / "scaler.joblib"
|
| 95 |
+
scaler = joblib.load(scaler_path) if scaler_path.exists() else None
|
| 96 |
+
|
| 97 |
+
load_fn = load_split_reg if task == "regression" else load_split_cls
|
| 98 |
+
data = load_fn(dataset_path)
|
| 99 |
+
print(f" Train: {data.X_train.shape} Val: {data.X_val.shape}")
|
| 100 |
+
|
| 101 |
+
# Build params
|
| 102 |
+
if model_type == "xgb":
|
| 103 |
+
params = {
|
| 104 |
+
"objective": "binary:logistic",
|
| 105 |
+
"eval_metric": "logloss",
|
| 106 |
+
"lambda": best_params["lambda"],
|
| 107 |
+
"alpha": best_params["alpha"],
|
| 108 |
+
"colsample_bytree": best_params["colsample_bytree"],
|
| 109 |
+
"subsample": best_params["subsample"],
|
| 110 |
+
"learning_rate": best_params["learning_rate"],
|
| 111 |
+
"max_depth": best_params["max_depth"],
|
| 112 |
+
"min_child_weight": best_params["min_child_weight"],
|
| 113 |
+
"gamma": best_params["gamma"],
|
| 114 |
+
"tree_method": "hist",
|
| 115 |
+
"device": "cuda",
|
| 116 |
+
"num_boost_round": best_params["num_boost_round"],
|
| 117 |
+
"early_stopping_rounds": best_params["early_stopping_rounds"],
|
| 118 |
+
}
|
| 119 |
+
train_fn = train_xgb
|
| 120 |
+
|
| 121 |
+
elif model_type == "xgb_reg":
|
| 122 |
+
params = {
|
| 123 |
+
"objective": "reg:squarederror",
|
| 124 |
+
"eval_metric": "rmse",
|
| 125 |
+
"lambda": best_params["lambda"],
|
| 126 |
+
"alpha": best_params["alpha"],
|
| 127 |
+
"gamma": best_params["gamma"],
|
| 128 |
+
"max_depth": best_params["max_depth"],
|
| 129 |
+
"min_child_weight": best_params["min_child_weight"],
|
| 130 |
+
"subsample": best_params["subsample"],
|
| 131 |
+
"colsample_bytree": best_params["colsample_bytree"],
|
| 132 |
+
"learning_rate": best_params["learning_rate"],
|
| 133 |
+
"tree_method": "hist",
|
| 134 |
+
"device": "cuda",
|
| 135 |
+
"num_boost_round": best_params["num_boost_round"],
|
| 136 |
+
"early_stopping_rounds": best_params["early_stopping_rounds"],
|
| 137 |
+
}
|
| 138 |
+
train_fn = train_xgb_reg
|
| 139 |
+
|
| 140 |
+
elif model_type == "svm_gpu":
|
| 141 |
+
params = best_params
|
| 142 |
+
train_fn = train_cuml_svc
|
| 143 |
+
|
| 144 |
+
elif model_type == "enet_gpu" and task == "classification":
|
| 145 |
+
params = best_params
|
| 146 |
+
train_fn = train_cuml_elastic_net
|
| 147 |
+
|
| 148 |
+
elif model_type == "enet_gpu" and task == "regression":
|
| 149 |
+
params = best_params
|
| 150 |
+
train_fn = train_cuml_elasticnet_reg
|
| 151 |
+
|
| 152 |
+
elif model_type == "svm":
|
| 153 |
+
params = best_params
|
| 154 |
+
train_fn = train_svm
|
| 155 |
+
|
| 156 |
+
elif model_type == "svr":
|
| 157 |
+
params = best_params
|
| 158 |
+
train_fn = train_svr_reg
|
| 159 |
+
|
| 160 |
+
else:
|
| 161 |
+
raise ValueError(f"Unhandled model_type={model_type}, task={task}")
|
| 162 |
+
|
| 163 |
+
# Timed block
|
| 164 |
+
t0 = time.perf_counter()
|
| 165 |
+
|
| 166 |
+
X_train = data.X_train
|
| 167 |
+
X_val = data.X_val
|
| 168 |
+
if scaler is not None:
|
| 169 |
+
X_train = scaler.transform(X_train).astype(np.float32)
|
| 170 |
+
X_val = scaler.transform(X_val).astype(np.float32)
|
| 171 |
+
|
| 172 |
+
train_fn(X_train, data.y_train, X_val, data.y_val, params)
|
| 173 |
+
|
| 174 |
+
wall_s = time.perf_counter() - t0
|
| 175 |
+
print(f" Wall time: {wall_s:.1f}s")
|
| 176 |
+
return wall_s, model_type
|
| 177 |
+
|
| 178 |
+
def write_wall_time(logs_dir: Path, objective: str, wt: str,
|
| 179 |
+
model_type: str, wall_s: float):
|
| 180 |
+
logs_dir.mkdir(parents=True, exist_ok=True)
|
| 181 |
+
date_str = datetime.now().strftime("%m_%d")
|
| 182 |
+
jsonl_path = logs_dir / f"{date_str}_wall_clock_ml.jsonl"
|
| 183 |
+
|
| 184 |
+
record = {
|
| 185 |
+
"model": model_type,
|
| 186 |
+
"objective": objective,
|
| 187 |
+
"wt": wt,
|
| 188 |
+
"wall_s": round(wall_s),
|
| 189 |
+
}
|
| 190 |
+
with open(jsonl_path, "a") as f:
|
| 191 |
+
f.write(json.dumps(record) + "\n")
|
| 192 |
+
print(f" Appended to {jsonl_path}: {record}")
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
parser = argparse.ArgumentParser()
|
| 196 |
+
parser.add_argument("--model_dir", type=str, required=True,
|
| 197 |
+
help="e.g. .../hemolysis/svm_gpu_smiles")
|
| 198 |
+
parser.add_argument("--dataset_path", type=str, required=True,
|
| 199 |
+
help="HuggingFace dataset path for this objective/embedding")
|
| 200 |
+
parser.add_argument("--logs_dir", type=str, required=True,
|
| 201 |
+
help="Directory to write *_wall_clock_ml.jsonl")
|
| 202 |
+
args = parser.parse_args()
|
| 203 |
+
|
| 204 |
+
model_dir = Path(args.model_dir)
|
| 205 |
+
objective, wt = parse_objective_and_wt(model_dir)
|
| 206 |
+
print(f"\nObjective: {objective} Embedding: {wt}")
|
| 207 |
+
|
| 208 |
+
wall_s, model_type = refit_and_time(model_dir, args.dataset_path)
|
| 209 |
+
write_wall_time(Path(args.logs_dir), objective, wt, model_type, wall_s)
|
training_classifiers/refit_nn_seed.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from datasets import load_from_disk, DatasetDict
|
| 5 |
+
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import argparse
|
| 11 |
+
from typing import Optional
|
| 12 |
+
from lightning.pytorch import seed_everything
|
| 13 |
+
|
| 14 |
+
def infer_in_dim_from_unpooled_ds(ds) -> int:
|
| 15 |
+
ex = ds[0]
|
| 16 |
+
return int(len(ex["embedding"][0]))
|
| 17 |
+
|
| 18 |
+
def load_split(dataset_path):
|
| 19 |
+
ds = load_from_disk(dataset_path)
|
| 20 |
+
if isinstance(ds, DatasetDict):
|
| 21 |
+
return ds["train"], ds["val"]
|
| 22 |
+
raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
|
| 23 |
+
|
| 24 |
+
def collate_unpooled(batch):
|
| 25 |
+
lengths = [int(x["length"]) for x in batch]
|
| 26 |
+
Lmax = max(lengths)
|
| 27 |
+
H = len(batch[0]["embedding"][0])
|
| 28 |
+
|
| 29 |
+
X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
|
| 30 |
+
M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
|
| 31 |
+
y = torch.tensor([x["label"] for x in batch], dtype=torch.float32)
|
| 32 |
+
|
| 33 |
+
for i, x in enumerate(batch):
|
| 34 |
+
emb = torch.tensor(x["embedding"], dtype=torch.float32)
|
| 35 |
+
L = emb.shape[0]
|
| 36 |
+
X[i, :L] = emb
|
| 37 |
+
if "attention_mask" in x:
|
| 38 |
+
m = torch.tensor(x["attention_mask"], dtype=torch.bool)
|
| 39 |
+
M[i, :L] = m[:L]
|
| 40 |
+
else:
|
| 41 |
+
M[i, :L] = True
|
| 42 |
+
|
| 43 |
+
return X, M, y
|
| 44 |
+
|
| 45 |
+
# ======================== Models =========================================
|
| 46 |
+
|
| 47 |
+
class MaskedMeanPool(nn.Module):
|
| 48 |
+
def forward(self, X, M):
|
| 49 |
+
Mf = M.unsqueeze(-1).float()
|
| 50 |
+
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 51 |
+
return (X * Mf).sum(dim=1) / denom
|
| 52 |
+
|
| 53 |
+
class MLPClassifier(nn.Module):
|
| 54 |
+
def __init__(self, in_dim, hidden=512, dropout=0.1):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.pool = MaskedMeanPool()
|
| 57 |
+
self.net = nn.Sequential(
|
| 58 |
+
nn.Linear(in_dim, hidden),
|
| 59 |
+
nn.GELU(),
|
| 60 |
+
nn.Dropout(dropout),
|
| 61 |
+
nn.Linear(hidden, 1),
|
| 62 |
+
)
|
| 63 |
+
def forward(self, X, M):
|
| 64 |
+
return self.net(self.pool(X, M)).squeeze(-1)
|
| 65 |
+
|
| 66 |
+
class CNNClassifier(nn.Module):
|
| 67 |
+
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
|
| 68 |
+
super().__init__()
|
| 69 |
+
blocks, ch = [], in_ch
|
| 70 |
+
for _ in range(layers):
|
| 71 |
+
blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)]
|
| 72 |
+
ch = c
|
| 73 |
+
self.conv = nn.Sequential(*blocks)
|
| 74 |
+
self.head = nn.Linear(c, 1)
|
| 75 |
+
|
| 76 |
+
def forward(self, X, M):
|
| 77 |
+
Y = self.conv(X.transpose(1, 2)).transpose(1, 2)
|
| 78 |
+
Mf = M.unsqueeze(-1).float()
|
| 79 |
+
pooled = (Y * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
|
| 80 |
+
return self.head(pooled).squeeze(-1)
|
| 81 |
+
|
| 82 |
+
class TransformerClassifier(nn.Module):
|
| 83 |
+
def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.proj = nn.Linear(in_dim, d_model)
|
| 86 |
+
enc_layer = nn.TransformerEncoderLayer(
|
| 87 |
+
d_model=d_model, nhead=nhead, dim_feedforward=ff,
|
| 88 |
+
dropout=dropout, batch_first=True, activation="gelu"
|
| 89 |
+
)
|
| 90 |
+
self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
|
| 91 |
+
self.head = nn.Linear(d_model, 1)
|
| 92 |
+
|
| 93 |
+
def forward(self, X, M):
|
| 94 |
+
Z = self.enc(self.proj(X), src_key_padding_mask=~M)
|
| 95 |
+
Mf = M.unsqueeze(-1).float()
|
| 96 |
+
pooled = (Z * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
|
| 97 |
+
return self.head(pooled).squeeze(-1)
|
| 98 |
+
|
| 99 |
+
# ======================== Training utils =========================================
|
| 100 |
+
|
| 101 |
+
def best_f1_threshold(y_true, y_prob):
|
| 102 |
+
p, r, thr = precision_recall_curve(y_true, y_prob)
|
| 103 |
+
f1s = (2 * p[:-1] * r[:-1]) / (p[:-1] + r[:-1] + 1e-12)
|
| 104 |
+
i = int(np.nanargmax(f1s))
|
| 105 |
+
return float(thr[i]), float(f1s[i])
|
| 106 |
+
|
| 107 |
+
@torch.no_grad()
|
| 108 |
+
def eval_probs(model, loader, device):
|
| 109 |
+
model.eval()
|
| 110 |
+
ys, ps = [], []
|
| 111 |
+
for X, M, y in loader:
|
| 112 |
+
X, M = X.to(device), M.to(device)
|
| 113 |
+
ps.append(torch.sigmoid(model(X, M)).cpu().numpy())
|
| 114 |
+
ys.append(y.numpy())
|
| 115 |
+
return np.concatenate(ys), np.concatenate(ps)
|
| 116 |
+
|
| 117 |
+
def train_one_epoch(model, loader, optim, criterion, device):
|
| 118 |
+
model.train()
|
| 119 |
+
for X, M, y in loader:
|
| 120 |
+
X, M, y = X.to(device), M.to(device), y.to(device)
|
| 121 |
+
optim.zero_grad(set_to_none=True)
|
| 122 |
+
criterion(model(X, M), y).backward()
|
| 123 |
+
optim.step()
|
| 124 |
+
|
| 125 |
+
def build_model(model_name, in_dim, params):
|
| 126 |
+
dropout = float(params.get("dropout", 0.1))
|
| 127 |
+
if model_name == "mlp":
|
| 128 |
+
return MLPClassifier(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout)
|
| 129 |
+
elif model_name == "cnn":
|
| 130 |
+
return CNNClassifier(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
|
| 131 |
+
layers=int(params["layers"]), dropout=dropout)
|
| 132 |
+
elif model_name == "transformer":
|
| 133 |
+
return TransformerClassifier(in_dim=in_dim, d_model=int(params["d_model"]),
|
| 134 |
+
nhead=int(params["nhead"]), layers=int(params["layers"]),
|
| 135 |
+
ff=int(params["ff"]), dropout=dropout)
|
| 136 |
+
raise ValueError(model_name)
|
| 137 |
+
|
| 138 |
+
# ======================== Main refit =========================================
|
| 139 |
+
|
| 140 |
+
def refit_with_seed(dataset_path, base_out_dir, model_name, seed, device="cuda:0"):
|
| 141 |
+
"""
|
| 142 |
+
Loads best_params from base_out_dir/best_model.pt (saved by original Optuna run),
|
| 143 |
+
retrains with the given seed, saves results to base_out_dir/seed_{seed}/.
|
| 144 |
+
"""
|
| 145 |
+
# Load best params from completed Optuna run
|
| 146 |
+
model_path = os.path.join(base_out_dir, "best_model.pt")
|
| 147 |
+
if not os.path.exists(model_path):
|
| 148 |
+
raise FileNotFoundError(f"No best_model.pt found at {model_path}. Run Optuna first.")
|
| 149 |
+
|
| 150 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
| 151 |
+
best_params = checkpoint["best_params"]
|
| 152 |
+
print(f"Loaded best_params from {model_path}")
|
| 153 |
+
print(json.dumps(best_params, indent=2))
|
| 154 |
+
|
| 155 |
+
# Seed
|
| 156 |
+
seed_everything(seed)
|
| 157 |
+
|
| 158 |
+
out_dir = os.path.join(base_out_dir, f"seed_{seed}")
|
| 159 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 160 |
+
|
| 161 |
+
# Data import
|
| 162 |
+
train_ds, val_ds = load_split(dataset_path)
|
| 163 |
+
print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
|
| 164 |
+
|
| 165 |
+
batch_size = int(best_params.get("batch_size", 32))
|
| 166 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
|
| 167 |
+
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
|
| 168 |
+
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
|
| 169 |
+
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
|
| 170 |
+
|
| 171 |
+
in_dim = infer_in_dim_from_unpooled_ds(train_ds)
|
| 172 |
+
model = build_model(model_name, in_dim, best_params).to(device)
|
| 173 |
+
|
| 174 |
+
# Loss
|
| 175 |
+
ytr = np.asarray(train_ds["label"], dtype=np.int64)
|
| 176 |
+
pos, neg = ytr.sum(), len(ytr) - ytr.sum()
|
| 177 |
+
pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32)
|
| 178 |
+
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
| 179 |
+
|
| 180 |
+
optim = torch.optim.AdamW(model.parameters(),
|
| 181 |
+
lr=float(best_params["lr"]),
|
| 182 |
+
weight_decay=float(best_params["weight_decay"]))
|
| 183 |
+
|
| 184 |
+
# Training loop with early stopping
|
| 185 |
+
best_f1, best_thr, bad, patience = -1.0, 0.5, 0, 12
|
| 186 |
+
best_state = None
|
| 187 |
+
|
| 188 |
+
for epoch in range(1, 151):
|
| 189 |
+
train_one_epoch(model, train_loader, optim, criterion, device)
|
| 190 |
+
y_true, y_prob = eval_probs(model, val_loader, device)
|
| 191 |
+
thr, f1 = best_f1_threshold(y_true, y_prob)
|
| 192 |
+
|
| 193 |
+
if f1 > best_f1 + 1e-4:
|
| 194 |
+
best_f1 = f1
|
| 195 |
+
best_thr = thr
|
| 196 |
+
bad = 0
|
| 197 |
+
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
| 198 |
+
else:
|
| 199 |
+
bad += 1
|
| 200 |
+
if bad >= patience:
|
| 201 |
+
print(f"Early stopping at epoch {epoch}")
|
| 202 |
+
break
|
| 203 |
+
|
| 204 |
+
if best_state is not None:
|
| 205 |
+
model.load_state_dict(best_state)
|
| 206 |
+
|
| 207 |
+
# Final eval
|
| 208 |
+
y_true_val, y_prob_val = eval_probs(model, val_loader, device)
|
| 209 |
+
best_thr_final, best_f1_final = best_f1_threshold(y_true_val, y_prob_val)
|
| 210 |
+
auc_final = roc_auc_score(y_true_val, y_prob_val)
|
| 211 |
+
|
| 212 |
+
# Save
|
| 213 |
+
df_val = pd.DataFrame({
|
| 214 |
+
"y_true": y_true_val.astype(int),
|
| 215 |
+
"y_prob": y_prob_val.astype(float),
|
| 216 |
+
"y_pred": (y_prob_val >= best_thr_final).astype(int),
|
| 217 |
+
})
|
| 218 |
+
if "sequence" in val_ds.column_names:
|
| 219 |
+
df_val.insert(0, "sequence", np.asarray(val_ds["sequence"]))
|
| 220 |
+
df_val.to_csv(os.path.join(out_dir, "val_predictions.csv"), index=False)
|
| 221 |
+
|
| 222 |
+
torch.save({"state_dict": model.state_dict(), "best_params": best_params, "seed": seed},
|
| 223 |
+
os.path.join(out_dir, "model.pt"))
|
| 224 |
+
|
| 225 |
+
summary = {
|
| 226 |
+
"model": model_name,
|
| 227 |
+
"seed": seed,
|
| 228 |
+
"val_f1": round(best_f1_final, 6),
|
| 229 |
+
"val_auc": round(auc_final, 6),
|
| 230 |
+
"val_thr": round(best_thr_final, 6),
|
| 231 |
+
}
|
| 232 |
+
with open(os.path.join(out_dir, "metrics.json"), "w") as f:
|
| 233 |
+
json.dump(summary, f, indent=2)
|
| 234 |
+
|
| 235 |
+
print(f"\n[Seed {seed}] F1={best_f1_final:.4f} AUC={auc_final:.4f} thr={best_thr_final:.4f}")
|
| 236 |
+
print(f"Saved to {out_dir}")
|
| 237 |
+
return summary
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# ======================== CI aggregation =========================================
|
| 241 |
+
|
| 242 |
+
def aggregate_seed_results(base_out_dir, seeds):
|
| 243 |
+
"""
|
| 244 |
+
Call after all seed runs finish to compute mean ± 95% CI across seeds.
|
| 245 |
+
Saves a summary CSV to base_out_dir/seed_aggregated_metrics.csv
|
| 246 |
+
"""
|
| 247 |
+
from scipy import stats
|
| 248 |
+
|
| 249 |
+
records = []
|
| 250 |
+
for seed in seeds:
|
| 251 |
+
p = os.path.join(base_out_dir, f"seed_{seed}", "metrics.json")
|
| 252 |
+
if os.path.exists(p):
|
| 253 |
+
records.append(json.load(open(p)))
|
| 254 |
+
else:
|
| 255 |
+
print(f"Warning: missing seed {seed} at {p}")
|
| 256 |
+
|
| 257 |
+
if not records:
|
| 258 |
+
raise ValueError("No seed results found.")
|
| 259 |
+
|
| 260 |
+
df = pd.DataFrame(records)
|
| 261 |
+
print("\nPer-seed results:")
|
| 262 |
+
print(df.to_string(index=False))
|
| 263 |
+
|
| 264 |
+
summary_rows = []
|
| 265 |
+
for metric in ["val_f1", "val_auc"]:
|
| 266 |
+
vals = df[metric].values
|
| 267 |
+
n = len(vals)
|
| 268 |
+
mean = vals.mean()
|
| 269 |
+
std = vals.std(ddof=1)
|
| 270 |
+
se = std / np.sqrt(n)
|
| 271 |
+
t_crit = stats.t.ppf(0.975, df=n - 1)
|
| 272 |
+
ci = t_crit * se
|
| 273 |
+
summary_rows.append({
|
| 274 |
+
"metric": metric,
|
| 275 |
+
"mean": round(mean, 4),
|
| 276 |
+
"std": round(std, 4),
|
| 277 |
+
"ci_95": round(ci, 4),
|
| 278 |
+
"report": f"{mean:.4f} ± {ci:.4f}",
|
| 279 |
+
"n_seeds": n,
|
| 280 |
+
})
|
| 281 |
+
|
| 282 |
+
summary_df = pd.DataFrame(summary_rows)
|
| 283 |
+
out_path = os.path.join(base_out_dir, "seed_aggregated_metrics.csv")
|
| 284 |
+
summary_df.to_csv(out_path, index=False)
|
| 285 |
+
|
| 286 |
+
print("\n=== Aggregated Metrics (95% CI) ===")
|
| 287 |
+
for _, row in summary_df.iterrows():
|
| 288 |
+
print(f" {row['metric']:12s}: {row['report']} (n={row['n_seeds']})")
|
| 289 |
+
print(f"\nSaved to {out_path}")
|
| 290 |
+
return summary_df
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
if __name__ == "__main__":
|
| 294 |
+
parser = argparse.ArgumentParser()
|
| 295 |
+
parser.add_argument("--dataset_path", type=str, required=True)
|
| 296 |
+
parser.add_argument("--base_out_dir", type=str, required=True,
|
| 297 |
+
help="Directory containing best_model.pt from Optuna run")
|
| 298 |
+
parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True)
|
| 299 |
+
parser.add_argument("--seed", type=int, required=True,
|
| 300 |
+
help="Training seed for this run (1986, 42, 0, 123, 12345)")
|
| 301 |
+
parser.add_argument("--aggregate", action="store_true",
|
| 302 |
+
help="After all seeds done: aggregate results into CI summary")
|
| 303 |
+
parser.add_argument("--all_seeds", type=int, nargs="+", default=[1986, 42, 0, 123, 12345],
|
| 304 |
+
help="All seeds to aggregate (used with --aggregate)")
|
| 305 |
+
args = parser.parse_args()
|
| 306 |
+
|
| 307 |
+
if args.aggregate:
|
| 308 |
+
aggregate_seed_results(args.base_out_dir, args.all_seeds)
|
| 309 |
+
else:
|
| 310 |
+
refit_with_seed(
|
| 311 |
+
dataset_path=args.dataset_path,
|
| 312 |
+
base_out_dir=args.base_out_dir,
|
| 313 |
+
model_name=args.model,
|
| 314 |
+
seed=args.seed,
|
| 315 |
+
)
|
training_classifiers/refit_regression_seed.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import argparse
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 10 |
+
from datasets import load_from_disk, DatasetDict
|
| 11 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
| 12 |
+
from scipy.stats import spearmanr
|
| 13 |
+
from lightning.pytorch import seed_everything
|
| 14 |
+
from typing import Dict, Optional
|
| 15 |
+
|
| 16 |
+
scaler_amp = GradScaler(enabled=torch.cuda.is_available())
|
| 17 |
+
|
| 18 |
+
def load_split(dataset_path):
|
| 19 |
+
ds = load_from_disk(dataset_path)
|
| 20 |
+
if isinstance(ds, DatasetDict):
|
| 21 |
+
return ds["train"], ds["val"]
|
| 22 |
+
raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
|
| 23 |
+
|
| 24 |
+
def infer_in_dim(ds) -> int:
|
| 25 |
+
return int(len(ds[0]["embedding"][0]))
|
| 26 |
+
|
| 27 |
+
def collate_unpooled_reg(batch):
|
| 28 |
+
lengths = [int(x["length"]) for x in batch]
|
| 29 |
+
Lmax = max(lengths)
|
| 30 |
+
H = len(batch[0]["embedding"][0])
|
| 31 |
+
|
| 32 |
+
X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
|
| 33 |
+
M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
|
| 34 |
+
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
|
| 35 |
+
|
| 36 |
+
for i, x in enumerate(batch):
|
| 37 |
+
emb = torch.tensor(x["embedding"], dtype=torch.float32)
|
| 38 |
+
L = emb.shape[0]
|
| 39 |
+
X[i, :L] = emb
|
| 40 |
+
if "attention_mask" in x:
|
| 41 |
+
m = torch.tensor(x["attention_mask"], dtype=torch.bool)
|
| 42 |
+
M[i, :L] = m[:L]
|
| 43 |
+
else:
|
| 44 |
+
M[i, :L] = True
|
| 45 |
+
return X, M, y
|
| 46 |
+
|
| 47 |
+
# ======================== Models =========================================
|
| 48 |
+
|
| 49 |
+
class MaskedMeanPool(nn.Module):
|
| 50 |
+
def forward(self, X, M):
|
| 51 |
+
Mf = M.unsqueeze(-1).float()
|
| 52 |
+
return (X * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
|
| 53 |
+
|
| 54 |
+
class MLPRegressor(nn.Module):
|
| 55 |
+
def __init__(self, in_dim, hidden=512, dropout=0.1):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.pool = MaskedMeanPool()
|
| 58 |
+
self.net = nn.Sequential(
|
| 59 |
+
nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, 1)
|
| 60 |
+
)
|
| 61 |
+
def forward(self, X, M):
|
| 62 |
+
return self.net(self.pool(X, M)).squeeze(-1)
|
| 63 |
+
|
| 64 |
+
class CNNRegressor(nn.Module):
|
| 65 |
+
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
|
| 66 |
+
super().__init__()
|
| 67 |
+
blocks, ch = [], in_ch
|
| 68 |
+
for _ in range(layers):
|
| 69 |
+
blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)]
|
| 70 |
+
ch = c
|
| 71 |
+
self.conv = nn.Sequential(*blocks)
|
| 72 |
+
self.head = nn.Linear(c, 1)
|
| 73 |
+
def forward(self, X, M):
|
| 74 |
+
Y = self.conv(X.transpose(1, 2)).transpose(1, 2)
|
| 75 |
+
Mf = M.unsqueeze(-1).float()
|
| 76 |
+
return self.head((Y * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)).squeeze(-1)
|
| 77 |
+
|
| 78 |
+
class TransformerRegressor(nn.Module):
|
| 79 |
+
def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.proj = nn.Linear(in_dim, d_model)
|
| 82 |
+
self.enc = nn.TransformerEncoder(
|
| 83 |
+
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=ff,
|
| 84 |
+
dropout=dropout, batch_first=True, activation="gelu"),
|
| 85 |
+
num_layers=layers
|
| 86 |
+
)
|
| 87 |
+
self.head = nn.Linear(d_model, 1)
|
| 88 |
+
def forward(self, X, M):
|
| 89 |
+
Z = self.enc(self.proj(X), src_key_padding_mask=~M)
|
| 90 |
+
Mf = M.unsqueeze(-1).float()
|
| 91 |
+
return self.head((Z * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)).squeeze(-1)
|
| 92 |
+
|
| 93 |
+
# ======================== utils =========================================
|
| 94 |
+
|
| 95 |
+
def safe_spearmanr(y_true, y_pred):
|
| 96 |
+
rho = spearmanr(y_true, y_pred).correlation
|
| 97 |
+
return 0.0 if (rho is None or np.isnan(rho)) else float(rho)
|
| 98 |
+
|
| 99 |
+
def eval_regression(y_true, y_pred) -> Dict[str, float]:
|
| 100 |
+
try:
|
| 101 |
+
from sklearn.metrics import root_mean_squared_error
|
| 102 |
+
rmse = float(root_mean_squared_error(y_true, y_pred))
|
| 103 |
+
except Exception:
|
| 104 |
+
rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
|
| 105 |
+
return {
|
| 106 |
+
"spearman_rho": safe_spearmanr(y_true, y_pred),
|
| 107 |
+
"rmse": rmse,
|
| 108 |
+
"mae": float(mean_absolute_error(y_true, y_pred)),
|
| 109 |
+
"r2": float(r2_score(y_true, y_pred)),
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
def score_from_metrics(metrics, objective):
|
| 113 |
+
return {"spearman": metrics["spearman_rho"],
|
| 114 |
+
"neg_rmse": -metrics["rmse"],
|
| 115 |
+
"r2": metrics["r2"]}[objective]
|
| 116 |
+
|
| 117 |
+
@torch.no_grad()
|
| 118 |
+
def eval_preds(model, loader, device):
|
| 119 |
+
model.eval()
|
| 120 |
+
ys, ps = [], []
|
| 121 |
+
for X, M, y in loader:
|
| 122 |
+
X, M = X.to(device), M.to(device)
|
| 123 |
+
ps.append(model(X, M).cpu().numpy())
|
| 124 |
+
ys.append(y.numpy())
|
| 125 |
+
return np.concatenate(ys), np.concatenate(ps)
|
| 126 |
+
|
| 127 |
+
def train_one_epoch(model, loader, optim, criterion, device):
|
| 128 |
+
model.train()
|
| 129 |
+
for X, M, y in loader:
|
| 130 |
+
X, M, y = X.to(device), M.to(device), y.to(device)
|
| 131 |
+
optim.zero_grad(set_to_none=True)
|
| 132 |
+
with autocast(enabled=torch.cuda.is_available()):
|
| 133 |
+
loss = criterion(model(X, M), y)
|
| 134 |
+
scaler_amp.scale(loss).backward()
|
| 135 |
+
scaler_amp.step(optim)
|
| 136 |
+
scaler_amp.update()
|
| 137 |
+
|
| 138 |
+
def build_model(model_name, in_dim, params):
|
| 139 |
+
dropout = float(params.get("dropout", 0.1))
|
| 140 |
+
if model_name == "mlp":
|
| 141 |
+
return MLPRegressor(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout)
|
| 142 |
+
elif model_name == "cnn":
|
| 143 |
+
return CNNRegressor(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
|
| 144 |
+
layers=int(params["layers"]), dropout=dropout)
|
| 145 |
+
elif model_name == "transformer":
|
| 146 |
+
return TransformerRegressor(in_dim=in_dim, d_model=int(params["d_model"]),
|
| 147 |
+
nhead=int(params["nhead"]), layers=int(params["layers"]),
|
| 148 |
+
ff=int(params["ff"]), dropout=dropout)
|
| 149 |
+
raise ValueError(model_name)
|
| 150 |
+
|
| 151 |
+
# ======================== Refit Loop =========================================
|
| 152 |
+
|
| 153 |
+
def refit_with_seed(dataset_path, base_out_dir, model_name, seed,
|
| 154 |
+
objective="spearman", device="cuda:0"):
|
| 155 |
+
model_path = os.path.join(base_out_dir, "best_model.pt")
|
| 156 |
+
if not os.path.exists(model_path):
|
| 157 |
+
raise FileNotFoundError(f"No best_model.pt at {model_path}. Run Optuna first.")
|
| 158 |
+
|
| 159 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
| 160 |
+
best_params = checkpoint["best_params"]
|
| 161 |
+
print(f"Loaded best_params from {model_path}")
|
| 162 |
+
print(json.dumps(best_params, indent=2))
|
| 163 |
+
|
| 164 |
+
seed_everything(seed)
|
| 165 |
+
out_dir = os.path.join(base_out_dir, f"seed_{seed}")
|
| 166 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 167 |
+
|
| 168 |
+
train_ds, val_ds = load_split(dataset_path)
|
| 169 |
+
print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
|
| 170 |
+
|
| 171 |
+
batch_size = int(best_params.get("batch_size", 32))
|
| 172 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
|
| 173 |
+
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
|
| 174 |
+
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
|
| 175 |
+
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
|
| 176 |
+
|
| 177 |
+
in_dim = infer_in_dim(train_ds)
|
| 178 |
+
model = build_model(model_name, in_dim, best_params).to(device)
|
| 179 |
+
|
| 180 |
+
# Loss
|
| 181 |
+
loss_name = best_params.get("loss", "mse")
|
| 182 |
+
if loss_name == "mse":
|
| 183 |
+
criterion = nn.MSELoss()
|
| 184 |
+
else:
|
| 185 |
+
criterion = nn.HuberLoss(delta=float(best_params.get("huber_delta", 1.0)))
|
| 186 |
+
|
| 187 |
+
optim = torch.optim.AdamW(model.parameters(),
|
| 188 |
+
lr=float(best_params["lr"]),
|
| 189 |
+
weight_decay=float(best_params["weight_decay"]))
|
| 190 |
+
|
| 191 |
+
best_score, bad, patience = -1e18, 0, 15
|
| 192 |
+
best_state, best_metrics = None, {}
|
| 193 |
+
|
| 194 |
+
for epoch in range(1, 201):
|
| 195 |
+
train_one_epoch(model, train_loader, optim, criterion, device)
|
| 196 |
+
y_true, y_pred = eval_preds(model, val_loader, device)
|
| 197 |
+
metrics = eval_regression(y_true, y_pred)
|
| 198 |
+
score = score_from_metrics(metrics, objective)
|
| 199 |
+
|
| 200 |
+
if score > best_score + 1e-6:
|
| 201 |
+
best_score = score
|
| 202 |
+
best_metrics = metrics
|
| 203 |
+
bad = 0
|
| 204 |
+
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
| 205 |
+
else:
|
| 206 |
+
bad += 1
|
| 207 |
+
if bad >= patience:
|
| 208 |
+
print(f"Early stopping at epoch {epoch}")
|
| 209 |
+
break
|
| 210 |
+
|
| 211 |
+
if best_state:
|
| 212 |
+
model.load_state_dict(best_state)
|
| 213 |
+
|
| 214 |
+
y_true_val, y_pred_val = eval_preds(model, val_loader, device)
|
| 215 |
+
final_metrics = eval_regression(y_true_val, y_pred_val)
|
| 216 |
+
|
| 217 |
+
df_val = pd.DataFrame({
|
| 218 |
+
"y_true": y_true_val.astype(float),
|
| 219 |
+
"y_pred": y_pred_val.astype(float),
|
| 220 |
+
"residual": (y_true_val - y_pred_val).astype(float),
|
| 221 |
+
"abs_error": np.abs(y_true_val - y_pred_val).astype(float),
|
| 222 |
+
})
|
| 223 |
+
if "sequence" in val_ds.column_names:
|
| 224 |
+
df_val.insert(0, "sequence", np.asarray(val_ds["sequence"]))
|
| 225 |
+
df_val.to_csv(os.path.join(out_dir, "val_predictions.csv"), index=False)
|
| 226 |
+
|
| 227 |
+
torch.save({"state_dict": model.state_dict(), "best_params": best_params, "seed": seed},
|
| 228 |
+
os.path.join(out_dir, "model.pt"))
|
| 229 |
+
|
| 230 |
+
summary = {"model": model_name, "seed": seed, **{k: round(v, 6) for k, v in final_metrics.items()}}
|
| 231 |
+
with open(os.path.join(out_dir, "metrics.json"), "w") as f:
|
| 232 |
+
json.dump(summary, f, indent=2)
|
| 233 |
+
|
| 234 |
+
print(f"\n[Seed {seed}] rho={final_metrics['spearman_rho']:.4f} "
|
| 235 |
+
f"RMSE={final_metrics['rmse']:.4f} R2={final_metrics['r2']:.4f}")
|
| 236 |
+
return summary
|
| 237 |
+
|
| 238 |
+
# ======================== CI aggregation =========================================
|
| 239 |
+
|
| 240 |
+
def aggregate_seed_results(base_out_dir, seeds):
|
| 241 |
+
"""
|
| 242 |
+
Aggregates across seed runs using:
|
| 243 |
+
- t-distribution 95% CI for Spearman rho, RMSE, R2, MAE
|
| 244 |
+
For rho > 0.9, use Fisher z-transform CI instead.
|
| 245 |
+
"""
|
| 246 |
+
from scipy import stats
|
| 247 |
+
|
| 248 |
+
records = []
|
| 249 |
+
for seed in seeds:
|
| 250 |
+
p = os.path.join(base_out_dir, f"seed_{seed}", "metrics.json")
|
| 251 |
+
if os.path.exists(p):
|
| 252 |
+
records.append(json.load(open(p)))
|
| 253 |
+
else:
|
| 254 |
+
print(f"Warning: missing seed {seed}")
|
| 255 |
+
|
| 256 |
+
if not records:
|
| 257 |
+
raise ValueError("No seed results found.")
|
| 258 |
+
|
| 259 |
+
df = pd.DataFrame(records)
|
| 260 |
+
print("\nPer-seed results:")
|
| 261 |
+
print(df.to_string(index=False))
|
| 262 |
+
|
| 263 |
+
summary_rows = []
|
| 264 |
+
for metric in ["spearman_rho", "rmse", "mae", "r2"]:
|
| 265 |
+
vals = df[metric].values
|
| 266 |
+
n = len(vals)
|
| 267 |
+
mean = vals.mean()
|
| 268 |
+
std = vals.std(ddof=1)
|
| 269 |
+
se = std / np.sqrt(n)
|
| 270 |
+
t_crit = stats.t.ppf(0.975, df=n - 1)
|
| 271 |
+
ci = t_crit * se
|
| 272 |
+
row = {
|
| 273 |
+
"metric": metric,
|
| 274 |
+
"mean": round(mean, 4),
|
| 275 |
+
"std": round(std, 4),
|
| 276 |
+
"ci_95": round(ci, 4),
|
| 277 |
+
"report": f"{mean:.4f} ± {ci:.4f}",
|
| 278 |
+
"n_seeds": n,
|
| 279 |
+
}
|
| 280 |
+
# Flag if rho is high enough that the t-CI boundary might exceed 1.0
|
| 281 |
+
if metric == "spearman_rho" and (mean + ci > 0.95 or mean - ci < -0.95):
|
| 282 |
+
row["note"] = "rho near boundary — consider Fisher z-transform CI"
|
| 283 |
+
summary_rows.append(row)
|
| 284 |
+
|
| 285 |
+
summary_df = pd.DataFrame(summary_rows)
|
| 286 |
+
out_path = os.path.join(base_out_dir, "seed_aggregated_metrics.csv")
|
| 287 |
+
summary_df.to_csv(out_path, index=False)
|
| 288 |
+
|
| 289 |
+
print("\n=== Aggregated Metrics (95% CI, t-distribution) ===")
|
| 290 |
+
for _, row in summary_df.iterrows():
|
| 291 |
+
note = f" ← {row['note']}" if "note" in row and pd.notna(row.get("note")) else ""
|
| 292 |
+
print(f" {row['metric']:15s}: {row['report']}{note}")
|
| 293 |
+
print(f"\nSaved to {out_path}")
|
| 294 |
+
return summary_df
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
if __name__ == "__main__":
|
| 298 |
+
parser = argparse.ArgumentParser()
|
| 299 |
+
parser.add_argument("--dataset_path", type=str, required=True)
|
| 300 |
+
parser.add_argument("--base_out_dir", type=str, required=True)
|
| 301 |
+
parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True)
|
| 302 |
+
parser.add_argument("--seed", type=int, required=True)
|
| 303 |
+
parser.add_argument("--objective", type=str, default="spearman",
|
| 304 |
+
choices=["spearman", "neg_rmse", "r2"])
|
| 305 |
+
parser.add_argument("--aggregate", action="store_true")
|
| 306 |
+
parser.add_argument("--all_seeds", type=int, nargs="+", default=[1986, 42, 0, 123, 12345])
|
| 307 |
+
args = parser.parse_args()
|
| 308 |
+
|
| 309 |
+
if args.aggregate:
|
| 310 |
+
aggregate_seed_results(args.base_out_dir, args.all_seeds)
|
| 311 |
+
else:
|
| 312 |
+
refit_with_seed(
|
| 313 |
+
dataset_path=args.dataset_path,
|
| 314 |
+
base_out_dir=args.base_out_dir,
|
| 315 |
+
model_name=args.model,
|
| 316 |
+
seed=args.seed,
|
| 317 |
+
objective=args.objective,
|
| 318 |
+
)
|
training_classifiers/src_bash/binding_refit.bash
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=ba-refit-seed
|
| 3 |
+
#SBATCH --partition=dgx-b200
|
| 4 |
+
#SBATCH --gpus=1
|
| 5 |
+
#SBATCH --cpus-per-task=10
|
| 6 |
+
#SBATCH --mem=200G
|
| 7 |
+
#SBATCH --time=24:00:00
|
| 8 |
+
#SBATCH --output=%x_%A_%a.out
|
| 9 |
+
#SBATCH --array=0-4 # 5 seeds → indices 0..4
|
| 10 |
+
|
| 11 |
+
HOME_LOC=~/
|
| 12 |
+
SCRIPT_LOC=$HOME_LOC/PeptiVerse/training_classifiers
|
| 13 |
+
ALT_EMB_LOC=$HOME_LOC/PeptiVerse/training_data_clean
|
| 14 |
+
|
| 15 |
+
# ── Configure per submission ──────────────────────────────────────────
|
| 16 |
+
BINDER_MODEL='wt' # chemberta / peptideclm / wt
|
| 17 |
+
MODE='pooled' # pooled / unpooled
|
| 18 |
+
|
| 19 |
+
# wt-wt
|
| 20 |
+
DATA_PATH="${SCRIPT_LOC}/binding_affinity/pair_wt_wt_${MODE}"
|
| 21 |
+
BASE_OUT_DIR="${SCRIPT_LOC}/binding_affinity/wt_wt_${MODE}"
|
| 22 |
+
|
| 23 |
+
# wt-smiles (chemberta or peptideclm)
|
| 24 |
+
#DATA_PATH="${ALT_EMB_LOC}/binding_affinity/${BINDER_MODEL}/pair_wt_smiles_${MODE}"
|
| 25 |
+
#BASE_OUT_DIR="${SCRIPT_LOC}/binding_affinity/${BINDER_MODEL}_smiles_${MODE}"
|
| 26 |
+
# ────────────────────────────────────────────────────────────────────────────
|
| 27 |
+
|
| 28 |
+
SEEDS=(1986 42 0 123 12345)
|
| 29 |
+
SEED=${SEEDS[$SLURM_ARRAY_TASK_ID]}
|
| 30 |
+
|
| 31 |
+
LOG_LOC=$SCRIPT_LOC/src_bash/logs
|
| 32 |
+
mkdir -p $LOG_LOC
|
| 33 |
+
DATE=$(date +%m_%d)
|
| 34 |
+
|
| 35 |
+
cd $SCRIPT_LOC
|
| 36 |
+
|
| 37 |
+
echo "Running: binder=${BINDER_MODEL} mode=${MODE} seed=${SEED}"
|
| 38 |
+
echo " data : ${DATA_PATH}"
|
| 39 |
+
echo " out : ${BASE_OUT_DIR}"
|
| 40 |
+
|
| 41 |
+
START_TIME=$(date +%s%N)
|
| 42 |
+
|
| 43 |
+
python -u refit_binding_affinity_seed.py \
|
| 44 |
+
--dataset_path "${DATA_PATH}" \
|
| 45 |
+
--base_out_dir "${BASE_OUT_DIR}" \
|
| 46 |
+
--mode "${MODE}" \
|
| 47 |
+
--seed "${SEED}" \
|
| 48 |
+
> "${LOG_LOC}/${DATE}_ba_refit_${BINDER_MODEL}_${MODE}_seed${SEED}.log" 2>&1
|
| 49 |
+
|
| 50 |
+
END_TIME=$(date +%s%N)
|
| 51 |
+
ELAPSED_S=$(( (END_TIME - START_TIME) / 1000000000 ))
|
| 52 |
+
echo "Seed ${SEED} done at $(date) — wall clock: ${ELAPSED_S}s"
|
| 53 |
+
echo "{\"binder\": \"${BINDER_MODEL}\", \"mode\": \"${MODE}\", \"seed\": ${SEED}, \"wall_s\": ${ELAPSED_S}}" \
|
| 54 |
+
>> "${LOG_LOC}/${DATE}_wall_clock_ba_refit.jsonl"
|
| 55 |
+
|
| 56 |
+
conda deactivate
|
training_classifiers/src_bash/ml_uncertainty.bash
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=ml-walltime
|
| 3 |
+
#SBATCH --partition=b200-mig45
|
| 4 |
+
#SBATCH --gpus=1
|
| 5 |
+
#SBATCH --cpus-per-task=5
|
| 6 |
+
#SBATCH --mem=50G
|
| 7 |
+
#SBATCH --time=6:00:00
|
| 8 |
+
#SBATCH --output=%x_%j.out
|
| 9 |
+
|
| 10 |
+
# =============================================================================
|
| 11 |
+
# Unified Bootstrap CI + Uncertainty + Wall-time Refit
|
| 12 |
+
# wt, smiles, chemberta embeddings
|
| 13 |
+
# Runs sequentially: bootstrap/uncertainty first, then wall-time refit
|
| 14 |
+
# =============================================================================
|
| 15 |
+
|
| 16 |
+
HOME_LOC=~/
|
| 17 |
+
SCRIPT_LOC=$HOME_LOC/PeptiVerse/training_classifiers
|
| 18 |
+
ALT_EMB_LOC=$HOME_LOC/PeptiVerse/training_data_cleaned
|
| 19 |
+
LOG_LOC=$SCRIPT_LOC/src_bash/logs
|
| 20 |
+
mkdir -p $LOG_LOC
|
| 21 |
+
DATE=$(date +%m_%d)
|
| 22 |
+
|
| 23 |
+
cd $SCRIPT_LOC
|
| 24 |
+
# =============================================================================
|
| 25 |
+
# Helper functions
|
| 26 |
+
# =============================================================================
|
| 27 |
+
|
| 28 |
+
# Bootstrap CI + uncertainty
|
| 29 |
+
# $1=OBJECTIVE $2=WT $3=UNCERTAINTY_SCRIPT $4=MODEL_TYPE $5=UNC_MODE
|
| 30 |
+
run_bootstrap() {
|
| 31 |
+
local OBJECTIVE=$1
|
| 32 |
+
local WT=$2
|
| 33 |
+
local SCRIPT=$3
|
| 34 |
+
local MODEL_TYPE=$4
|
| 35 |
+
local UNC_MODE=$5
|
| 36 |
+
|
| 37 |
+
local VAL_PREDS="${SCRIPT_LOC}/${OBJECTIVE}/${MODEL_TYPE}_${WT}/val_predictions.csv"
|
| 38 |
+
local OUT_DIR="${SCRIPT_LOC}/${OBJECTIVE}/${MODEL_TYPE}_${WT}"
|
| 39 |
+
local LOG_FILE="${LOG_LOC}/${DATE}_ci_${MODEL_TYPE}_${OBJECTIVE}_${WT}.log"
|
| 40 |
+
|
| 41 |
+
if [ ! -f "$VAL_PREDS" ]; then
|
| 42 |
+
echo " [SKIP bootstrap] val_predictions.csv not found: $VAL_PREDS"
|
| 43 |
+
return
|
| 44 |
+
fi
|
| 45 |
+
|
| 46 |
+
echo " [bootstrap ci] ${MODEL_TYPE} / ${OBJECTIVE} / ${WT}"
|
| 47 |
+
python -u "$SCRIPT" \
|
| 48 |
+
--mode ci \
|
| 49 |
+
--val_preds "$VAL_PREDS" \
|
| 50 |
+
--out_dir "$OUT_DIR" \
|
| 51 |
+
--model_name "${MODEL_TYPE}_${WT}" \
|
| 52 |
+
>> "$LOG_FILE" 2>&1
|
| 53 |
+
|
| 54 |
+
echo " [bootstrap unc] ${MODEL_TYPE} / ${OBJECTIVE} / ${WT} (${UNC_MODE})"
|
| 55 |
+
python -u "$SCRIPT" \
|
| 56 |
+
--mode "$UNC_MODE" \
|
| 57 |
+
--val_preds "$VAL_PREDS" \
|
| 58 |
+
--out_dir "$OUT_DIR" \
|
| 59 |
+
--model_name "${MODEL_TYPE}_${WT}" \
|
| 60 |
+
>> "$LOG_FILE" 2>&1
|
| 61 |
+
|
| 62 |
+
echo " ${OUT_DIR}/"
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
# Wall-time refit
|
| 66 |
+
# $1=OBJECTIVE $2=WT $3=MODEL_TYPE $4=DATASET_PATH
|
| 67 |
+
run_walltime() {
|
| 68 |
+
local OBJECTIVE=$1
|
| 69 |
+
local WT=$2
|
| 70 |
+
local MODEL_TYPE=$3
|
| 71 |
+
local DATASET_PATH=$4
|
| 72 |
+
|
| 73 |
+
local MODEL_DIR="${SCRIPT_LOC}/${OBJECTIVE}/${MODEL_TYPE}_${WT}"
|
| 74 |
+
local LOG_FILE="${LOG_LOC}/${DATE}_walltime_${MODEL_TYPE}_${OBJECTIVE}_${WT}.log"
|
| 75 |
+
|
| 76 |
+
if [ ! -d "$MODEL_DIR" ]; then
|
| 77 |
+
echo " [SKIP walltime] model_dir not found: $MODEL_DIR"
|
| 78 |
+
return
|
| 79 |
+
fi
|
| 80 |
+
if [ ! -d "$DATASET_PATH" ]; then
|
| 81 |
+
echo " [SKIP walltime] dataset not found: $DATASET_PATH"
|
| 82 |
+
return
|
| 83 |
+
fi
|
| 84 |
+
|
| 85 |
+
echo " [walltime] ${MODEL_TYPE} / ${OBJECTIVE} / ${WT}"
|
| 86 |
+
python -u refit_ml_walltime.py \
|
| 87 |
+
--model_dir "$MODEL_DIR" \
|
| 88 |
+
--dataset_path "$DATASET_PATH" \
|
| 89 |
+
--logs_dir "$LOG_LOC" \
|
| 90 |
+
>> "$LOG_FILE" 2>&1
|
| 91 |
+
|
| 92 |
+
echo " logged to ${LOG_LOC}/${DATE}_wall_clock_ml.jsonl"
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
# =============================================================================
|
| 96 |
+
# Dataset path lookup
|
| 97 |
+
# $1=OBJECTIVE $2=WT
|
| 98 |
+
# =============================================================================
|
| 99 |
+
get_dataset_path() {
|
| 100 |
+
local OBJECTIVE=$1
|
| 101 |
+
local WT=$2
|
| 102 |
+
|
| 103 |
+
local DATA_LOC=$HOME_LOC/projects/Classifier_Weight/training_data_cleaned
|
| 104 |
+
|
| 105 |
+
case "${OBJECTIVE}|${WT}" in
|
| 106 |
+
# -- wt embeddings (ESM2 / original) ------------------------------
|
| 107 |
+
"hemolysis|wt") echo "${DATA_LOC}/hemolysis/hemo_wt_with_embeddings" ;;
|
| 108 |
+
"nf|wt") echo "${DATA_LOC}/nf/nf_wt_with_embeddings" ;;
|
| 109 |
+
"solubility|wt") echo "${DATA_LOC}/solubility/sol_wt_with_embeddings" ;;
|
| 110 |
+
"permeability_penetrance|wt") echo "${DATA_LOC}/permeability_penetrance/perm_wt_with_embeddings_pooled" ;;
|
| 111 |
+
# -- smiles embeddings (PeptideCLM) -------------------------------
|
| 112 |
+
"hemolysis|smiles") echo "${ALT_EMB_LOC}/hemolysis_peptideclm/hemo_smiles_with_embeddings" ;;
|
| 113 |
+
"nf|smiles") echo "${ALT_EMB_LOC}/nf_peptideclm/nf_smiles_with_embeddings" ;;
|
| 114 |
+
"permeability_pampa|smiles") echo "${ALT_EMB_LOC}/permeability_pampa_peptideclm/pampa_smiles_with_embeddings" ;;
|
| 115 |
+
"permeability_caco2|smiles") echo "${ALT_EMB_LOC}/permeability_caco2_peptideclm/caco2_smiles_with_embeddings" ;;
|
| 116 |
+
# -- chemberta embeddings -----------------------------------------
|
| 117 |
+
"hemolysis|chemberta") echo "${ALT_EMB_LOC}/hemolysis_chemberta/hemo_smiles_with_embeddings" ;;
|
| 118 |
+
"nf|chemberta") echo "${ALT_EMB_LOC}/nf_chemberta/nf_smiles_with_embeddings" ;;
|
| 119 |
+
"permeability_penetrance|chemberta") echo "${ALT_EMB_LOC}/permeability_chemberta/perm_smiles_with_embeddings" ;;
|
| 120 |
+
"permeability_penetrance|peptideclm") echo "${ALT_EMB_LOC}/permeability_peptideclm/perm_smiles_with_embeddings" ;;
|
| 121 |
+
"permeability_pampa|chemberta") echo "${ALT_EMB_LOC}/permeability_pampa_chemberta/pampa_smiles_with_embeddings" ;;
|
| 122 |
+
"permeability_caco2|chemberta") echo "${ALT_EMB_LOC}/permeability_caco2_chemberta/caco2_smiles_with_embeddings" ;;
|
| 123 |
+
*)
|
| 124 |
+
echo ""
|
| 125 |
+
;;
|
| 126 |
+
esac
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
# =============================================================================
|
| 130 |
+
# SECTION 1 - Classification tasks
|
| 131 |
+
# =============================================================================
|
| 132 |
+
echo ""
|
| 133 |
+
echo "============================================================"
|
| 134 |
+
echo " SECTION 1: Classification bootstrap + walltime"
|
| 135 |
+
echo "============================================================"
|
| 136 |
+
|
| 137 |
+
CLS_MODEL_TYPES=("svm_gpu" "enet_gpu" "xgb")
|
| 138 |
+
|
| 139 |
+
# hemolysis, nf - wt + smiles + chemberta
|
| 140 |
+
for OBJECTIVE in "hemolysis" "nf"; do
|
| 141 |
+
for WT in "wt" "smiles" "chemberta"; do
|
| 142 |
+
for MODEL_TYPE in "${CLS_MODEL_TYPES[@]}"; do
|
| 143 |
+
echo ""
|
| 144 |
+
echo "-- ${OBJECTIVE} / ${WT} / ${MODEL_TYPE} --"
|
| 145 |
+
run_bootstrap "$OBJECTIVE" "$WT" "ml_uncertainty.py" "$MODEL_TYPE" "uncertainty_prob"
|
| 146 |
+
DPATH=$(get_dataset_path "$OBJECTIVE" "$WT")
|
| 147 |
+
run_walltime "$OBJECTIVE" "$WT" "$MODEL_TYPE" "$DPATH"
|
| 148 |
+
done
|
| 149 |
+
done
|
| 150 |
+
done
|
| 151 |
+
|
| 152 |
+
# solubility, permeability_penetrance - wt + chemberta (no smiles embeddings)
|
| 153 |
+
for OBJECTIVE in "solubility" "permeability_penetrance"; do
|
| 154 |
+
for WT in "wt" "chemberta"; do
|
| 155 |
+
for MODEL_TYPE in "${CLS_MODEL_TYPES[@]}"; do
|
| 156 |
+
echo ""
|
| 157 |
+
echo "-- ${OBJECTIVE} / ${WT} / ${MODEL_TYPE} --"
|
| 158 |
+
run_bootstrap "$OBJECTIVE" "$WT" "ml_uncertainty.py" "$MODEL_TYPE" "uncertainty_prob"
|
| 159 |
+
DPATH=$(get_dataset_path "$OBJECTIVE" "$WT")
|
| 160 |
+
run_walltime "$OBJECTIVE" "$WT" "$MODEL_TYPE" "$DPATH"
|
| 161 |
+
done
|
| 162 |
+
done
|
| 163 |
+
done
|
| 164 |
+
|
| 165 |
+
# =============================================================================
|
| 166 |
+
# SECTION 2 - Regression tasks (PAMPA, Caco-2)
|
| 167 |
+
# =============================================================================
|
| 168 |
+
echo ""
|
| 169 |
+
echo "============================================================"
|
| 170 |
+
echo " SECTION 2: Regression bootstrap + walltime"
|
| 171 |
+
echo "============================================================"
|
| 172 |
+
|
| 173 |
+
REG_MODEL_TYPES=("svr" "enet_gpu" "xgb")
|
| 174 |
+
|
| 175 |
+
for OBJECTIVE in "permeability_pampa" "permeability_caco2"; do
|
| 176 |
+
for WT in "smiles" "chemberta"; do
|
| 177 |
+
for MODEL_TYPE in "${REG_MODEL_TYPES[@]}"; do
|
| 178 |
+
echo ""
|
| 179 |
+
echo "-- ${OBJECTIVE} / ${WT} / ${MODEL_TYPE} --"
|
| 180 |
+
run_bootstrap "$OBJECTIVE" "$WT" "ml_uncertainty_reg.py" "$MODEL_TYPE" "uncertainty_residual"
|
| 181 |
+
DPATH=$(get_dataset_path "$OBJECTIVE" "$WT")
|
| 182 |
+
run_walltime "$OBJECTIVE" "$WT" "$MODEL_TYPE" "$DPATH"
|
| 183 |
+
done
|
| 184 |
+
done
|
| 185 |
+
done
|
| 186 |
+
|
| 187 |
+
echo ""
|
| 188 |
+
echo "============================================================"
|
| 189 |
+
echo "All runs completed at $(date)"
|
| 190 |
+
echo "============================================================"
|
| 191 |
+
|
| 192 |
+
conda deactivate
|
training_classifiers/src_bash/nn_uncertainty.bash
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=refit-seed-array
|
| 3 |
+
#SBATCH --partition=dgx-b200
|
| 4 |
+
#SBATCH --gpus=1
|
| 5 |
+
#SBATCH --cpus-per-task=10
|
| 6 |
+
#SBATCH --mem=100G
|
| 7 |
+
#SBATCH --time=12:00:00
|
| 8 |
+
#SBATCH --output=%x_%A_%a.out
|
| 9 |
+
#SBATCH --array=0-4 # 5 seeds → indices 0..4
|
| 10 |
+
|
| 11 |
+
HOME_LOC=~/
|
| 12 |
+
SCRIPT_LOC=$HOME_LOC/PeptiVerse/training_classifiers
|
| 13 |
+
DATA_LOC=$HOME_LOC/PeptiVerse/training_data_cleaned
|
| 14 |
+
# ── Configure per submission ──────────────────────────────────────────
|
| 15 |
+
OBJECTIVE='permeability_pampa' # nf / solubility / hemolysis / permeability_penetrance/ permeability_pampa / permeability_caco2
|
| 16 |
+
WT='chemberta' # wt / smiles / chemberta / peptideclm
|
| 17 |
+
MODEL_TYPE='mlp' # mlp / cnn / transformer
|
| 18 |
+
DATA_FILE="hemo_${WT}_with_embeddings_unpooled" # nf / sol/ hemo / perm / pampa/ caco2
|
| 19 |
+
# Points to the directory where Optuna already saved best_model.pt
|
| 20 |
+
BASE_OUT_DIR="${SCRIPT_LOC}/${OBJECTIVE}/${MODEL_TYPE}_${WT}"
|
| 21 |
+
DATASET_PATH="${DATA_LOC}/permeability_${WT}/${DATA_FILE}"
|
| 22 |
+
# ────────────────────────────────────────────────────────────────────────────
|
| 23 |
+
|
| 24 |
+
SEEDS=(1986 42 0 123 12345)
|
| 25 |
+
SEED=${SEEDS[$SLURM_ARRAY_TASK_ID]}
|
| 26 |
+
|
| 27 |
+
LOG_LOC=$SCRIPT_LOC/src_bash/logs
|
| 28 |
+
mkdir -p $LOG_LOC
|
| 29 |
+
DATE=$(date +%m_%d)
|
| 30 |
+
|
| 31 |
+
cd $SCRIPT_LOC
|
| 32 |
+
|
| 33 |
+
echo "Running seed=$SEED model=$MODEL_TYPE objective=$OBJECTIVE wt=$WT"
|
| 34 |
+
|
| 35 |
+
START_TIME=$(date +%s%N)
|
| 36 |
+
|
| 37 |
+
python -u refit_nn_seed.py \
|
| 38 |
+
--dataset_path "${DATASET_PATH}" \
|
| 39 |
+
--base_out_dir "${BASE_OUT_DIR}" \
|
| 40 |
+
--model "${MODEL_TYPE}" \
|
| 41 |
+
--seed "${SEED}" \
|
| 42 |
+
> "${LOG_LOC}/${DATE}_refit_${MODEL_TYPE}_${OBJECTIVE}_${WT}_seed${SEED}.log" 2>&1
|
| 43 |
+
|
| 44 |
+
END_TIME=$(date +%s%N)
|
| 45 |
+
ELAPSED_S=$(( (END_TIME - START_TIME) / 1000000000 ))
|
| 46 |
+
|
| 47 |
+
echo "Seed $SEED done at $(date) — wall clock: ${ELAPSED_S}s"
|
| 48 |
+
echo "{\"model\": \"${MODEL_TYPE}\", \"objective\": \"${OBJECTIVE}\", \"wt\": \"${WT}\", \"seed\": ${SEED}, \"wall_s\": ${ELAPSED_S}}" \
|
| 49 |
+
>> "${LOG_LOC}/${DATE}_wall_clock_refit.jsonl"
|
| 50 |
+
|
| 51 |
+
|
training_classifiers/train_ml.py
CHANGED
|
@@ -55,11 +55,9 @@ def _stack_embeddings(col) -> np.ndarray:
|
|
| 55 |
def load_split_data(dataset_path: str) -> SplitData:
|
| 56 |
ds = load_from_disk(dataset_path)
|
| 57 |
|
| 58 |
-
# Case A: DatasetDict with train/val
|
| 59 |
if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
|
| 60 |
train_ds, val_ds = ds["train"], ds["val"]
|
| 61 |
else:
|
| 62 |
-
# Case B: Single dataset with "split" column
|
| 63 |
if "split" not in ds.column_names:
|
| 64 |
raise ValueError(
|
| 65 |
"Dataset must be a DatasetDict(train/val) or have a 'split' column."
|
|
@@ -201,7 +199,6 @@ def train_svm(X_train, y_train, X_val, y_val, params):
|
|
| 201 |
def train_linearsvm_calibrated(X_train, y_train, X_val, y_val, params):
|
| 202 |
"""
|
| 203 |
Fast linear SVM (LinearSVC) + probability calibration.
|
| 204 |
-
Usually much faster than SVC on large datasets.
|
| 205 |
"""
|
| 206 |
base = LinearSVC(
|
| 207 |
C=float(params["C"]),
|
|
|
|
| 55 |
def load_split_data(dataset_path: str) -> SplitData:
|
| 56 |
ds = load_from_disk(dataset_path)
|
| 57 |
|
|
|
|
| 58 |
if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
|
| 59 |
train_ds, val_ds = ds["train"], ds["val"]
|
| 60 |
else:
|
|
|
|
| 61 |
if "split" not in ds.column_names:
|
| 62 |
raise ValueError(
|
| 63 |
"Dataset must be a DatasetDict(train/val) or have a 'split' column."
|
|
|
|
| 199 |
def train_linearsvm_calibrated(X_train, y_train, X_val, y_val, params):
|
| 200 |
"""
|
| 201 |
Fast linear SVM (LinearSVC) + probability calibration.
|
|
|
|
| 202 |
"""
|
| 203 |
base = LinearSVC(
|
| 204 |
C=float(params["C"]),
|
training_data_cleaned/binding_affinity/binding_affinity_smiles_meta_with_split.csv
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3aee738ef2b17343ae69723a75473821b4188a196a55dacd0286ec47d065d531
|
| 3 |
+
size 4436974
|
training_data_cleaned/binding_affinity/binding_affinity_wt_meta_with_split.csv
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b7abc47729fa52a9f0aa68bffc6dd8c6562d0e4621d437a3a939c4ab27f46d80
|
| 3 |
+
size 3704486
|
training_data_cleaned/binding_affinity_split.py
CHANGED
|
@@ -1,62 +1,77 @@
|
|
| 1 |
-
import os
|
| 2 |
import math
|
| 3 |
-
from pathlib import Path
|
| 4 |
import sys
|
| 5 |
from contextlib import contextmanager
|
|
|
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
import torch
|
|
|
|
| 9 |
from tqdm import tqdm
|
| 10 |
-
from
|
| 11 |
-
from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence
|
| 12 |
-
from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM
|
| 13 |
-
from lightning.pytorch import seed_everything
|
| 14 |
-
seed_everything(1986)
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
)
|
| 21 |
|
| 22 |
-
|
| 23 |
-
WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
|
| 24 |
-
WT_MAX_LEN = 1022
|
| 25 |
-
WT_BATCH = 32
|
| 26 |
|
| 27 |
-
|
| 28 |
-
SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all"
|
| 29 |
-
TOKENIZER_VOCAB = "./Classifier_Weight/tokenizer/new_vocab.txt"
|
| 30 |
-
TOKENIZER_SPLITS = "./Classifier_Weight/tokenizer/new_splits.txt"
|
| 31 |
-
SMI_MAX_LEN = 768
|
| 32 |
-
SMI_BATCH = 128
|
| 33 |
|
| 34 |
-
#
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
COL_SMI_IPTM = "smiles_iptm_score"
|
| 46 |
|
| 47 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 49 |
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
|
| 55 |
def log(msg: str):
|
| 56 |
-
if LOG_FILE is not None:
|
| 57 |
-
Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True)
|
| 58 |
-
with open(LOG_FILE, "a") as f:
|
| 59 |
-
f.write(msg.rstrip() + "\n")
|
| 60 |
if not QUIET:
|
| 61 |
print(msg)
|
| 62 |
|
|
@@ -70,14 +85,22 @@ def section(title: str):
|
|
| 70 |
log(f"=== done: {title} ===")
|
| 71 |
|
| 72 |
|
| 73 |
-
#
|
| 74 |
-
#
|
| 75 |
-
#
|
|
|
|
| 76 |
def has_uaa(seq: str) -> bool:
|
| 77 |
return "X" in str(seq).upper()
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
def affinity_to_class(a: float) -> str:
|
| 80 |
-
# High: >= 9 ; Moderate: [7, 9) ; Low: < 7
|
| 81 |
if a >= 9.0:
|
| 82 |
return "High"
|
| 83 |
elif a >= 7.0:
|
|
@@ -87,10 +110,8 @@ def affinity_to_class(a: float) -> str:
|
|
| 87 |
|
| 88 |
def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
|
| 89 |
df = df.copy()
|
| 90 |
-
|
| 91 |
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
|
| 92 |
df = df.dropna(subset=[COL_AFF]).reset_index(drop=True)
|
| 93 |
-
|
| 94 |
df["affinity_class"] = df[COL_AFF].apply(affinity_to_class)
|
| 95 |
|
| 96 |
try:
|
|
@@ -101,717 +122,446 @@ def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
|
|
| 101 |
strat_col = "aff_bin"
|
| 102 |
|
| 103 |
rng = np.random.RandomState(RANDOM_SEED)
|
| 104 |
-
|
| 105 |
df["split"] = None
|
| 106 |
for _, g in df.groupby(strat_col, observed=True):
|
| 107 |
idx = g.index.to_numpy()
|
| 108 |
rng.shuffle(idx)
|
| 109 |
n_train = int(math.floor(len(idx) * TRAIN_FRAC))
|
| 110 |
df.loc[idx[:n_train], "split"] = "train"
|
| 111 |
-
df.loc[idx[n_train:],
|
| 112 |
-
|
| 113 |
df["split"] = df["split"].fillna("train")
|
| 114 |
return df
|
| 115 |
|
| 116 |
-
def
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
"
|
| 137 |
-
"
|
| 138 |
-
"
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
seq_col: str,
|
| 146 |
-
iptm_col: str,
|
| 147 |
-
aff_class_col: str = "affinity_class",
|
| 148 |
-
aff_bins: int = 30,
|
| 149 |
-
save_report_prefix: str | None = None,
|
| 150 |
-
verbose: bool = False,
|
| 151 |
-
):
|
| 152 |
-
df2 = df2.copy()
|
| 153 |
-
df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce")
|
| 154 |
-
df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce")
|
| 155 |
-
|
| 156 |
-
assert split_col in df2.columns, f"Missing split col: {split_col}"
|
| 157 |
-
assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}"
|
| 158 |
-
assert df2[affinity_col].notna().any(), "No valid affinity values after coercion."
|
| 159 |
|
| 160 |
-
try:
|
| 161 |
-
df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop")
|
| 162 |
-
except Exception:
|
| 163 |
-
df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str)
|
| 164 |
-
|
| 165 |
-
tr = df2[df2[split_col] == "train"].reset_index(drop=True)
|
| 166 |
-
va = df2[df2[split_col] == "val"].reset_index(drop=True)
|
| 167 |
-
|
| 168 |
-
tr_aff = _summ(tr[affinity_col].to_numpy())
|
| 169 |
-
va_aff = _summ(va[affinity_col].to_numpy())
|
| 170 |
-
tr_len = _len_stats(tr[seq_col].tolist())
|
| 171 |
-
va_len = _len_stats(va[seq_col].tolist())
|
| 172 |
-
|
| 173 |
-
# bin drift
|
| 174 |
-
bin_ct = (
|
| 175 |
-
df2.groupby([split_col, "_aff_bin_dbg"])
|
| 176 |
-
.size()
|
| 177 |
-
.groupby(level=0)
|
| 178 |
-
.apply(lambda s: s / s.sum())
|
| 179 |
-
)
|
| 180 |
-
tr_bins = bin_ct.loc["train"]
|
| 181 |
-
va_bins = bin_ct.loc["val"]
|
| 182 |
-
all_bins = tr_bins.index.union(va_bins.index)
|
| 183 |
-
tr_bins = tr_bins.reindex(all_bins, fill_value=0.0)
|
| 184 |
-
va_bins = va_bins.reindex(all_bins, fill_value=0.0)
|
| 185 |
-
max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values)))
|
| 186 |
-
|
| 187 |
-
msg = (
|
| 188 |
-
f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | "
|
| 189 |
-
f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | "
|
| 190 |
-
f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | "
|
| 191 |
-
f"max_bin_diff={max_bin_diff:.4f}"
|
| 192 |
-
)
|
| 193 |
-
log(msg)
|
| 194 |
-
|
| 195 |
-
if verbose and (not QUIET):
|
| 196 |
-
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
|
| 197 |
-
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0)
|
| 198 |
-
print("\n[verbose] affinity_class counts:\n", class_ct)
|
| 199 |
-
print("\n[verbose] affinity_class proportions:\n", class_prop.round(4))
|
| 200 |
-
|
| 201 |
-
if save_report_prefix is not None:
|
| 202 |
-
out = Path(save_report_prefix)
|
| 203 |
-
out.parent.mkdir(parents=True, exist_ok=True)
|
| 204 |
-
|
| 205 |
-
stats_df = pd.DataFrame([
|
| 206 |
-
{"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}},
|
| 207 |
-
{"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}},
|
| 208 |
-
])
|
| 209 |
-
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
|
| 210 |
-
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index()
|
| 211 |
-
|
| 212 |
-
stats_df.to_csv(out.with_suffix(".stats.csv"), index=False)
|
| 213 |
-
class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False)
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
# -------------------------
|
| 217 |
-
# WT pooled (ESM2)
|
| 218 |
-
# -------------------------
|
| 219 |
-
@torch.no_grad()
|
| 220 |
-
def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022):
|
| 221 |
-
embs = []
|
| 222 |
-
for i in pbar(range(0, len(seqs), batch_size)):
|
| 223 |
-
batch = seqs[i:i + batch_size]
|
| 224 |
-
inputs = tokenizer(
|
| 225 |
-
batch,
|
| 226 |
-
padding=True,
|
| 227 |
-
truncation=True,
|
| 228 |
-
max_length=max_length,
|
| 229 |
-
return_tensors="pt",
|
| 230 |
-
)
|
| 231 |
-
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 232 |
-
out = model(**inputs)
|
| 233 |
-
h = out.last_hidden_state # (B, L, H)
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
|
| 244 |
-
# -------------------------
|
| 245 |
-
# WT unpooled (ESM2)
|
| 246 |
-
# -------------------------
|
| 247 |
@torch.no_grad()
|
| 248 |
-
def
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
return
|
| 263 |
-
|
| 264 |
-
def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model):
|
| 265 |
-
"""
|
| 266 |
-
Expects df_split to have:
|
| 267 |
-
- target_sequence (seq1)
|
| 268 |
-
- sequence (binder seq2; WT binder)
|
| 269 |
-
- label, affinity_class, COL_AFF, COL_WT_IPTM
|
| 270 |
-
Saves a dataset where each row contains BOTH:
|
| 271 |
-
- target_embedding (Lt,H), target_attention_mask, target_length
|
| 272 |
-
- binder_embedding (Lb,H), binder_attention_mask, binder_length
|
| 273 |
-
"""
|
| 274 |
-
cls_id = tokenizer.cls_token_id
|
| 275 |
-
eos_id = tokenizer.eos_token_id
|
| 276 |
-
H = model.config.hidden_size
|
| 277 |
-
|
| 278 |
-
features = Features({
|
| 279 |
-
"target_sequence": Value("string"),
|
| 280 |
-
"sequence": Value("string"),
|
| 281 |
-
"label": Value("float32"),
|
| 282 |
-
"affinity": Value("float32"),
|
| 283 |
-
"affinity_class": Value("string"),
|
| 284 |
-
|
| 285 |
-
"target_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
|
| 286 |
-
"target_attention_mask": HFSequence(Value("int8")),
|
| 287 |
-
"target_length": Value("int64"),
|
| 288 |
-
|
| 289 |
-
"binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
|
| 290 |
-
"binder_attention_mask": HFSequence(Value("int8")),
|
| 291 |
-
"binder_length": Value("int64"),
|
| 292 |
-
|
| 293 |
-
COL_WT_IPTM: Value("float32"),
|
| 294 |
-
COL_AFF: Value("float32"),
|
| 295 |
-
})
|
| 296 |
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
bnd = str(getattr(r, "sequence")).strip()
|
| 301 |
-
|
| 302 |
-
y = float(getattr(r, "label"))
|
| 303 |
-
aff = float(getattr(r, COL_AFF))
|
| 304 |
-
acls = str(getattr(r, "affinity_class"))
|
| 305 |
-
|
| 306 |
-
iptm = getattr(r, COL_WT_IPTM)
|
| 307 |
-
iptm = float(iptm) if pd.notna(iptm) else np.nan
|
| 308 |
-
|
| 309 |
-
# token embeddings for target + binder (both ESM)
|
| 310 |
-
t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lt,H)
|
| 311 |
-
b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lb,H)
|
| 312 |
-
|
| 313 |
-
t_list = t_emb.tolist()
|
| 314 |
-
b_list = b_emb.tolist()
|
| 315 |
-
Lt = len(t_list)
|
| 316 |
-
Lb = len(b_list)
|
| 317 |
-
|
| 318 |
-
yield {
|
| 319 |
-
"target_sequence": tgt,
|
| 320 |
-
"sequence": bnd,
|
| 321 |
-
"label": np.float32(y),
|
| 322 |
-
"affinity": np.float32(aff),
|
| 323 |
-
"affinity_class": acls,
|
| 324 |
-
|
| 325 |
-
"target_embedding": t_list,
|
| 326 |
-
"target_attention_mask": [1] * Lt,
|
| 327 |
-
"target_length": int(Lt),
|
| 328 |
-
|
| 329 |
-
"binder_embedding": b_list,
|
| 330 |
-
"binder_attention_mask": [1] * Lb,
|
| 331 |
-
"binder_length": int(Lb),
|
| 332 |
-
|
| 333 |
-
COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
|
| 334 |
-
COL_AFF: np.float32(aff),
|
| 335 |
-
}
|
| 336 |
-
|
| 337 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 338 |
-
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
|
| 339 |
-
ds.save_to_disk(str(out_dir), max_shard_size="1GB")
|
| 340 |
-
return ds
|
| 341 |
-
|
| 342 |
-
def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled,
|
| 343 |
-
smi_tok, smi_roformer):
|
| 344 |
"""
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
Saves rows with:
|
| 350 |
-
target_embedding (Lt,Ht) from ESM
|
| 351 |
-
binder_embedding (Lb,Hb) from PeptideCLM
|
| 352 |
"""
|
| 353 |
-
cls_id =
|
| 354 |
-
eos_id =
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
})
|
|
|
|
|
|
|
| 381 |
|
| 382 |
-
def gen_rows(df: pd.DataFrame):
|
| 383 |
-
for r in pbar(df.itertuples(index=False), total=len(df)):
|
| 384 |
-
tgt = str(getattr(r, "target_sequence")).strip()
|
| 385 |
-
bnd = str(getattr(r, "sequence")).strip()
|
| 386 |
-
|
| 387 |
-
y = float(getattr(r, "label"))
|
| 388 |
-
aff = float(getattr(r, COL_AFF))
|
| 389 |
-
acls = str(getattr(r, "affinity_class"))
|
| 390 |
-
|
| 391 |
-
iptm = getattr(r, COL_SMI_IPTM)
|
| 392 |
-
iptm = float(iptm) if pd.notna(iptm) else np.nan
|
| 393 |
-
|
| 394 |
-
# target token embeddings (ESM)
|
| 395 |
-
t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN)
|
| 396 |
-
t_list = t_emb.tolist()
|
| 397 |
-
Lt = len(t_list)
|
| 398 |
-
|
| 399 |
-
# binder token embeddings (PeptideCLM)
|
| 400 |
-
_, tok_list, mask_list, lengths = smiles_embed_batch_return_both(
|
| 401 |
-
[bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN
|
| 402 |
-
)
|
| 403 |
-
b_emb = tok_list[0]
|
| 404 |
-
b_list = b_emb.tolist()
|
| 405 |
-
Lb = int(lengths[0])
|
| 406 |
-
b_mask = mask_list[0].astype(np.int8).tolist()
|
| 407 |
-
|
| 408 |
-
yield {
|
| 409 |
-
"target_sequence": tgt,
|
| 410 |
-
"sequence": bnd,
|
| 411 |
-
"label": np.float32(y),
|
| 412 |
-
"affinity": np.float32(aff),
|
| 413 |
-
"affinity_class": acls,
|
| 414 |
-
|
| 415 |
-
"target_embedding": t_list,
|
| 416 |
-
"target_attention_mask": [1] * Lt,
|
| 417 |
-
"target_length": int(Lt),
|
| 418 |
-
|
| 419 |
-
"binder_embedding": b_list,
|
| 420 |
-
"binder_attention_mask": [int(x) for x in b_mask],
|
| 421 |
-
"binder_length": int(Lb),
|
| 422 |
-
|
| 423 |
-
COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
|
| 424 |
-
COL_AFF: np.float32(aff),
|
| 425 |
-
}
|
| 426 |
-
|
| 427 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 428 |
-
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
|
| 429 |
-
ds.save_to_disk(str(out_dir), max_shard_size="1GB")
|
| 430 |
-
return ds
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
# -------------------------
|
| 434 |
-
# SMILES pooled + unpooled (PeptideCLM)
|
| 435 |
-
# -------------------------
|
| 436 |
-
def get_special_ids(tokenizer_obj):
|
| 437 |
-
cand = [
|
| 438 |
-
getattr(tokenizer_obj, "pad_token_id", None),
|
| 439 |
-
getattr(tokenizer_obj, "cls_token_id", None),
|
| 440 |
-
getattr(tokenizer_obj, "sep_token_id", None),
|
| 441 |
-
getattr(tokenizer_obj, "bos_token_id", None),
|
| 442 |
-
getattr(tokenizer_obj, "eos_token_id", None),
|
| 443 |
-
getattr(tokenizer_obj, "mask_token_id", None),
|
| 444 |
-
]
|
| 445 |
-
return sorted({x for x in cand if x is not None})
|
| 446 |
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
tok = tokenizer_obj(
|
| 450 |
-
batch_sequences,
|
| 451 |
-
return_tensors="pt",
|
| 452 |
-
padding=True,
|
| 453 |
-
truncation=True,
|
| 454 |
-
max_length=max_length,
|
| 455 |
-
)
|
| 456 |
-
input_ids = tok["input_ids"].to(DEVICE)
|
| 457 |
-
attention_mask = tok["attention_mask"].to(DEVICE)
|
| 458 |
-
|
| 459 |
-
outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask)
|
| 460 |
-
last_hidden = outputs.last_hidden_state # (B, L, H)
|
| 461 |
-
|
| 462 |
-
special_ids = get_special_ids(tokenizer_obj)
|
| 463 |
valid = attention_mask.bool()
|
| 464 |
-
if
|
| 465 |
-
|
| 466 |
-
if hasattr(torch, "isin"):
|
| 467 |
-
valid = valid & (~torch.isin(input_ids, sid))
|
| 468 |
-
else:
|
| 469 |
-
m = torch.zeros_like(valid)
|
| 470 |
-
for s in special_ids:
|
| 471 |
-
m |= (input_ids == s)
|
| 472 |
-
valid = valid & (~m)
|
| 473 |
|
| 474 |
valid_f = valid.unsqueeze(-1).float()
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
|
|
|
| 478 |
|
| 479 |
-
|
| 480 |
for b in range(last_hidden.shape[0]):
|
| 481 |
-
emb = last_hidden[b, valid[b]]
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
lengths.append(
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
)
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
# Main
|
| 528 |
-
#
|
|
|
|
| 529 |
def main():
|
| 530 |
-
|
|
|
|
|
|
|
| 531 |
OUT_ROOT.mkdir(parents=True, exist_ok=True)
|
| 532 |
|
|
|
|
|
|
|
|
|
|
| 533 |
with section("load csv + dedup"):
|
| 534 |
df = pd.read_csv(CSV_PATH)
|
| 535 |
-
|
|
|
|
|
|
|
|
|
|
| 536 |
if c in df.columns:
|
| 537 |
df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
# numeric affinity for both branches
|
| 551 |
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
|
| 552 |
|
| 553 |
-
#
|
| 554 |
-
|
| 555 |
-
|
|
|
|
|
|
|
|
|
|
| 556 |
df_wt = df.copy()
|
| 557 |
df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip()
|
| 558 |
-
df_wt = df_wt.dropna(subset=[COL_AFF])
|
| 559 |
-
df_wt = df_wt[df_wt[
|
| 560 |
-
df_wt = df_wt[
|
| 561 |
-
|
| 562 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
df_smi = df.copy()
|
| 564 |
-
df_smi = df_smi.dropna(subset=[COL_AFF])
|
| 565 |
df_smi = df_smi[
|
| 566 |
pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna()
|
| 567 |
-
]
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
df_smi["smiles_sequence"] = df_smi
|
| 572 |
-
df_smi = df_smi[df_smi["smiles_sequence"].notna()
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
df_smi2 = make_distribution_matched_split(df_smi)
|
| 581 |
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv"
|
| 585 |
-
df_wt2.to_csv(wt_split_csv, index=False)
|
| 586 |
-
df_smi2.to_csv(smi_split_csv, index=False)
|
| 587 |
-
log(f"Saved WT split meta: {wt_split_csv}")
|
| 588 |
-
log(f"Saved SMILES split meta: {smi_split_csv}")
|
| 589 |
-
|
| 590 |
-
verify_split_before_embedding(
|
| 591 |
-
df2=df_wt2,
|
| 592 |
-
affinity_col=COL_AFF,
|
| 593 |
-
split_col="split",
|
| 594 |
-
seq_col="wt_sequence",
|
| 595 |
-
iptm_col=COL_WT_IPTM,
|
| 596 |
-
aff_class_col="affinity_class",
|
| 597 |
-
aff_bins=AFFINITY_Q_BINS,
|
| 598 |
-
save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"),
|
| 599 |
-
verbose=False,
|
| 600 |
-
)
|
| 601 |
-
verify_split_before_embedding(
|
| 602 |
-
df2=df_smi2,
|
| 603 |
-
affinity_col=COL_AFF,
|
| 604 |
-
split_col="split",
|
| 605 |
-
seq_col="smiles_sequence",
|
| 606 |
-
iptm_col=COL_SMI_IPTM,
|
| 607 |
-
aff_class_col="affinity_class",
|
| 608 |
-
aff_bins=AFFINITY_Q_BINS,
|
| 609 |
-
save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"),
|
| 610 |
-
verbose=False,
|
| 611 |
-
)
|
| 612 |
|
| 613 |
-
#
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True)
|
| 622 |
-
return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]]
|
| 623 |
-
|
| 624 |
-
wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM)
|
| 625 |
-
smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM)
|
| 626 |
-
|
| 627 |
-
# -------------------------
|
| 628 |
-
# Split views
|
| 629 |
-
# -------------------------
|
| 630 |
-
wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True)
|
| 631 |
-
wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True)
|
| 632 |
smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True)
|
| 633 |
smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True)
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
#
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
# ---- SMILES targets ----
|
| 659 |
-
smi_train_tgt_emb = wt_pooled_embeddings(
|
| 660 |
-
smi_train["target_sequence"].astype(str).str.strip().tolist(),
|
| 661 |
-
wt_tok, wt_esm,
|
| 662 |
-
batch_size=WT_BATCH,
|
| 663 |
-
max_length=WT_MAX_LEN,
|
| 664 |
-
).astype(np.float32)
|
| 665 |
-
|
| 666 |
-
smi_val_tgt_emb = wt_pooled_embeddings(
|
| 667 |
-
smi_val["target_sequence"].astype(str).str.strip().tolist(),
|
| 668 |
-
wt_tok, wt_esm,
|
| 669 |
-
batch_size=WT_BATCH,
|
| 670 |
-
max_length=WT_MAX_LEN,
|
| 671 |
-
).astype(np.float32)
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
# =========================
|
| 675 |
-
# WT pooled binder embeddings (binder = WT peptide)
|
| 676 |
-
# =========================
|
| 677 |
-
with section("WT pooled binder embeddings + save"):
|
| 678 |
-
wt_train_emb = wt_pooled_embeddings(
|
| 679 |
-
wt_train["sequence"].astype(str).str.strip().tolist(),
|
| 680 |
-
wt_tok, wt_esm,
|
| 681 |
-
batch_size=WT_BATCH,
|
| 682 |
-
max_length=WT_MAX_LEN,
|
| 683 |
-
).astype(np.float32)
|
| 684 |
-
|
| 685 |
-
wt_val_emb = wt_pooled_embeddings(
|
| 686 |
-
wt_val["sequence"].astype(str).str.strip().tolist(),
|
| 687 |
-
wt_tok, wt_esm,
|
| 688 |
-
batch_size=WT_BATCH,
|
| 689 |
-
max_length=WT_MAX_LEN,
|
| 690 |
-
).astype(np.float32)
|
| 691 |
-
|
| 692 |
-
wt_train_ds = Dataset.from_dict({
|
| 693 |
-
"target_sequence": wt_train["target_sequence"].tolist(),
|
| 694 |
-
"sequence": wt_train["sequence"].tolist(),
|
| 695 |
-
"label": wt_train["label"].astype(float).tolist(),
|
| 696 |
-
"target_embedding": wt_train_tgt_emb,
|
| 697 |
-
"embedding": wt_train_emb,
|
| 698 |
-
COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(),
|
| 699 |
-
COL_AFF: wt_train[COL_AFF].astype(float).tolist(),
|
| 700 |
-
"affinity_class": wt_train["affinity_class"].tolist(),
|
| 701 |
-
})
|
| 702 |
-
|
| 703 |
-
wt_val_ds = Dataset.from_dict({
|
| 704 |
-
"target_sequence": wt_val["target_sequence"].tolist(),
|
| 705 |
-
"sequence": wt_val["sequence"].tolist(),
|
| 706 |
-
"label": wt_val["label"].astype(float).tolist(),
|
| 707 |
-
"target_embedding": wt_val_tgt_emb,
|
| 708 |
-
"embedding": wt_val_emb,
|
| 709 |
-
COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(),
|
| 710 |
-
COL_AFF: wt_val[COL_AFF].astype(float).tolist(),
|
| 711 |
-
"affinity_class": wt_val["affinity_class"].tolist(),
|
| 712 |
-
})
|
| 713 |
-
|
| 714 |
-
wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds})
|
| 715 |
-
wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled"
|
| 716 |
-
wt_pooled_dd.save_to_disk(str(wt_pooled_out))
|
| 717 |
-
log(f"Saved WT pooled -> {wt_pooled_out}")
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
# =========================
|
| 721 |
-
# SMILES pooled binder embeddings (binder = SMILES via PeptideCLM)
|
| 722 |
-
# =========================
|
| 723 |
-
with section("SMILES pooled binder embeddings + save"):
|
| 724 |
-
smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
|
| 725 |
-
smi_roformer = (
|
| 726 |
-
AutoModelForMaskedLM
|
| 727 |
-
.from_pretrained(SMI_MODEL_NAME)
|
| 728 |
-
.roformer
|
| 729 |
-
.to(DEVICE)
|
| 730 |
-
.eval()
|
| 731 |
-
)
|
| 732 |
-
|
| 733 |
-
smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
|
| 734 |
-
smi_train["sequence"].astype(str).str.strip().tolist(),
|
| 735 |
-
smi_tok, smi_roformer,
|
| 736 |
-
batch_size=SMI_BATCH,
|
| 737 |
-
max_length=SMI_MAX_LEN,
|
| 738 |
)
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
|
|
|
| 745 |
)
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
"target_sequence": smi_train["target_sequence"].tolist(),
|
| 749 |
-
"sequence": smi_train["sequence"].tolist(),
|
| 750 |
-
"label": smi_train["label"].astype(float).tolist(),
|
| 751 |
-
"target_embedding": smi_train_tgt_emb,
|
| 752 |
-
"embedding": smi_train_pooled.astype(np.float32),
|
| 753 |
-
COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(),
|
| 754 |
-
COL_AFF: smi_train[COL_AFF].astype(float).tolist(),
|
| 755 |
-
"affinity_class": smi_train["affinity_class"].tolist(),
|
| 756 |
-
})
|
| 757 |
-
|
| 758 |
-
smi_val_ds = Dataset.from_dict({
|
| 759 |
-
"target_sequence": smi_val["target_sequence"].tolist(),
|
| 760 |
-
"sequence": smi_val["sequence"].tolist(),
|
| 761 |
-
"label": smi_val["label"].astype(float).tolist(),
|
| 762 |
-
"target_embedding": smi_val_tgt_emb,
|
| 763 |
-
"embedding": smi_val_pooled.astype(np.float32),
|
| 764 |
-
COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(),
|
| 765 |
-
COL_AFF: smi_val[COL_AFF].astype(float).tolist(),
|
| 766 |
-
"affinity_class": smi_val["affinity_class"].tolist(),
|
| 767 |
-
})
|
| 768 |
-
|
| 769 |
-
smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds})
|
| 770 |
-
smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled"
|
| 771 |
-
smi_pooled_dd.save_to_disk(str(smi_pooled_out))
|
| 772 |
-
log(f"Saved SMILES pooled -> {smi_pooled_out}")
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
# =========================
|
| 776 |
-
# WT unpooled paired (ESM target + ESM binder) + save
|
| 777 |
-
# =========================
|
| 778 |
-
with section("WT unpooled paired embeddings + save"):
|
| 779 |
-
wt_tok_unpooled = wt_tok # reuse tokenizer
|
| 780 |
-
wt_esm_unpooled = wt_esm # reuse model
|
| 781 |
-
|
| 782 |
-
wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled"
|
| 783 |
-
wt_unpooled_dd = DatasetDict({
|
| 784 |
-
"train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train",
|
| 785 |
-
wt_tok_unpooled, wt_esm_unpooled),
|
| 786 |
-
"val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val",
|
| 787 |
-
wt_tok_unpooled, wt_esm_unpooled),
|
| 788 |
-
})
|
| 789 |
-
wt_unpooled_dd.save_to_disk(str(wt_unpooled_out))
|
| 790 |
-
log(f"Saved WT unpooled -> {wt_unpooled_out}")
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
# =========================
|
| 794 |
-
# SMILES unpooled paired (ESM target + PeptideCLM binder) + save
|
| 795 |
-
# =========================
|
| 796 |
-
with section("SMILES unpooled paired embeddings + save"):
|
| 797 |
-
smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled"
|
| 798 |
-
smi_unpooled_dd = DatasetDict({
|
| 799 |
-
"train": build_smiles_unpooled_paired_dataset(
|
| 800 |
-
smi_train, smi_unpooled_out / "train",
|
| 801 |
-
wt_tok, wt_esm,
|
| 802 |
-
smi_tok, smi_roformer
|
| 803 |
-
),
|
| 804 |
-
"val": build_smiles_unpooled_paired_dataset(
|
| 805 |
-
smi_val, smi_unpooled_out / "val",
|
| 806 |
-
wt_tok, wt_esm,
|
| 807 |
-
smi_tok, smi_roformer
|
| 808 |
-
),
|
| 809 |
-
})
|
| 810 |
-
smi_unpooled_dd.save_to_disk(str(smi_unpooled_out))
|
| 811 |
-
log(f"Saved SMILES unpooled -> {smi_unpooled_out}")
|
| 812 |
-
|
| 813 |
-
log(f"\n[DONE] All datasets saved under: {OUT_ROOT}")
|
| 814 |
|
| 815 |
|
| 816 |
if __name__ == "__main__":
|
| 817 |
-
main()
|
|
|
|
|
|
|
| 1 |
import math
|
|
|
|
| 2 |
import sys
|
| 3 |
from contextlib import contextmanager
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
import torch
|
| 9 |
+
from datasets import Dataset, DatasetDict
|
| 10 |
from tqdm import tqdm
|
| 11 |
+
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, EsmModel
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
# ======================
|
| 14 |
+
# CONFIG
|
| 15 |
+
# ======================
|
| 16 |
|
| 17 |
+
ROOT = Path("<>") # CHANGE HERE
|
| 18 |
+
PROJ_ROOT = ROOT / "PeptiVerse"
|
|
|
|
| 19 |
|
| 20 |
+
CSV_PATH = PROJ_ROOT / "training_data" / "c-binding.csv"
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
OUT_ROOT = PROJ_ROOT / "training_data_cleaned" / "binding_affinity"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
# ESM2 - target encoder (shared across all branches)
|
| 25 |
+
ESM_MODEL = "facebook/esm2_t33_650M_UR50D"
|
| 26 |
+
ESM_MAX_LEN = 1022
|
| 27 |
+
ESM_BATCH = 32
|
| 28 |
+
|
| 29 |
+
# PeptideCLM - SMILES binder encoder
|
| 30 |
+
sys.path.append(str(PROJ_ROOT))
|
| 31 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 32 |
+
|
| 33 |
+
PEPTIDECLM_MODEL = "aaronfeller/PeptideCLM-23M-all"
|
| 34 |
+
TOKENIZER_VOCAB = str(PROJ_ROOT / "tokenizer" / "new_vocab.txt")
|
| 35 |
+
TOKENIZER_SPLITS = str(PROJ_ROOT / "tokenizer" / "new_splits.txt")
|
| 36 |
+
PEPTIDECLM_MAX_LEN = 768
|
| 37 |
+
PEPTIDECLM_BATCH = 128
|
| 38 |
+
|
| 39 |
+
# ChemBERTa - SMILES binder encoder
|
| 40 |
+
CHEMBERTA_MODEL = "DeepChem/ChemBERTa-77M-MLM"
|
| 41 |
+
CHEMBERTA_MAX_LEN = 512
|
| 42 |
+
CHEMBERTA_BATCH = 128
|
| 43 |
+
|
| 44 |
+
# Which SMILES binder models to run
|
| 45 |
+
RUN_PEPTIDECLM = True
|
| 46 |
+
RUN_CHEMBERTA = True
|
| 47 |
+
|
| 48 |
+
# CSV column names
|
| 49 |
+
COL_SEQ1 = "seq1"
|
| 50 |
+
COL_SEQ2 = "seq2"
|
| 51 |
+
COL_AFF = "affinity"
|
| 52 |
+
COL_F2S = "Fasta2SMILES"
|
| 53 |
+
COL_REACT = "REACT_SMILES"
|
| 54 |
+
COL_MERGE = "Merge_SMILES"
|
| 55 |
+
COL_WT_IPTM = "wt_iptm_score"
|
| 56 |
COL_SMI_IPTM = "smiles_iptm_score"
|
| 57 |
|
| 58 |
+
# Split config
|
| 59 |
+
TRAIN_FRAC = 0.80
|
| 60 |
+
RANDOM_SEED = 1986
|
| 61 |
+
AFFINITY_Q_BINS = 30
|
| 62 |
+
|
| 63 |
+
# Logging
|
| 64 |
+
QUIET = True
|
| 65 |
+
USE_TQDM = False
|
| 66 |
+
|
| 67 |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 68 |
|
| 69 |
|
| 70 |
+
# ======================
|
| 71 |
+
# Logging / progress
|
| 72 |
+
# ======================
|
| 73 |
|
| 74 |
def log(msg: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
if not QUIET:
|
| 76 |
print(msg)
|
| 77 |
|
|
|
|
| 85 |
log(f"=== done: {title} ===")
|
| 86 |
|
| 87 |
|
| 88 |
+
# ======================
|
| 89 |
+
# Data Handling
|
| 90 |
+
# ======================
|
| 91 |
+
|
| 92 |
def has_uaa(seq: str) -> bool:
|
| 93 |
return "X" in str(seq).upper()
|
| 94 |
|
| 95 |
+
def pick_smiles(row) -> str | None:
|
| 96 |
+
"""Column Priority: Fasta2SMILES > REACT_SMILES > Merge_SMILES."""
|
| 97 |
+
for col in [COL_F2S, COL_REACT, COL_MERGE]:
|
| 98 |
+
val = row.get(col, None)
|
| 99 |
+
if val is not None and str(val).strip() not in ("", "nan", "None"):
|
| 100 |
+
return str(val).strip()
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
def affinity_to_class(a: float) -> str:
|
|
|
|
| 104 |
if a >= 9.0:
|
| 105 |
return "High"
|
| 106 |
elif a >= 7.0:
|
|
|
|
| 110 |
|
| 111 |
def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
|
| 112 |
df = df.copy()
|
|
|
|
| 113 |
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
|
| 114 |
df = df.dropna(subset=[COL_AFF]).reset_index(drop=True)
|
|
|
|
| 115 |
df["affinity_class"] = df[COL_AFF].apply(affinity_to_class)
|
| 116 |
|
| 117 |
try:
|
|
|
|
| 122 |
strat_col = "aff_bin"
|
| 123 |
|
| 124 |
rng = np.random.RandomState(RANDOM_SEED)
|
|
|
|
| 125 |
df["split"] = None
|
| 126 |
for _, g in df.groupby(strat_col, observed=True):
|
| 127 |
idx = g.index.to_numpy()
|
| 128 |
rng.shuffle(idx)
|
| 129 |
n_train = int(math.floor(len(idx) * TRAIN_FRAC))
|
| 130 |
df.loc[idx[:n_train], "split"] = "train"
|
| 131 |
+
df.loc[idx[n_train:], "split"] = "val"
|
|
|
|
| 132 |
df["split"] = df["split"].fillna("train")
|
| 133 |
return df
|
| 134 |
|
| 135 |
+
def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame:
|
| 136 |
+
out = df_in.copy()
|
| 137 |
+
out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip()
|
| 138 |
+
out["sequence"] = out[binder_seq_col].astype(str).str.strip()
|
| 139 |
+
out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce")
|
| 140 |
+
out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce")
|
| 141 |
+
out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce")
|
| 142 |
+
out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True)
|
| 143 |
+
return out[["target_sequence", "sequence", "label", "split",
|
| 144 |
+
iptm_col, COL_AFF, "affinity_class"]]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# ======================
|
| 148 |
+
# Dataset builders
|
| 149 |
+
# ======================
|
| 150 |
+
|
| 151 |
+
def build_pooled_ds(view: pd.DataFrame, iptm_col: str,
|
| 152 |
+
tgt_embs: np.ndarray, bnd_embs: np.ndarray) -> Dataset:
|
| 153 |
+
"""Both target and binder are (N, H) pooled float32 arrays."""
|
| 154 |
+
return Dataset.from_dict({
|
| 155 |
+
"target_sequence": view["target_sequence"].tolist(),
|
| 156 |
+
"sequence": view["sequence"].tolist(),
|
| 157 |
+
"label": view["label"].astype(float).tolist(),
|
| 158 |
+
"target_embedding": tgt_embs, # (N, H_esm) float32
|
| 159 |
+
"binder_embedding": bnd_embs, # (N, H_binder) float32
|
| 160 |
+
"affinity": view[COL_AFF].astype(float).tolist(),
|
| 161 |
+
"affinity_class": view["affinity_class"].tolist(),
|
| 162 |
+
iptm_col: view[iptm_col].astype(float).tolist(),
|
| 163 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
+
def build_unpooled_ds(view: pd.DataFrame, iptm_col: str,
|
| 167 |
+
tgt_tok_embs, tgt_masks, tgt_lengths,
|
| 168 |
+
bnd_tok_embs, bnd_masks, bnd_lengths) -> Dataset:
|
| 169 |
+
"""
|
| 170 |
+
Per-token lists for both sides.
|
| 171 |
+
target_embedding[i] : (Lt_i, H_esm) float16 ndarray
|
| 172 |
+
binder_embedding[i] : (Lb_i, H_binder) float16 ndarray
|
| 173 |
+
"""
|
| 174 |
+
return Dataset.from_dict({
|
| 175 |
+
"target_sequence": view["target_sequence"].tolist(),
|
| 176 |
+
"sequence": view["sequence"].tolist(),
|
| 177 |
+
"label": view["label"].astype(float).tolist(),
|
| 178 |
+
|
| 179 |
+
"target_embedding": tgt_tok_embs,
|
| 180 |
+
"target_attention_mask": tgt_masks,
|
| 181 |
+
"target_length": tgt_lengths,
|
| 182 |
+
|
| 183 |
+
"binder_embedding": bnd_tok_embs,
|
| 184 |
+
"binder_attention_mask": bnd_masks,
|
| 185 |
+
"binder_length": bnd_lengths,
|
| 186 |
+
|
| 187 |
+
"affinity": view[COL_AFF].astype(float).tolist(),
|
| 188 |
+
"affinity_class": view["affinity_class"].tolist(),
|
| 189 |
+
iptm_col: view[iptm_col].astype(float).tolist(),
|
| 190 |
+
})
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# ======================
|
| 194 |
+
# ESM2 - shared target encoder
|
| 195 |
+
# ======================
|
| 196 |
|
| 197 |
+
def load_esm():
|
| 198 |
+
print(f" Loading ESM2: {ESM_MODEL}")
|
| 199 |
+
tok = AutoTokenizer.from_pretrained(ESM_MODEL)
|
| 200 |
+
model = EsmModel.from_pretrained(ESM_MODEL).to(DEVICE).eval()
|
| 201 |
+
return tok, model
|
| 202 |
|
| 203 |
|
|
|
|
|
|
|
|
|
|
| 204 |
@torch.no_grad()
|
| 205 |
+
def embed_esm_pooled(seqs, tok, model) -> np.ndarray:
|
| 206 |
+
"""Returns (N, H) float32 - mean-pooled over non-pad tokens."""
|
| 207 |
+
all_embs = []
|
| 208 |
+
for i in pbar(range(0, len(seqs), ESM_BATCH), desc=" ESM2 pooled"):
|
| 209 |
+
batch = seqs[i:i + ESM_BATCH]
|
| 210 |
+
enc = tok(batch, return_tensors="pt", padding=True,
|
| 211 |
+
truncation=True, max_length=ESM_MAX_LEN)
|
| 212 |
+
ids = enc["input_ids"].to(DEVICE)
|
| 213 |
+
mask = enc["attention_mask"].to(DEVICE)
|
| 214 |
+
h = model(input_ids=ids, attention_mask=mask).last_hidden_state
|
| 215 |
+
attn_f = mask.unsqueeze(-1).float()
|
| 216 |
+
pooled = ((h * attn_f).sum(dim=1) /
|
| 217 |
+
attn_f.sum(dim=1).clamp(min=1e-9)).cpu().numpy().astype(np.float32)
|
| 218 |
+
all_embs.append(pooled)
|
| 219 |
+
return np.vstack(all_embs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
+
|
| 222 |
+
@torch.no_grad()
|
| 223 |
+
def embed_esm_unpooled(seqs, tok, model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
"""
|
| 225 |
+
Returns per-token lists (CLS/EOS/pad excluded).
|
| 226 |
+
tok_embs : list of (Lt_i, H) float16 arrays
|
| 227 |
+
masks : list of (Lt_i,) int8 arrays (all-ones)
|
| 228 |
+
lengths : list of int
|
|
|
|
|
|
|
|
|
|
| 229 |
"""
|
| 230 |
+
cls_id = tok.cls_token_id
|
| 231 |
+
eos_id = tok.eos_token_id
|
| 232 |
+
|
| 233 |
+
tok_embs, masks, lengths = [], [], []
|
| 234 |
+
for i in pbar(range(0, len(seqs), ESM_BATCH), desc=" ESM2 unpooled"):
|
| 235 |
+
batch = seqs[i:i + ESM_BATCH]
|
| 236 |
+
enc = tok(batch, return_tensors="pt", padding=True,
|
| 237 |
+
truncation=True, max_length=ESM_MAX_LEN)
|
| 238 |
+
ids = enc["input_ids"].to(DEVICE)
|
| 239 |
+
mask = enc["attention_mask"].to(DEVICE)
|
| 240 |
+
h = model(input_ids=ids, attention_mask=mask).last_hidden_state
|
| 241 |
+
|
| 242 |
+
for b in range(h.shape[0]):
|
| 243 |
+
keep = mask[b].bool()
|
| 244 |
+
if cls_id is not None:
|
| 245 |
+
keep = keep & (ids[b] != cls_id)
|
| 246 |
+
if eos_id is not None:
|
| 247 |
+
keep = keep & (ids[b] != eos_id)
|
| 248 |
+
emb = h[b, keep].cpu().to(torch.float16).numpy()
|
| 249 |
+
tok_embs.append(emb)
|
| 250 |
+
masks.append(np.ones(emb.shape[0], dtype=np.int8))
|
| 251 |
+
lengths.append(emb.shape[0])
|
| 252 |
+
return tok_embs, masks, lengths
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# ======================
|
| 256 |
+
# Generic binder embedding helpers
|
| 257 |
+
# ======================
|
| 258 |
+
|
| 259 |
+
def _get_special_ids_t(tokenizer):
|
| 260 |
+
special_ids = sorted({
|
| 261 |
+
x for x in [
|
| 262 |
+
getattr(tokenizer, attr, None)
|
| 263 |
+
for attr in ("pad_token_id", "cls_token_id", "sep_token_id",
|
| 264 |
+
"bos_token_id", "eos_token_id", "mask_token_id")
|
| 265 |
+
] if x is not None
|
| 266 |
})
|
| 267 |
+
return (torch.tensor(special_ids, device=DEVICE, dtype=torch.long)
|
| 268 |
+
if special_ids else None)
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
+
def _pool_and_unpool(last_hidden, input_ids, attention_mask, special_ids_t):
|
| 272 |
+
"""Mean-pool over non-special valid tokens; also return per-token arrays."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
valid = attention_mask.bool()
|
| 274 |
+
if special_ids_t is not None:
|
| 275 |
+
valid = valid & (~torch.isin(input_ids, special_ids_t))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
valid_f = valid.unsqueeze(-1).float()
|
| 278 |
+
pooled = (
|
| 279 |
+
torch.sum(last_hidden * valid_f, dim=1) /
|
| 280 |
+
torch.clamp(valid_f.sum(dim=1), min=1e-9)
|
| 281 |
+
).cpu().numpy().astype(np.float32)
|
| 282 |
|
| 283 |
+
tok_embs, masks, lengths = [], [], []
|
| 284 |
for b in range(last_hidden.shape[0]):
|
| 285 |
+
emb = last_hidden[b, valid[b]].cpu().to(torch.float16).numpy()
|
| 286 |
+
tok_embs.append(emb)
|
| 287 |
+
masks.append(np.ones(emb.shape[0], dtype=np.int8))
|
| 288 |
+
lengths.append(emb.shape[0])
|
| 289 |
+
return pooled, tok_embs, masks, lengths
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# ======================
|
| 293 |
+
# PeptideCLM - SMILES binder encoder
|
| 294 |
+
# ======================
|
| 295 |
+
|
| 296 |
+
def load_peptideclm():
|
| 297 |
+
print(f" Loading PeptideCLM: {PEPTIDECLM_MODEL}")
|
| 298 |
+
tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
|
| 299 |
+
model = (AutoModelForMaskedLM.from_pretrained(PEPTIDECLM_MODEL)
|
| 300 |
+
.roformer.to(DEVICE).eval())
|
| 301 |
+
return tok, model, _get_special_ids_t(tok)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@torch.no_grad()
|
| 305 |
+
def embed_peptideclm(seqs, tok, model, sid_t):
|
| 306 |
+
pooled_all, tok_all, mask_all, len_all = [], [], [], []
|
| 307 |
+
for i in pbar(range(0, len(seqs), PEPTIDECLM_BATCH), desc=" PeptideCLM binder"):
|
| 308 |
+
batch = seqs[i:i + PEPTIDECLM_BATCH]
|
| 309 |
+
enc = tok(batch, return_tensors="pt", padding=True,
|
| 310 |
+
truncation=True, max_length=PEPTIDECLM_MAX_LEN)
|
| 311 |
+
ids = enc["input_ids"].to(DEVICE)
|
| 312 |
+
mask = enc["attention_mask"].to(DEVICE)
|
| 313 |
+
h = model(input_ids=ids, attention_mask=mask).last_hidden_state
|
| 314 |
+
p, t, m, l = _pool_and_unpool(h, ids, mask, sid_t)
|
| 315 |
+
pooled_all.append(p); tok_all.extend(t); mask_all.extend(m); len_all.extend(l)
|
| 316 |
+
return np.vstack(pooled_all), tok_all, mask_all, len_all
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# ======================
|
| 320 |
+
# ChemBERTa - SMILES binder encoder
|
| 321 |
+
# ======================
|
| 322 |
+
|
| 323 |
+
def load_chemberta():
|
| 324 |
+
print(f" Loading ChemBERTa: {CHEMBERTA_MODEL}")
|
| 325 |
+
tok = AutoTokenizer.from_pretrained(CHEMBERTA_MODEL)
|
| 326 |
+
model = AutoModel.from_pretrained(CHEMBERTA_MODEL).to(DEVICE).eval()
|
| 327 |
+
return tok, model, _get_special_ids_t(tok)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
@torch.no_grad()
|
| 331 |
+
def embed_chemberta(seqs, tok, model, sid_t):
|
| 332 |
+
pooled_all, tok_all, mask_all, len_all = [], [], [], []
|
| 333 |
+
for i in pbar(range(0, len(seqs), CHEMBERTA_BATCH), desc=" ChemBERTa binder"):
|
| 334 |
+
batch = seqs[i:i + CHEMBERTA_BATCH]
|
| 335 |
+
enc = tok(batch, return_tensors="pt", padding=True,
|
| 336 |
+
truncation=True, max_length=CHEMBERTA_MAX_LEN)
|
| 337 |
+
ids = enc["input_ids"].to(DEVICE)
|
| 338 |
+
mask = enc["attention_mask"].to(DEVICE)
|
| 339 |
+
h = model(input_ids=ids, attention_mask=mask).last_hidden_state
|
| 340 |
+
p, t, m, l = _pool_and_unpool(h, ids, mask, sid_t)
|
| 341 |
+
pooled_all.append(p); tok_all.extend(t); mask_all.extend(m); len_all.extend(l)
|
| 342 |
+
return np.vstack(pooled_all), tok_all, mask_all, len_all
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# ======================
|
| 346 |
+
# WT branch (ESM2 × ESM2)
|
| 347 |
+
# ======================
|
| 348 |
+
|
| 349 |
+
def run_wt_branch(wt_train: pd.DataFrame, wt_val: pd.DataFrame,
|
| 350 |
+
esm_tok, esm_model):
|
| 351 |
+
print("\n" + "="*55)
|
| 352 |
+
print(" Branch : WT (ESM2 target × ESM2 binder)")
|
| 353 |
+
print("="*55)
|
| 354 |
+
|
| 355 |
+
pooled_splits, unpooled_splits = {}, {}
|
| 356 |
+
|
| 357 |
+
for split_name, view in [("train", wt_train), ("val", wt_val)]:
|
| 358 |
+
print(f"\n [{split_name}] {len(view)} rows")
|
| 359 |
+
targets = view["target_sequence"].tolist()
|
| 360 |
+
binders = view["sequence"].tolist()
|
| 361 |
+
|
| 362 |
+
tgt_pooled = embed_esm_pooled(targets, esm_tok, esm_model)
|
| 363 |
+
bnd_pooled = embed_esm_pooled(binders, esm_tok, esm_model)
|
| 364 |
+
|
| 365 |
+
tgt_tok_embs, tgt_masks, tgt_lengths = embed_esm_unpooled(targets, esm_tok, esm_model)
|
| 366 |
+
bnd_tok_embs, bnd_masks, bnd_lengths = embed_esm_unpooled(binders, esm_tok, esm_model)
|
| 367 |
+
|
| 368 |
+
pooled_splits[split_name] = build_pooled_ds(
|
| 369 |
+
view, COL_WT_IPTM, tgt_pooled, bnd_pooled)
|
| 370 |
+
unpooled_splits[split_name] = build_unpooled_ds(
|
| 371 |
+
view, COL_WT_IPTM,
|
| 372 |
+
tgt_tok_embs, tgt_masks, tgt_lengths,
|
| 373 |
+
bnd_tok_embs, bnd_masks, bnd_lengths)
|
| 374 |
+
|
| 375 |
+
pooled_out = OUT_ROOT / "pair_wt_wt_pooled"
|
| 376 |
+
unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled"
|
| 377 |
+
DatasetDict(pooled_splits).save_to_disk(str(pooled_out))
|
| 378 |
+
DatasetDict(unpooled_splits).save_to_disk(str(unpooled_out))
|
| 379 |
+
print(f"\n WT pooled to {pooled_out}")
|
| 380 |
+
print(f" WT unpooled to {unpooled_out}")
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
# ======================
|
| 384 |
+
# SMILES branch (ESM2 × {PeptideCLM | ChemBERTa})
|
| 385 |
+
# ======================
|
| 386 |
+
|
| 387 |
+
def run_smiles_binder_model(name: str,
|
| 388 |
+
smi_train: pd.DataFrame, smi_val: pd.DataFrame,
|
| 389 |
+
esm_tok, esm_model,
|
| 390 |
+
load_fn, embed_fn):
|
| 391 |
+
print("\n" + "="*55)
|
| 392 |
+
print(f" Branch : SMILES (ESM2 target × {name} binder)")
|
| 393 |
+
print("="*55)
|
| 394 |
+
|
| 395 |
+
binder_tok, binder_model, sid_t = load_fn()
|
| 396 |
+
pooled_splits, unpooled_splits = {}, {}
|
| 397 |
+
|
| 398 |
+
for split_name, view in [("train", smi_train), ("val", smi_val)]:
|
| 399 |
+
print(f"\n [{split_name}] {len(view)} rows")
|
| 400 |
+
targets = view["target_sequence"].tolist()
|
| 401 |
+
binders = view["sequence"].tolist()
|
| 402 |
+
|
| 403 |
+
print(" ESM2 target - pooled ...")
|
| 404 |
+
tgt_pooled = embed_esm_pooled(targets, esm_tok, esm_model)
|
| 405 |
+
|
| 406 |
+
print(" ESM2 target - unpooled ...")
|
| 407 |
+
tgt_tok_embs, tgt_masks, tgt_lengths = embed_esm_unpooled(
|
| 408 |
+
targets, esm_tok, esm_model)
|
| 409 |
+
|
| 410 |
+
print(f" {name} binder - pooled + unpooled ...")
|
| 411 |
+
bnd_pooled, bnd_tok_embs, bnd_masks, bnd_lengths = embed_fn(
|
| 412 |
+
binders, binder_tok, binder_model, sid_t)
|
| 413 |
+
|
| 414 |
+
pooled_splits[split_name] = build_pooled_ds(
|
| 415 |
+
view, COL_SMI_IPTM, tgt_pooled, bnd_pooled)
|
| 416 |
+
unpooled_splits[split_name] = build_unpooled_ds(
|
| 417 |
+
view, COL_SMI_IPTM,
|
| 418 |
+
tgt_tok_embs, tgt_masks, tgt_lengths,
|
| 419 |
+
bnd_tok_embs, bnd_masks, bnd_lengths)
|
| 420 |
+
|
| 421 |
+
suffix = "" if name.lower() == "peptideclm" else f"_{name.lower()}"
|
| 422 |
+
pooled_out = OUT_ROOT / f"pair_wt_smiles_pooled{suffix}"
|
| 423 |
+
unpooled_out = OUT_ROOT / f"pair_wt_smiles_unpooled{suffix}"
|
| 424 |
+
DatasetDict(pooled_splits).save_to_disk(str(pooled_out))
|
| 425 |
+
DatasetDict(unpooled_splits).save_to_disk(str(unpooled_out))
|
| 426 |
+
print(f"\n {name} pooled to {pooled_out}")
|
| 427 |
+
print(f" {name} unpooled to {unpooled_out}")
|
| 428 |
+
|
| 429 |
+
del binder_model
|
| 430 |
+
torch.cuda.empty_cache()
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
# ======================
|
| 434 |
# Main
|
| 435 |
+
# ======================
|
| 436 |
+
|
| 437 |
def main():
|
| 438 |
+
print(f"Device : {DEVICE}")
|
| 439 |
+
print(f"CSV : {CSV_PATH}")
|
| 440 |
+
print(f"Out : {OUT_ROOT}\n")
|
| 441 |
OUT_ROOT.mkdir(parents=True, exist_ok=True)
|
| 442 |
|
| 443 |
+
# ------------------------------------------------------------------
|
| 444 |
+
# 1. Load + dedup
|
| 445 |
+
# ------------------------------------------------------------------
|
| 446 |
with section("load csv + dedup"):
|
| 447 |
df = pd.read_csv(CSV_PATH)
|
| 448 |
+
print(f"Raw rows: {len(df)}")
|
| 449 |
+
df["orig_idx"] = df.index # traceability only
|
| 450 |
+
|
| 451 |
+
for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT, COL_MERGE]:
|
| 452 |
if c in df.columns:
|
| 453 |
df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
|
| 454 |
+
|
| 455 |
+
for col in [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]:
|
| 456 |
+
if col not in df.columns:
|
| 457 |
+
raise ValueError(f"Missing required column: '{col}'")
|
| 458 |
+
|
| 459 |
+
dedup_cols = [c for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT, COL_MERGE]
|
| 460 |
+
if c in df.columns]
|
| 461 |
+
before = len(df)
|
| 462 |
+
df = df.drop_duplicates(subset=dedup_cols, keep="first").reset_index(drop=True)
|
| 463 |
+
print(f"After dedup pass 1 (raw columns) : {len(df)} (-{before - len(df)})")
|
| 464 |
+
|
|
|
|
|
|
|
| 465 |
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
|
| 466 |
|
| 467 |
+
# ------------------------------------------------------------------
|
| 468 |
+
# 2. Prepare per-branch subsets
|
| 469 |
+
# ------------------------------------------------------------------
|
| 470 |
+
with section("prepare WT / SMILES subsets"):
|
| 471 |
+
# ── WT branch ──────────────────────────────────────────────────
|
| 472 |
+
# Both seq1 and seq2 must be canonical (no X) for ESM2
|
| 473 |
df_wt = df.copy()
|
| 474 |
df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip()
|
| 475 |
+
df_wt = df_wt.dropna(subset=[COL_AFF])
|
| 476 |
+
df_wt = df_wt[~df_wt[COL_SEQ1].astype(str).str.contains("X", case=False, na=False)]
|
| 477 |
+
df_wt = df_wt[df_wt["wt_sequence"] != ""]
|
| 478 |
+
df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)]
|
| 479 |
+
df_wt = df_wt.reset_index(drop=True)
|
| 480 |
+
|
| 481 |
+
# ── SMILES branch ──────────────────────────────────────────────
|
| 482 |
+
# seq1 must be canonical (no X) for ESM2; binder SMILES picked
|
| 483 |
+
# by priority (Fasta2SMILES > REACT_SMILES > Merge_SMILES), then
|
| 484 |
+
# dedup pass 2 on (seq1, picked smiles_sequence)
|
| 485 |
df_smi = df.copy()
|
| 486 |
+
df_smi = df_smi.dropna(subset=[COL_AFF])
|
| 487 |
df_smi = df_smi[
|
| 488 |
pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna()
|
| 489 |
+
]
|
| 490 |
+
df_smi = df_smi[~df_smi[COL_SEQ1].astype(str).str.contains("X", case=False, na=False)]
|
| 491 |
+
df_smi = df_smi.reset_index(drop=True)
|
| 492 |
+
|
| 493 |
+
df_smi["smiles_sequence"] = df_smi.apply(pick_smiles, axis=1)
|
| 494 |
+
df_smi = df_smi[df_smi["smiles_sequence"].notna()].reset_index(drop=True)
|
| 495 |
+
print(f"After requiring ≥1 valid SMILES : {len(df_smi)}")
|
| 496 |
+
|
| 497 |
+
# Dedup pass 2: (seq1, picked smiles_sequence)
|
| 498 |
+
before = len(df_smi)
|
| 499 |
+
df_smi = df_smi.drop_duplicates(
|
| 500 |
+
subset=[COL_SEQ1, "smiles_sequence"], keep="first"
|
| 501 |
+
).reset_index(drop=True)
|
| 502 |
+
print(f"After dedup pass 2 (seq1, smiles_sequence): {len(df_smi)} (-{before - len(df_smi)})")
|
| 503 |
+
|
| 504 |
+
assert not df_smi.duplicated(subset=[COL_SEQ1, "smiles_sequence"]).any(), \
|
| 505 |
+
"BUG: duplicate (seq1, smiles_sequence) pairs remain!"
|
| 506 |
+
|
| 507 |
+
print(f"\n[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)}")
|
| 508 |
+
|
| 509 |
+
# ------------------------------------------------------------------
|
| 510 |
+
# 3. Split
|
| 511 |
+
# ------------------------------------------------------------------
|
| 512 |
+
with section("split WT and SMILES separately"):
|
| 513 |
+
df_wt2 = make_distribution_matched_split(df_wt)
|
| 514 |
df_smi2 = make_distribution_matched_split(df_smi)
|
| 515 |
|
| 516 |
+
df_wt2.to_csv(OUT_ROOT / "binding_affinity_wt_meta_with_split.csv", index=False)
|
| 517 |
+
df_smi2.to_csv(OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv", index=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
|
| 519 |
+
# ------------------------------------------------------------------
|
| 520 |
+
# 4. Build split views
|
| 521 |
+
# ------------------------------------------------------------------
|
| 522 |
+
wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM)
|
| 523 |
+
smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM)
|
| 524 |
+
|
| 525 |
+
wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True)
|
| 526 |
+
wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True)
|
| 528 |
smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True)
|
| 529 |
+
|
| 530 |
+
print(f"\nSplit sizes - WT: train={len(wt_train)} val={len(wt_val)}")
|
| 531 |
+
print(f"Split sizes - SMILES: train={len(smi_train)} val={len(smi_val)}")
|
| 532 |
+
|
| 533 |
+
# ------------------------------------------------------------------
|
| 534 |
+
# 5. Load ESM2 once - shared across all branches
|
| 535 |
+
# ------------------------------------------------------------------
|
| 536 |
+
print("\nLoading ESM2 (shared target encoder) ...")
|
| 537 |
+
esm_tok, esm_model = load_esm()
|
| 538 |
+
|
| 539 |
+
# ------------------------------------------------------------------
|
| 540 |
+
# 6. WT branch
|
| 541 |
+
# ------------------------------------------------------------------
|
| 542 |
+
run_wt_branch(wt_train, wt_val, esm_tok, esm_model)
|
| 543 |
+
|
| 544 |
+
# ------------------------------------------------------------------
|
| 545 |
+
# 7. SMILES branches
|
| 546 |
+
# ------------------------------------------------------------------
|
| 547 |
+
if RUN_PEPTIDECLM:
|
| 548 |
+
run_smiles_binder_model(
|
| 549 |
+
"peptideclm", smi_train, smi_val,
|
| 550 |
+
esm_tok, esm_model,
|
| 551 |
+
load_fn=load_peptideclm,
|
| 552 |
+
embed_fn=embed_peptideclm,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
)
|
| 554 |
+
|
| 555 |
+
if RUN_CHEMBERTA:
|
| 556 |
+
run_smiles_binder_model(
|
| 557 |
+
"chemberta", smi_train, smi_val,
|
| 558 |
+
esm_tok, esm_model,
|
| 559 |
+
load_fn=load_chemberta,
|
| 560 |
+
embed_fn=embed_chemberta,
|
| 561 |
)
|
| 562 |
+
|
| 563 |
+
print(f"\n All done. Datasets saved under: {OUT_ROOT}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
|
| 565 |
|
| 566 |
if __name__ == "__main__":
|
| 567 |
+
main()
|
training_data_cleaned/embed_smiles.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pipeline:
|
| 3 |
+
1. Read *_meta_with_split.csv (sequence, label, id, split)
|
| 4 |
+
2. Convert wt sequences to SMILES via: fasta2smi -i peptides.fasta -o peptides.p2smi
|
| 5 |
+
3. Parse .p2smi format: "{seq}-linear: {SMILES}"
|
| 6 |
+
4. Embed SMILES with ChemBERTa to save pooled + unpooled DatasetDicts
|
| 7 |
+
5. Embed SMILES with PeptideCLM to save pooled + unpooled DatasetDicts
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import subprocess
|
| 12 |
+
import tempfile
|
| 13 |
+
import sys
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import pandas as pd
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from datasets import Dataset, DatasetDict
|
| 19 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
|
| 20 |
+
|
| 21 |
+
PROJECT_ROOT = "<>" # change here
|
| 22 |
+
|
| 23 |
+
# using permeability as example
|
| 24 |
+
META_CSV = (
|
| 25 |
+
f"{PROJECT_ROOT}/training_data_cleaned/"
|
| 26 |
+
"permeability_penetrance/permeability_meta_with_split.csv"
|
| 27 |
+
)
|
| 28 |
+
BASE_OUT = f"{PROJECT_ROOT}/alternative_embeddings"
|
| 29 |
+
|
| 30 |
+
# ChemBERTa
|
| 31 |
+
CHEMBERTA_MODEL = "DeepChem/ChemBERTa-77M-MLM"
|
| 32 |
+
CHEMBERTA_OUT = f"{BASE_OUT}/permeability_chemberta/perm_smiles_with_embeddings"
|
| 33 |
+
|
| 34 |
+
# PeptideCLM
|
| 35 |
+
sys.path.append(PROJECT_ROOT)
|
| 36 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 37 |
+
|
| 38 |
+
PEPTIDECLM_MODEL = "aaronfeller/PeptideCLM-23M-all"
|
| 39 |
+
PEPTIDECLM_TOKENIZER = f"{PROJECT_ROOT}/tokenizer/new_vocab.txt"
|
| 40 |
+
PEPTIDECLM_SPLITS = f"{PROJECT_ROOT}/tokenizer/new_splits.txt"
|
| 41 |
+
PEPTIDECLM_OUT = f"{BASE_OUT}/permeability_peptideclm/perm_smiles_with_embeddings"
|
| 42 |
+
|
| 43 |
+
# Column names in the CSV
|
| 44 |
+
SEQ_COL = "sequence"
|
| 45 |
+
LABEL_COL = "label"
|
| 46 |
+
SPLIT_COL = "split"
|
| 47 |
+
ID_COL = "id" # used as FASTA header; must be unique
|
| 48 |
+
|
| 49 |
+
# fasta2smi settings
|
| 50 |
+
FASTA2SMI_BIN = "fasta2smi" # install via github
|
| 51 |
+
|
| 52 |
+
# Embedding settings
|
| 53 |
+
MAX_LENGTH_CHEMBERTA = 512
|
| 54 |
+
MAX_LENGTH_PEPTIDECLM = 768
|
| 55 |
+
BATCH_SIZE = 128
|
| 56 |
+
|
| 57 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ===========================================================================================
|
| 61 |
+
# Step 1 — fasta2smi conversion, do not apply to properties that only have SMILES sequences
|
| 62 |
+
# ===========================================================================================
|
| 63 |
+
def sequences_to_smiles(sequences: list[str], ids: list[str]) -> dict[str, str]:
|
| 64 |
+
"""
|
| 65 |
+
.p2smi format produced by fasta2smi:
|
| 66 |
+
MIIFAIAASHKK-linear: N[C@@H](CCSC)C(=O)...
|
| 67 |
+
KIAKLKAKIQ...-linear: N[C@@H](CCCCN)C(=O)...
|
| 68 |
+
"""
|
| 69 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 70 |
+
fasta_path = os.path.join(tmpdir, "peptides.fasta")
|
| 71 |
+
p2smi_path = os.path.join(tmpdir, "peptides.p2smi")
|
| 72 |
+
|
| 73 |
+
with open(fasta_path, "w") as fh:
|
| 74 |
+
for sid, seq in zip(ids, sequences):
|
| 75 |
+
fh.write(f">{sid}\n{seq}\n")
|
| 76 |
+
|
| 77 |
+
cmd = [FASTA2SMI_BIN, "-i", fasta_path, "-o", p2smi_path]
|
| 78 |
+
print(f" Running: {' '.join(cmd)}")
|
| 79 |
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
| 80 |
+
if result.returncode != 0:
|
| 81 |
+
raise RuntimeError(
|
| 82 |
+
f"fasta2smi failed (exit {result.returncode}):\n"
|
| 83 |
+
f" stdout: {result.stdout}\n stderr: {result.stderr}"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
seq2smi = _parse_p2smi(p2smi_path)
|
| 87 |
+
|
| 88 |
+
n_ok = len(seq2smi)
|
| 89 |
+
n_fail = len(sequences) - n_ok
|
| 90 |
+
print(f" fasta2smi: {n_ok}/{len(sequences)} converted ({n_fail} failed/skipped)")
|
| 91 |
+
return seq2smi
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _parse_p2smi(path: str) -> dict[str, str]:
|
| 95 |
+
seq2smi: dict[str, str] = {}
|
| 96 |
+
with open(path) as fh:
|
| 97 |
+
for line in fh:
|
| 98 |
+
line = line.strip()
|
| 99 |
+
if not line or line.startswith("#"):
|
| 100 |
+
continue
|
| 101 |
+
# Split on "-linear: " — the separator fasta2smi uses
|
| 102 |
+
if "-linear: " not in line:
|
| 103 |
+
print(f" [WARN] Unexpected p2smi line, skipping: {line[:80]}")
|
| 104 |
+
continue
|
| 105 |
+
aa_seq, smi = line.split("-linear: ", maxsplit=1)
|
| 106 |
+
smi = smi.strip()
|
| 107 |
+
if smi and smi.lower() not in ("none", "null", "n/a"):
|
| 108 |
+
seq2smi[aa_seq] = smi
|
| 109 |
+
return seq2smi
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ============================================================
|
| 113 |
+
# Setups
|
| 114 |
+
# ============================================================
|
| 115 |
+
def _get_special_ids_tensor(tokenizer):
|
| 116 |
+
attrs = [
|
| 117 |
+
"pad_token_id", "cls_token_id", "sep_token_id",
|
| 118 |
+
"bos_token_id", "eos_token_id", "mask_token_id",
|
| 119 |
+
]
|
| 120 |
+
ids = sorted({getattr(tokenizer, a, None) for a in attrs} - {None})
|
| 121 |
+
return torch.tensor(ids, device=device, dtype=torch.long) if ids else None
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@torch.no_grad()
|
| 125 |
+
def _embed_batch(tokenizer, model, special_ids_t, sequences, max_length):
|
| 126 |
+
tok = tokenizer(
|
| 127 |
+
sequences, return_tensors="pt",
|
| 128 |
+
padding=True, max_length=max_length, truncation=True,
|
| 129 |
+
)
|
| 130 |
+
input_ids = tok["input_ids"].to(device)
|
| 131 |
+
attention_mask = tok["attention_mask"].to(device)
|
| 132 |
+
|
| 133 |
+
out = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 134 |
+
last_hidden = out.last_hidden_state # (B, L, H)
|
| 135 |
+
|
| 136 |
+
valid = attention_mask.bool()
|
| 137 |
+
if special_ids_t is not None:
|
| 138 |
+
valid = valid & (~torch.isin(input_ids, special_ids_t))
|
| 139 |
+
|
| 140 |
+
valid_f = valid.unsqueeze(-1).float()
|
| 141 |
+
pooled = (
|
| 142 |
+
torch.sum(last_hidden * valid_f, dim=1)
|
| 143 |
+
/ torch.clamp(valid_f.sum(dim=1), min=1e-9)
|
| 144 |
+
).cpu().numpy() # (B, H) float32
|
| 145 |
+
|
| 146 |
+
token_embs, masks, lengths = [], [], []
|
| 147 |
+
for b in range(last_hidden.shape[0]):
|
| 148 |
+
emb = last_hidden[b, valid[b]].cpu().to(torch.float16).numpy()
|
| 149 |
+
token_embs.append(emb)
|
| 150 |
+
masks.append(np.ones(emb.shape[0], dtype=np.int8))
|
| 151 |
+
lengths.append(emb.shape[0])
|
| 152 |
+
|
| 153 |
+
return pooled, token_embs, masks, lengths
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _embed_all(tokenizer, model, special_ids_t, sequences, max_length):
|
| 157 |
+
pooled_all, token_all, mask_all, len_all = [], [], [], []
|
| 158 |
+
for i in tqdm(range(0, len(sequences), BATCH_SIZE), desc=" batches"):
|
| 159 |
+
p, t, m, l = _embed_batch(
|
| 160 |
+
tokenizer, model, special_ids_t,
|
| 161 |
+
sequences[i:i+BATCH_SIZE], max_length,
|
| 162 |
+
)
|
| 163 |
+
pooled_all.append(p)
|
| 164 |
+
token_all.extend(t)
|
| 165 |
+
mask_all.extend(m)
|
| 166 |
+
len_all.extend(l)
|
| 167 |
+
return np.vstack(pooled_all), token_all, mask_all, len_all
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _build_datasets(wt_seqs, smiles, labels, tokenizer, model, special_ids_t, max_length):
|
| 171 |
+
pooled, tok_embs, masks, lengths = _embed_all(
|
| 172 |
+
tokenizer, model, special_ids_t, smiles, max_length
|
| 173 |
+
)
|
| 174 |
+
pooled_ds = Dataset.from_dict({
|
| 175 |
+
"sequence": wt_seqs,
|
| 176 |
+
"smiles": smiles,
|
| 177 |
+
"label": labels,
|
| 178 |
+
"embedding": pooled,
|
| 179 |
+
})
|
| 180 |
+
full_ds = Dataset.from_dict({
|
| 181 |
+
"sequence": wt_seqs,
|
| 182 |
+
"smiles": smiles,
|
| 183 |
+
"label": labels,
|
| 184 |
+
"embedding": tok_embs,
|
| 185 |
+
"attention_mask": masks,
|
| 186 |
+
"length": lengths,
|
| 187 |
+
})
|
| 188 |
+
return pooled_ds, full_ds
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _save(splits: dict, out_path: str):
|
| 192 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
| 193 |
+
DatasetDict({k: v[0] for k, v in splits.items()}).save_to_disk(out_path)
|
| 194 |
+
DatasetDict({k: v[1] for k, v in splits.items()}).save_to_disk(out_path + "_unpooled")
|
| 195 |
+
print(f" Saved pooled to {out_path}")
|
| 196 |
+
print(f" Saved unpooled to {out_path}_unpooled")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# ============================================================
|
| 200 |
+
# ChemBERTa
|
| 201 |
+
# ============================================================
|
| 202 |
+
def run_chemberta(meta: pd.DataFrame):
|
| 203 |
+
print(f"\n{'='*60}")
|
| 204 |
+
print(" Encoder: ChemBERTa")
|
| 205 |
+
print(f"{'='*60}")
|
| 206 |
+
|
| 207 |
+
print(f" Loading {CHEMBERTA_MODEL} ...")
|
| 208 |
+
tokenizer = AutoTokenizer.from_pretrained(CHEMBERTA_MODEL)
|
| 209 |
+
model = AutoModel.from_pretrained(CHEMBERTA_MODEL).to(device).eval()
|
| 210 |
+
special_ids_t = _get_special_ids_tensor(tokenizer)
|
| 211 |
+
|
| 212 |
+
splits: dict[str, tuple] = {}
|
| 213 |
+
for split_name in ["train", "val"]:
|
| 214 |
+
df = meta[meta[SPLIT_COL] == split_name].reset_index(drop=True)
|
| 215 |
+
print(f"\n [{split_name}] {len(df)} rows")
|
| 216 |
+
if df.empty:
|
| 217 |
+
print(" [WARN] Empty split, skipping.")
|
| 218 |
+
continue
|
| 219 |
+
pooled_ds, full_ds = _build_datasets(
|
| 220 |
+
df[SEQ_COL].tolist(), df["smiles"].tolist(),
|
| 221 |
+
df[LABEL_COL].tolist(),
|
| 222 |
+
tokenizer, model, special_ids_t, MAX_LENGTH_CHEMBERTA,
|
| 223 |
+
)
|
| 224 |
+
splits[split_name] = (pooled_ds, full_ds)
|
| 225 |
+
|
| 226 |
+
_save(splits, CHEMBERTA_OUT)
|
| 227 |
+
|
| 228 |
+
# free GPU memory before loading next model
|
| 229 |
+
del model
|
| 230 |
+
torch.cuda.empty_cache()
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# ============================================================
|
| 234 |
+
# PeptideCLM
|
| 235 |
+
# ============================================================
|
| 236 |
+
def run_peptideclm(meta: pd.DataFrame):
|
| 237 |
+
print(f"\n{'='*60}")
|
| 238 |
+
print(" Encoder: PeptideCLM")
|
| 239 |
+
print(f"{'='*60}")
|
| 240 |
+
|
| 241 |
+
print(f" Loading tokenizer from {PEPTIDECLM_TOKENIZER} ...")
|
| 242 |
+
tokenizer = SMILES_SPE_Tokenizer(PEPTIDECLM_TOKENIZER, PEPTIDECLM_SPLITS)
|
| 243 |
+
|
| 244 |
+
print(f" Loading {PEPTIDECLM_MODEL} ...")
|
| 245 |
+
full_model = AutoModelForMaskedLM.from_pretrained(PEPTIDECLM_MODEL)
|
| 246 |
+
model = full_model.roformer.to(device).eval()
|
| 247 |
+
special_ids_t = _get_special_ids_tensor(tokenizer)
|
| 248 |
+
|
| 249 |
+
splits: dict[str, tuple] = {}
|
| 250 |
+
for split_name in ["train", "val"]:
|
| 251 |
+
df = meta[meta[SPLIT_COL] == split_name].reset_index(drop=True)
|
| 252 |
+
print(f"\n [{split_name}] {len(df)} rows")
|
| 253 |
+
if df.empty:
|
| 254 |
+
print(" [WARN] Empty split, skipping.")
|
| 255 |
+
continue
|
| 256 |
+
pooled_ds, full_ds = _build_datasets(
|
| 257 |
+
df[SEQ_COL].tolist(), df["smiles"].tolist(),
|
| 258 |
+
df[LABEL_COL].tolist(),
|
| 259 |
+
tokenizer, model, special_ids_t, MAX_LENGTH_PEPTIDECLM,
|
| 260 |
+
)
|
| 261 |
+
splits[split_name] = (pooled_ds, full_ds)
|
| 262 |
+
|
| 263 |
+
_save(splits, PEPTIDECLM_OUT)
|
| 264 |
+
|
| 265 |
+
del model
|
| 266 |
+
torch.cuda.empty_cache()
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# ============================================================
|
| 270 |
+
# Main
|
| 271 |
+
# ============================================================
|
| 272 |
+
def main():
|
| 273 |
+
print(f"\nDevice : {device}")
|
| 274 |
+
print(f"Meta : {META_CSV}")
|
| 275 |
+
|
| 276 |
+
# Load metadata
|
| 277 |
+
meta = pd.read_csv(META_CSV, sep=None, engine="python")
|
| 278 |
+
print(f"Loaded {len(meta)} rows. Columns: {meta.columns.tolist()}")
|
| 279 |
+
for col in [SEQ_COL, LABEL_COL, SPLIT_COL]:
|
| 280 |
+
if col not in meta.columns:
|
| 281 |
+
raise ValueError(f"Expected column '{col}' not found. Available: {meta.columns.tolist()}")
|
| 282 |
+
|
| 283 |
+
# Ensure numeric labels
|
| 284 |
+
meta[LABEL_COL] = pd.to_numeric(meta[LABEL_COL], errors="coerce")
|
| 285 |
+
meta = meta.dropna(subset=[SEQ_COL, LABEL_COL]).reset_index(drop=True)
|
| 286 |
+
|
| 287 |
+
# Build id list for FASTA headers
|
| 288 |
+
if ID_COL in meta.columns:
|
| 289 |
+
ids = meta[ID_COL].astype(str).tolist()
|
| 290 |
+
else:
|
| 291 |
+
ids = [f"seq_{i}" for i in range(len(meta))]
|
| 292 |
+
|
| 293 |
+
# Note that for properties start with SMILES sequences, fasta2smi is not needed
|
| 294 |
+
# Convert wt to SMILES (single fasta2smi call for the whole dataset)
|
| 295 |
+
print("\nConverting peptide sequences to SMILES ...")
|
| 296 |
+
seqs = meta[SEQ_COL].astype(str).tolist()
|
| 297 |
+
seq2smi = sequences_to_smiles(seqs, ids)
|
| 298 |
+
|
| 299 |
+
meta["smiles"] = meta[SEQ_COL].astype(str).map(seq2smi)
|
| 300 |
+
n_missing = meta["smiles"].isna().sum()
|
| 301 |
+
if n_missing:
|
| 302 |
+
print(f" [WARN] {n_missing} sequences had no SMILES — dropping.")
|
| 303 |
+
meta = meta.dropna(subset=["smiles"]).reset_index(drop=True)
|
| 304 |
+
print(f" Retained {len(meta)} rows with valid SMILES.")
|
| 305 |
+
# Save SMILES-enriched meta CSV
|
| 306 |
+
smiles_meta_path = os.path.join(BASE_OUT, "permeability_smiles_meta_with_split.csv")
|
| 307 |
+
os.makedirs(BASE_OUT, exist_ok=True)
|
| 308 |
+
meta.to_csv(smiles_meta_path, index=False)
|
| 309 |
+
print(f" Saved SMILES meta to {smiles_meta_path}")
|
| 310 |
+
|
| 311 |
+
# Run both encoders sequentially (share the same converted SMILES)
|
| 312 |
+
#run_chemberta(meta)
|
| 313 |
+
#run_peptideclm(meta)
|
| 314 |
+
|
| 315 |
+
print("\nAll done.")
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
if __name__ == "__main__":
|
| 319 |
+
main()
|
training_data_cleaned/permeability_penetrance/permeability_smiles_meta_with_split.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cbece0b3b8345cae1ce6fe2e9a1a10ddd5320bae18c3a7a3f958b97b98979796
|
| 3 |
+
size 947525
|
training_data_cleaned/smiles_data_split.py
CHANGED
|
@@ -15,6 +15,7 @@ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
|
| 15 |
|
| 16 |
seed_everything(1986)
|
| 17 |
|
|
|
|
| 18 |
df = pd.read_csv("caco2.csv")
|
| 19 |
|
| 20 |
mols = []
|
|
@@ -87,151 +88,4 @@ df[df["split"] == "train"].to_csv("caco2_train.csv", index=False)
|
|
| 87 |
df[df["split"] == "val"].to_csv("caco2_val.csv", index=False)
|
| 88 |
df.to_csv("caco2_meta_with_split.csv", index=False)
|
| 89 |
|
| 90 |
-
print(df["split"].value_counts())
|
| 91 |
-
|
| 92 |
-
# ======================
|
| 93 |
-
# Config
|
| 94 |
-
# ======================
|
| 95 |
-
MAX_LENGTH = 768
|
| 96 |
-
BATCH_SIZE = 128
|
| 97 |
-
|
| 98 |
-
TRAIN_CSV = "caco2_train.csv"
|
| 99 |
-
VAL_CSV = "caco2_val.csv"
|
| 100 |
-
|
| 101 |
-
SMILES_COL = "SMILES"
|
| 102 |
-
LABEL_COL = "Caco2"
|
| 103 |
-
|
| 104 |
-
OUT_PATH = "./Classifier_Weight/training_data_cleaned/permeability_caco2/caco2_smiles_with_embeddings"
|
| 105 |
-
|
| 106 |
-
# GPU device
|
| 107 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 108 |
-
print(f"Using device: {device}")
|
| 109 |
-
|
| 110 |
-
# ======================
|
| 111 |
-
# Load tokenizer + model
|
| 112 |
-
# ======================
|
| 113 |
-
print("Loading tokenizer and model...")
|
| 114 |
-
tokenizer = SMILES_SPE_Tokenizer(
|
| 115 |
-
"./Classifier_Weight/tokenizer/new_vocab.txt",
|
| 116 |
-
"./Classifier_Weight/tokenizer/new_splits.txt",
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
embedding_model = AutoModelForMaskedLM.from_pretrained("aaronfeller/PeptideCLM-23M-all").roformer
|
| 120 |
-
embedding_model.to(device)
|
| 121 |
-
embedding_model.eval()
|
| 122 |
-
|
| 123 |
-
HIDDEN_KEY = "last_hidden_state"
|
| 124 |
-
|
| 125 |
-
def get_special_ids(tokenizer):
|
| 126 |
-
cand = [
|
| 127 |
-
getattr(tokenizer, "pad_token_id", None),
|
| 128 |
-
getattr(tokenizer, "cls_token_id", None),
|
| 129 |
-
getattr(tokenizer, "sep_token_id", None),
|
| 130 |
-
getattr(tokenizer, "bos_token_id", None),
|
| 131 |
-
getattr(tokenizer, "eos_token_id", None),
|
| 132 |
-
getattr(tokenizer, "mask_token_id", None),
|
| 133 |
-
]
|
| 134 |
-
special_ids = sorted({x for x in cand if x is not None})
|
| 135 |
-
if len(special_ids) == 0:
|
| 136 |
-
print("[WARN] No special token ids found on tokenizer; pooling will only exclude padding via attention_mask.")
|
| 137 |
-
return special_ids
|
| 138 |
-
|
| 139 |
-
SPECIAL_IDS = get_special_ids(tokenizer)
|
| 140 |
-
SPECIAL_IDS_T = torch.tensor(SPECIAL_IDS, device=device, dtype=torch.long) if len(SPECIAL_IDS) else None
|
| 141 |
-
|
| 142 |
-
@torch.no_grad()
|
| 143 |
-
def embed_batch_return_both(batch_sequences, max_length, device):
|
| 144 |
-
tok = tokenizer(
|
| 145 |
-
batch_sequences,
|
| 146 |
-
return_tensors="pt",
|
| 147 |
-
padding=True,
|
| 148 |
-
max_length=max_length,
|
| 149 |
-
truncation=True,
|
| 150 |
-
)
|
| 151 |
-
input_ids = tok["input_ids"].to(device) # (B, L)
|
| 152 |
-
attention_mask = tok["attention_mask"].to(device) # (B, L)
|
| 153 |
-
|
| 154 |
-
outputs = embedding_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 155 |
-
last_hidden = outputs.last_hidden_state # (B, L, H)
|
| 156 |
-
|
| 157 |
-
valid = attention_mask.bool()
|
| 158 |
-
if SPECIAL_IDS_T is not None and SPECIAL_IDS_T.numel() > 0:
|
| 159 |
-
valid = valid & (~torch.isin(input_ids, SPECIAL_IDS_T))
|
| 160 |
-
|
| 161 |
-
# --- pooled embeddings (exclude specials) ---
|
| 162 |
-
valid_f = valid.unsqueeze(-1).float() # (B, L, 1)
|
| 163 |
-
summed = torch.sum(last_hidden * valid_f, dim=1) # (B, H)
|
| 164 |
-
denom = torch.clamp(valid_f.sum(dim=1), min=1e-9) # (B, 1)
|
| 165 |
-
pooled = (summed / denom).detach().cpu().numpy() # (B, H), float32
|
| 166 |
-
|
| 167 |
-
# --- unpooled per-example token embeddings (exclude specials) ---
|
| 168 |
-
token_emb_list = []
|
| 169 |
-
mask_list = []
|
| 170 |
-
lengths = []
|
| 171 |
-
for b in range(last_hidden.shape[0]):
|
| 172 |
-
emb = last_hidden[b, valid[b]] # (L_i, H)
|
| 173 |
-
token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy()) # float16
|
| 174 |
-
L_i = emb.shape[0]
|
| 175 |
-
lengths.append(int(L_i))
|
| 176 |
-
mask_list.append(np.ones((L_i,), dtype=np.int8))
|
| 177 |
-
|
| 178 |
-
return pooled, token_emb_list, mask_list, lengths
|
| 179 |
-
|
| 180 |
-
def generate_embeddings_batched_both(sequences, batch_size, max_length):
|
| 181 |
-
pooled_all = []
|
| 182 |
-
token_emb_all = []
|
| 183 |
-
mask_all = []
|
| 184 |
-
lengths_all = []
|
| 185 |
-
|
| 186 |
-
for i in tqdm(range(0, len(sequences), batch_size), desc="Embedding batches"):
|
| 187 |
-
batch = sequences[i:i + batch_size]
|
| 188 |
-
pooled, token_list, m_list, lens = embed_batch_return_both(batch, max_length, device)
|
| 189 |
-
pooled_all.append(pooled)
|
| 190 |
-
token_emb_all.extend(token_list)
|
| 191 |
-
mask_all.extend(m_list)
|
| 192 |
-
lengths_all.extend(lens)
|
| 193 |
-
|
| 194 |
-
pooled_all = np.vstack(pooled_all) # (N, H)
|
| 195 |
-
return pooled_all, token_emb_all, mask_all, lengths_all
|
| 196 |
-
|
| 197 |
-
from datasets import Dataset, DatasetDict
|
| 198 |
-
|
| 199 |
-
def make_split_datasets(csv_path, split_name):
|
| 200 |
-
df = pd.read_csv(csv_path)
|
| 201 |
-
df = df.dropna(subset=[SMILES_COL, LABEL_COL]).reset_index(drop=True)
|
| 202 |
-
df["sequence"] = df[SMILES_COL].astype(str)
|
| 203 |
-
|
| 204 |
-
labels = pd.to_numeric(df[LABEL_COL], errors="coerce")
|
| 205 |
-
df = df.loc[~labels.isna()].reset_index(drop=True)
|
| 206 |
-
sequences = df["sequence"].tolist()
|
| 207 |
-
labels = pd.to_numeric(df[LABEL_COL], errors="coerce").tolist()
|
| 208 |
-
|
| 209 |
-
# (pooled_embs: (N,H), token_emb_list: list of (L_i,H), mask_list: list of (L_i,), lengths: list[int])
|
| 210 |
-
pooled_embs, token_emb_list, mask_list, lengths = generate_embeddings_batched_both(
|
| 211 |
-
sequences, batch_size=BATCH_SIZE, max_length=MAX_LENGTH
|
| 212 |
-
)
|
| 213 |
-
|
| 214 |
-
pooled_ds = Dataset.from_dict({
|
| 215 |
-
"sequence": sequences,
|
| 216 |
-
"label": labels,
|
| 217 |
-
"embedding": pooled_embs, # (N,H)
|
| 218 |
-
})
|
| 219 |
-
|
| 220 |
-
full_ds = Dataset.from_dict({
|
| 221 |
-
"sequence": sequences,
|
| 222 |
-
"label": labels,
|
| 223 |
-
"embedding": token_emb_list, # each (L_i,H) float16
|
| 224 |
-
"attention_mask": mask_list, # each (L_i,) int8 ones
|
| 225 |
-
"length": lengths,
|
| 226 |
-
})
|
| 227 |
-
|
| 228 |
-
return pooled_ds, full_ds
|
| 229 |
-
|
| 230 |
-
train_pooled, train_full = make_split_datasets(TRAIN_CSV, "train")
|
| 231 |
-
val_pooled, val_full = make_split_datasets(VAL_CSV, "val")
|
| 232 |
-
|
| 233 |
-
ds_pooled = DatasetDict({"train": train_pooled, "val": val_pooled})
|
| 234 |
-
ds_full = DatasetDict({"train": train_full, "val": val_full})
|
| 235 |
-
|
| 236 |
-
ds_pooled.save_to_disk(OUT_PATH)
|
| 237 |
-
ds_full.save_to_disk(OUT_PATH + "_unpooled")
|
|
|
|
| 15 |
|
| 16 |
seed_everything(1986)
|
| 17 |
|
| 18 |
+
# Starting with a raw dataframe, using caco2 as example.
|
| 19 |
df = pd.read_csv("caco2.csv")
|
| 20 |
|
| 21 |
mols = []
|
|
|
|
| 88 |
df[df["split"] == "val"].to_csv("caco2_val.csv", index=False)
|
| 89 |
df.to_csv("caco2_meta_with_split.csv", index=False)
|
| 90 |
|
| 91 |
+
print(df["split"].value_counts())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|