Joblib
ynuozhang commited on
Commit
04c2975
·
1 Parent(s): d2421cb

major update

Browse files
basic_models.txt CHANGED
@@ -1,10 +1,10 @@
1
  Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
2
- Hemolysis, XGB, Transformer, Classifier, 0.2521, 0.4343 ,
3
- Non-Fouling, MLP, XGB, Classifier, 0.57, 0.6969,
4
- Solubility, CNN, -, Classifier, 0.377, -,
5
- Permeability (Penetrance), XGB, -, Classifier, 0.5493, -,
6
- Toxicity, -, Transformer, Classifier, -, 0.3401,
7
- Binding_affinity, unpooled, unpooled, Regression, -, -,
8
- Permeability_PAMPA, -, CNN, Regression, -, -,
9
- Permeability_CACO2, -, SVR, Regression, -, -,
10
- Halflife, Transformer, XGB, Regression, -, -,
 
1
  Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
2
+ Hemolysis, XGB, CNN (chemberta), Classifier, 0.2801, 0.564,
3
+ Non-Fouling, Transformer, XGB (peptideclm), Classifier, 0.57, 0.3892,
4
+ Solubility, CNN, Transformer (peptideclm), Classifier, 0.377, 0.329,
5
+ Permeability (Penetrance), XGB, XGB (chemberta), Classifier, 0.4301, 0.5028,
6
+ Toxicity, -, CNN (chemberta), Classifier, -, 0.49,
7
+ Binding_affinity, wt_wt_pooled, chemberta_smiles_pooled, Regression, -, -,
8
+ Permeability_PAMPA, -, CNN (chemberta), Regression, -, -,
9
+ Permeability_CACO2, -, SVR (chemberta), Regression, -, -,
10
+ Halflife, Transformer, XGB (peptideclm), Regression, -, -,
best_models.txt CHANGED
@@ -1,10 +1,10 @@
1
  Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
2
- Hemolysis, SVM, Transformer, Classifier, 0.2521, 0.4343 ,
3
- Non-Fouling, MLP, ENET, Classifier, 0.57, 0.6969,
4
- Solubility, CNN, -, Classifier, 0.377, -,
5
- Permeability (Penetrance), SVM, -, Classifier, 0.5493, -,
6
- Toxicity, -, Transformer, Classifier, -, 0.3401,
7
- Binding_affinity, unpooled, unpooled, Regression, -, -,
8
- Permeability_PAMPA, -, CNN, Regression, -, -,
9
- Permeability_CACO2, -, SVR, Regression, -, -,
10
- Halflife, Transformer, XGB, Regression, -, -,
 
1
  Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
2
+ Hemolysis, SVM, CNN (chemberta), Classifier, 0.2521, 0.564,
3
+ Non-Fouling, Transformer, ENET (peptideclm), Classifier, 0.57, 0.6969,
4
+ Solubility, CNN, Transformer (peptideclm), Classifier, 0.377, 0.329,
5
+ Permeability (Penetrance), SVM, SVM (chemberta), Classifier, 0.5493, 0.573,
6
+ Toxicity, -, CNN (chemberta), Classifier, -, 0.49,
7
+ Binding_affinity, wt_wt_pooled, chemberta_smiles_pooled, Regression, -, -,
8
+ Permeability_PAMPA, -, CNN (chemberta), Regression, -, -,
9
+ Permeability_CACO2, -, SVR (chemberta), Regression, -, -,
10
+ Halflife, Transformer, XGB (peptideclm), Regression, -, -,
inference.py CHANGED
@@ -1,16 +1,13 @@
1
  from __future__ import annotations
2
-
3
  import csv, re, json
4
  from dataclasses import dataclass
5
  from pathlib import Path
6
  from typing import Dict, Optional, Tuple, Any, List
7
-
8
  import numpy as np
9
  import torch
10
  import torch.nn as nn
11
  import joblib
12
  import xgboost as xgb
13
-
14
  from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM
15
  from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
16
  from lightning.pytorch import seed_everything
@@ -19,13 +16,31 @@ seed_everything(1986)
19
  # -----------------------------
20
  # Manifest
21
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  @dataclass(frozen=True)
23
  class BestRow:
24
  property_key: str
25
- best_wt: Optional[str]
26
- best_smiles: Optional[str]
27
- task_type: str # "Classifier" or "Regression"
28
- thr_wt: Optional[float]
29
  thr_smiles: Optional[float]
30
 
31
 
@@ -34,21 +49,16 @@ def _clean(s: str) -> str:
34
 
35
  def _none_if_dash(s: str) -> Optional[str]:
36
  s = _clean(s)
37
- if s in {"", "-", "", "NA", "N/A"}:
38
- return None
39
- return s
40
 
41
  def _float_or_none(s: str) -> Optional[float]:
42
  s = _clean(s)
43
- if s in {"", "-", "", "NA", "N/A"}:
44
- return None
45
- return float(s)
46
 
47
  def normalize_property_key(name: str) -> str:
48
  n = name.strip().lower()
49
  n = re.sub(r"\s*\(.*?\)\s*", "", n)
50
  n = n.replace("-", "_").replace(" ", "_")
51
-
52
  if "permeability" in n and "pampa" not in n and "caco" not in n:
53
  return "permeability_penetrance"
54
  if n == "binding_affinity":
@@ -60,11 +70,40 @@ def normalize_property_key(name: str) -> str:
60
  return n
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
64
- """
65
- Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
66
- Hemolysis, SVM, SGB, Classifier, 0.2801, 0.2223,
67
- """
68
  p = Path(path)
69
  out: Dict[str, BestRow] = {}
70
 
@@ -90,10 +129,13 @@ def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
90
  continue
91
  prop_key = normalize_property_key(prop_raw)
92
 
 
 
 
93
  row = BestRow(
94
  property_key=prop_key,
95
- best_wt=_none_if_dash(rec.get("Best_Model_WT", "")),
96
- best_smiles=_none_if_dash(rec.get("Best_Model_SMILES", "")),
97
  task_type=_clean(rec.get("Type", "Classifier")),
98
  thr_wt=_float_or_none(rec.get("Threshold_WT", "")),
99
  thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")),
@@ -103,53 +145,32 @@ def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
103
  return out
104
 
105
 
106
- MODEL_ALIAS = {
107
- "SVM": "svm_gpu",
108
- "SVR": "svr",
109
- "ENET": "enet_gpu",
110
- "CNN": "cnn",
111
- "MLP": "mlp",
112
- "TRANSFORMER": "transformer",
113
- "XGB": "xgb",
114
- "XGB_REG": "xgb_reg",
115
- "POOLED": "pooled",
116
- "UNPOOLED": "unpooled",
117
- "TRANSFORMER_WT_LOG": "transformer_wt_log",
118
- }
119
- def canon_model(label: Optional[str]) -> Optional[str]:
120
- if label is None:
121
- return None
122
- k = label.strip().upper()
123
- return MODEL_ALIAS.get(k, label.strip().lower())
124
-
125
-
126
  # -----------------------------
127
  # Generic artifact loading
128
  # -----------------------------
129
  def find_best_artifact(model_dir: Path) -> Path:
130
- for pat in ["best_model.json", "best_model.pt", "best_model*.joblib"]:
 
131
  hits = sorted(model_dir.glob(pat))
132
  if hits:
133
  return hits[0]
 
 
 
134
  raise FileNotFoundError(f"No best_model artifact found in {model_dir}")
135
 
136
  def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]:
137
  art = find_best_artifact(model_dir)
138
-
139
  if art.suffix == ".json":
140
  booster = xgb.Booster()
141
- #print(str(art))
142
  booster.load_model(str(art))
143
  return "xgb", booster, art
144
-
145
  if art.suffix == ".joblib":
146
  obj = joblib.load(art)
147
  return "joblib", obj, art
148
-
149
  if art.suffix == ".pt":
150
  ckpt = torch.load(art, map_location=device, weights_only=False)
151
  return "torch_ckpt", ckpt, art
152
-
153
  raise ValueError(f"Unknown artifact type: {art}")
154
 
155
 
@@ -157,7 +178,7 @@ def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path
157
  # NN architectures
158
  # -----------------------------
159
  class MaskedMeanPool(nn.Module):
160
- def forward(self, X, M): # X:(B,L,H), M:(B,L)
161
  Mf = M.unsqueeze(-1).float()
162
  denom = Mf.sum(dim=1).clamp(min=1.0)
163
  return (X * Mf).sum(dim=1) / denom
@@ -167,34 +188,25 @@ class MLPHead(nn.Module):
167
  super().__init__()
168
  self.pool = MaskedMeanPool()
169
  self.net = nn.Sequential(
170
- nn.Linear(in_dim, hidden),
171
- nn.GELU(),
172
- nn.Dropout(dropout),
173
  nn.Linear(hidden, 1),
174
  )
175
  def forward(self, X, M):
176
- z = self.pool(X, M)
177
- return self.net(z).squeeze(-1)
178
 
179
  class CNNHead(nn.Module):
180
  def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
181
  super().__init__()
182
- blocks = []
183
- ch = in_ch
184
  for _ in range(layers):
185
- blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
186
- nn.GELU(),
187
- nn.Dropout(dropout)]
188
  ch = c
189
  self.conv = nn.Sequential(*blocks)
190
  self.head = nn.Linear(c, 1)
191
-
192
  def forward(self, X, M):
193
- Xc = X.transpose(1, 2) # (B,H,L)
194
- Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
195
  Mf = M.unsqueeze(-1).float()
196
- denom = Mf.sum(dim=1).clamp(min=1.0)
197
- pooled = (Y * Mf).sum(dim=1) / denom
198
  return self.head(pooled).squeeze(-1)
199
 
200
  class TransformerHead(nn.Module):
@@ -207,55 +219,36 @@ class TransformerHead(nn.Module):
207
  )
208
  self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
209
  self.head = nn.Linear(d_model, 1)
210
-
211
  def forward(self, X, M):
212
- pad_mask = ~M
213
- Z = self.proj(X)
214
- Z = self.enc(Z, src_key_padding_mask=pad_mask)
215
  Mf = M.unsqueeze(-1).float()
216
- denom = Mf.sum(dim=1).clamp(min=1.0)
217
- pooled = (Z * Mf).sum(dim=1) / denom
218
  return self.head(pooled).squeeze(-1)
219
 
220
  def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int:
221
- if model_name == "mlp":
222
- return int(sd["net.0.weight"].shape[1])
223
- if model_name == "cnn":
224
- return int(sd["conv.0.weight"].shape[1])
225
- if model_name == "transformer":
226
- return int(sd["proj.weight"].shape[1])
227
  raise ValueError(model_name)
228
 
229
  def _infer_num_layers_from_sd(sd: dict, prefix: str = "enc.layers.") -> int:
230
- # enc.layers.0.*, enc.layers.1.*, ...
231
  idxs = set()
232
  for k in sd.keys():
233
  if k.startswith(prefix):
234
- rest = k[len(prefix):]
235
- m = re.match(r"(\d+)\.", rest)
236
  if m:
237
  idxs.add(int(m.group(1)))
238
  return (max(idxs) + 1) if idxs else 1
239
 
240
  def _infer_transformer_arch_from_sd(sd: dict) -> Tuple[int, int, int]:
241
- """
242
- Returns (d_model, layers, ff) inferred from weights.
243
- - d_model from proj.weight (shape: [d_model, in_dim])
244
- - layers from count of enc.layers.*
245
- - ff from enc.layers.0.linear1.weight (shape: [ff, d_model])
246
- """
247
  if "proj.weight" not in sd:
248
- raise KeyError("Missing proj.weight in state_dict; cannot infer transformer d_model.")
249
  d_model = int(sd["proj.weight"].shape[0])
250
- layers = _infer_num_layers_from_sd(sd, prefix="enc.layers.")
251
- if "enc.layers.0.linear1.weight" in sd:
252
- ff = int(sd["enc.layers.0.linear1.weight"].shape[0])
253
- else:
254
- ff = 4 * d_model
255
  return d_model, layers, ff
256
 
257
  def _pick_nhead(d_model: int) -> int:
258
- # prefer common head counts; must divide d_model
259
  for h in (8, 6, 4, 3, 2, 1):
260
  if d_model % h == 0:
261
  return h
@@ -263,7 +256,7 @@ def _pick_nhead(d_model: int) -> int:
263
 
264
  def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module:
265
  params = ckpt["best_params"]
266
- sd = ckpt["state_dict"]
267
  in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name)))
268
  dropout = float(params.get("dropout", 0.1))
269
 
@@ -273,44 +266,127 @@ def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.devic
273
  model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
274
  layers=int(params["layers"]), dropout=dropout)
275
  elif model_name == "transformer":
276
- # if transfer-learning ckpt omits arch params, infer from state_dict. special case for transformer_wt_log
277
  d_model = params.get("d_model") or params.get("hidden") or params.get("hidden_dim")
278
-
279
  if d_model is None:
280
  d_model_i, layers_i, ff_i = _infer_transformer_arch_from_sd(sd)
281
  nhead_i = _pick_nhead(d_model_i)
282
  model = TransformerHead(
283
- in_dim=in_dim,
284
- d_model=int(d_model_i),
285
- nhead=int(params.get("nhead", nhead_i)),
286
- layers=int(params.get("layers", layers_i)),
287
- ff=int(params.get("ff", ff_i)),
288
  dropout=float(params.get("dropout", dropout)),
289
  )
290
  else:
291
  d_model = int(d_model)
292
  model = TransformerHead(
293
- in_dim=in_dim,
294
- d_model=d_model,
295
  nhead=int(params.get("nhead", _pick_nhead(d_model))),
296
  layers=int(params.get("layers", 2)),
297
  ff=int(params.get("ff", 4 * d_model)),
298
- dropout=dropout
299
  )
300
  else:
301
  raise ValueError(f"Unknown NN model_name={model_name}")
302
 
303
  model.load_state_dict(sd)
304
- model.to(device)
305
- model.eval()
306
  return model
307
 
308
 
309
  # -----------------------------
310
- # Binding affinity models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  def affinity_to_class(y: float) -> int:
313
- # 0=High(>=9), 1=Moderate(7-9), 2=Low(<7)
314
  if y >= 9.0: return 0
315
  if y < 7.0: return 2
316
  return 1
@@ -320,38 +396,31 @@ class CrossAttnPooled(nn.Module):
320
  super().__init__()
321
  self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
322
  self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
323
-
324
  self.layers = nn.ModuleList([])
325
  for _ in range(n_layers):
326
  self.layers.append(nn.ModuleDict({
327
  "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
328
  "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
329
- "n1t": nn.LayerNorm(hidden),
330
- "n2t": nn.LayerNorm(hidden),
331
- "n1b": nn.LayerNorm(hidden),
332
- "n2b": nn.LayerNorm(hidden),
333
  "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
334
  "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
335
  }))
336
-
337
  self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
338
  self.reg = nn.Linear(hidden, 1)
339
  self.cls = nn.Linear(hidden, 3)
340
 
341
  def forward(self, t_vec, b_vec):
342
- t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
343
- b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
344
  for L in self.layers:
345
  t_attn, _ = L["attn_tb"](t, b, b)
346
  t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
347
  t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
348
-
349
  b_attn, _ = L["attn_bt"](b, t, t)
350
  b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
351
  b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
352
-
353
- z = torch.cat([t[0], b[0]], dim=-1)
354
- h = self.shared(z)
355
  return self.reg(h).squeeze(-1), self.cls(h)
356
 
357
  class CrossAttnUnpooled(nn.Module):
@@ -359,334 +428,247 @@ class CrossAttnUnpooled(nn.Module):
359
  super().__init__()
360
  self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
361
  self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
362
-
363
  self.layers = nn.ModuleList([])
364
  for _ in range(n_layers):
365
  self.layers.append(nn.ModuleDict({
366
  "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
367
  "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
368
- "n1t": nn.LayerNorm(hidden),
369
- "n2t": nn.LayerNorm(hidden),
370
- "n1b": nn.LayerNorm(hidden),
371
- "n2b": nn.LayerNorm(hidden),
372
  "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
373
  "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
374
  }))
375
-
376
  self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
377
  self.reg = nn.Linear(hidden, 1)
378
  self.cls = nn.Linear(hidden, 3)
379
 
380
  def _masked_mean(self, X, M):
381
  Mf = M.unsqueeze(-1).float()
382
- denom = Mf.sum(dim=1).clamp(min=1.0)
383
- return (X * Mf).sum(dim=1) / denom
384
 
385
  def forward(self, T, Mt, B, Mb):
386
- T = self.t_proj(T)
387
- Bx = self.b_proj(B)
388
- kp_t = ~Mt
389
- kp_b = ~Mb
390
-
391
  for L in self.layers:
392
  T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
393
- T = L["n1t"](T + T_attn)
394
- T = L["n2t"](T + L["fft"](T))
395
-
396
  B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
397
- Bx = L["n1b"](Bx + B_attn)
398
- Bx = L["n2b"](Bx + L["ffb"](Bx))
399
-
400
- t_pool = self._masked_mean(T, Mt)
401
- b_pool = self._masked_mean(Bx, Mb)
402
- z = torch.cat([t_pool, b_pool], dim=-1)
403
- h = self.shared(z)
404
  return self.reg(h).squeeze(-1), self.cls(h)
405
 
406
  def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module:
407
  ckpt = torch.load(best_model_pt, map_location=device, weights_only=False)
408
  params = ckpt["best_params"]
409
- sd = ckpt["state_dict"]
410
-
411
- # infer Ht/Hb from projection weights
412
  Ht = int(sd["t_proj.0.weight"].shape[1])
413
  Hb = int(sd["b_proj.0.weight"].shape[1])
414
-
415
- common = dict(
416
- Ht=Ht, Hb=Hb,
417
- hidden=int(params["hidden_dim"]),
418
- n_heads=int(params["n_heads"]),
419
- n_layers=int(params["n_layers"]),
420
- dropout=float(params["dropout"]),
421
- )
422
-
423
- if pooled_or_unpooled == "pooled":
424
- model = CrossAttnPooled(**common)
425
- elif pooled_or_unpooled == "unpooled":
426
- model = CrossAttnUnpooled(**common)
427
- else:
428
- raise ValueError(pooled_or_unpooled)
429
-
430
  model.load_state_dict(sd)
431
- model.to(device).eval()
432
- return model
433
 
434
 
435
  # -----------------------------
436
  # Embedding generation
437
  # -----------------------------
438
  def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor:
439
- """
440
- Pytorch patch
441
- """
442
  if hasattr(torch, "isin"):
443
  return torch.isin(ids, test_ids)
444
- # Fallback: compare against each special id
445
- # (B,L,1) == (1,1,K) -> (B,L,K)
446
  return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1)
447
-
448
  class SMILESEmbedder:
449
- """
450
- PeptideCLM RoFormer embeddings for SMILES.
451
- - pooled(): mean over tokens where attention_mask==1 AND token_id not in SPECIAL_IDS
452
- - unpooled(): returns token embeddings filtered to valid tokens (specials removed),
453
- plus a 1-mask of length Li (since already filtered).
454
- """
455
- def __init__(
456
- self,
457
- device: torch.device,
458
- vocab_path: str,
459
- splits_path: str,
460
- clm_name: str = "aaronfeller/PeptideCLM-23M-all",
461
- max_len: int = 512,
462
- use_cache: bool = True,
463
- ):
464
  self.device = device
465
  self.max_len = max_len
466
  self.use_cache = use_cache
467
-
468
  self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path)
469
  self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval()
470
-
471
  self.special_ids = self._get_special_ids(self.tokenizer)
472
  self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
473
- if len(self.special_ids) else None)
474
-
475
  self._cache_pooled: Dict[str, torch.Tensor] = {}
476
  self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
477
 
478
  @staticmethod
479
  def _get_special_ids(tokenizer) -> List[int]:
480
- cand = [
481
- getattr(tokenizer, "pad_token_id", None),
482
- getattr(tokenizer, "cls_token_id", None),
483
- getattr(tokenizer, "sep_token_id", None),
484
- getattr(tokenizer, "bos_token_id", None),
485
- getattr(tokenizer, "eos_token_id", None),
486
- getattr(tokenizer, "mask_token_id", None),
487
- ]
488
  return sorted({int(x) for x in cand if x is not None})
489
 
490
- def _tokenize(self, smiles_list: List[str]) -> Dict[str, torch.Tensor]:
491
- tok = self.tokenizer(
492
- smiles_list,
493
- return_tensors="pt",
494
- padding=True,
495
- truncation=True,
496
- max_length=self.max_len,
497
- )
498
- for k in tok:
499
- tok[k] = tok[k].to(self.device)
500
  if "attention_mask" not in tok:
501
  tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
502
  return tok
503
 
 
 
 
 
 
 
504
  @torch.no_grad()
505
  def pooled(self, smiles: str) -> torch.Tensor:
506
  s = smiles.strip()
507
- if self.use_cache and s in self._cache_pooled:
508
- return self._cache_pooled[s]
 
 
 
 
 
 
509
 
 
 
 
 
510
  tok = self._tokenize([s])
511
- ids = tok["input_ids"] # (1,L)
512
- attn = tok["attention_mask"].bool() # (1,L)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
 
514
- out = self.model(input_ids=ids, attention_mask=tok["attention_mask"])
515
- h = out.last_hidden_state # (1,L,H)
 
 
 
 
 
516
 
517
- valid = attn
 
518
  if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
519
  valid = valid & (~_safe_isin(ids, self.special_ids_t))
 
520
 
 
 
 
 
 
 
 
521
  vf = valid.unsqueeze(-1).float()
522
- summed = (h * vf).sum(dim=1) # (1,H)
523
- denom = vf.sum(dim=1).clamp(min=1e-9) # (1,1)
524
- pooled = summed / denom # (1,H)
525
-
526
- if self.use_cache:
527
- self._cache_pooled[s] = pooled
528
  return pooled
529
 
530
  @torch.no_grad()
531
  def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]:
532
- """
533
- Returns:
534
- X: (1, Li, H) float32 on device
535
- M: (1, Li) bool on device
536
- where Li excludes padding + special tokens.
537
- """
538
  s = smiles.strip()
539
- if self.use_cache and s in self._cache_unpooled:
540
- return self._cache_unpooled[s]
541
-
542
  tok = self._tokenize([s])
543
- ids = tok["input_ids"] # (1,L)
544
- attn = tok["attention_mask"].bool() # (1,L)
545
-
546
- out = self.model(input_ids=ids, attention_mask=tok["attention_mask"])
547
- h = out.last_hidden_state # (1,L,H)
548
-
549
- valid = attn
550
- if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
551
- valid = valid & (~_safe_isin(ids, self.special_ids_t))
552
-
553
- # filter valid tokens
554
- keep = valid[0] # (L,)
555
- X = h[:, keep, :] # (1,Li,H)
556
  M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
557
-
558
- if self.use_cache:
559
- self._cache_unpooled[s] = (X, M)
560
  return X, M
561
 
562
 
563
  class WTEmbedder:
564
- """
565
- ESM2 embeddings for AA sequences.
566
- - pooled(): mean over tokens where attention_mask==1 AND token_id not in {CLS, EOS, PAD,...}
567
- - unpooled(): returns token embeddings filtered to valid tokens (specials removed),
568
- plus a 1-mask of length Li (since already filtered).
569
- """
570
- def __init__(
571
- self,
572
- device: torch.device,
573
- esm_name: str = "facebook/esm2_t33_650M_UR50D",
574
- max_len: int = 1022,
575
- use_cache: bool = True,
576
- ):
577
  self.device = device
578
  self.max_len = max_len
579
  self.use_cache = use_cache
580
-
581
  self.tokenizer = EsmTokenizer.from_pretrained(esm_name)
582
  self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval()
583
-
584
  self.special_ids = self._get_special_ids(self.tokenizer)
585
  self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
586
- if len(self.special_ids) else None)
587
-
588
  self._cache_pooled: Dict[str, torch.Tensor] = {}
589
  self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
590
 
591
  @staticmethod
592
  def _get_special_ids(tokenizer) -> List[int]:
593
- cand = [
594
- getattr(tokenizer, "pad_token_id", None),
595
- getattr(tokenizer, "cls_token_id", None),
596
- getattr(tokenizer, "sep_token_id", None),
597
- getattr(tokenizer, "bos_token_id", None),
598
- getattr(tokenizer, "eos_token_id", None),
599
- getattr(tokenizer, "mask_token_id", None),
600
- ]
601
  return sorted({int(x) for x in cand if x is not None})
602
 
603
- def _tokenize(self, seq_list: List[str]) -> Dict[str, torch.Tensor]:
604
- tok = self.tokenizer(
605
- seq_list,
606
- return_tensors="pt",
607
- padding=True,
608
- truncation=True,
609
- max_length=self.max_len,
610
- )
611
  tok = {k: v.to(self.device) for k, v in tok.items()}
612
  if "attention_mask" not in tok:
613
  tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
614
  return tok
615
 
 
 
 
 
 
 
616
  @torch.no_grad()
617
  def pooled(self, seq: str) -> torch.Tensor:
618
  s = seq.strip()
619
- if self.use_cache and s in self._cache_pooled:
620
- return self._cache_pooled[s]
621
-
622
  tok = self._tokenize([s])
623
- ids = tok["input_ids"] # (1,L)
624
- attn = tok["attention_mask"].bool() # (1,L)
625
-
626
- out = self.model(**tok)
627
- h = out.last_hidden_state # (1,L,H)
628
-
629
- valid = attn
630
- if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
631
- valid = valid & (~_safe_isin(ids, self.special_ids_t))
632
-
633
  vf = valid.unsqueeze(-1).float()
634
- summed = (h * vf).sum(dim=1) # (1,H)
635
- denom = vf.sum(dim=1).clamp(min=1e-9) # (1,1)
636
- pooled = summed / denom # (1,H)
637
-
638
- if self.use_cache:
639
- self._cache_pooled[s] = pooled
640
  return pooled
641
 
642
  @torch.no_grad()
643
  def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]:
644
- """
645
- Returns:
646
- X: (1, Li, H) float32 on device
647
- M: (1, Li) bool on device
648
- where Li excludes padding + special tokens.
649
- """
650
  s = seq.strip()
651
- if self.use_cache and s in self._cache_unpooled:
652
- return self._cache_unpooled[s]
653
-
654
  tok = self._tokenize([s])
655
- ids = tok["input_ids"] # (1,L)
656
- attn = tok["attention_mask"].bool() # (1,L)
657
-
658
- out = self.model(**tok)
659
- h = out.last_hidden_state # (1,L,H)
660
-
661
- valid = attn
662
- if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
663
- valid = valid & (~_safe_isin(ids, self.special_ids_t))
664
-
665
- keep = valid[0] # (L,)
666
- X = h[:, keep, :] # (1,Li,H)
667
  M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
668
-
669
- if self.use_cache:
670
- self._cache_unpooled[s] = (X, M)
671
  return X, M
672
 
673
 
674
-
675
  # -----------------------------
676
  # Predictor
677
  # -----------------------------
 
678
  class PeptiVersePredictor:
679
- """
680
- - loads best models from training_classifiers/
681
- - computes embeddings as needed (pooled/unpooled)
682
- - supports: xgb, joblib(ENET/SVM/SVR), NN(mlp/cnn/transformer), binding pooled/unpooled.
683
- """
684
  def __init__(
685
  self,
686
  manifest_path: str | Path,
687
  classifier_weight_root: str | Path,
688
  esm_name="facebook/esm2_t33_650M_UR50D",
689
  clm_name="aaronfeller/PeptideCLM-23M-all",
 
690
  smiles_vocab="tokenizer/new_vocab.txt",
691
  smiles_splits="tokenizer/new_splits.txt",
692
  device: Optional[str] = None,
@@ -697,293 +679,398 @@ class PeptiVersePredictor:
697
 
698
  self.manifest = read_best_manifest_csv(manifest_path)
699
 
700
- self.wt_embedder = WTEmbedder(self.device)
701
- self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name,
702
- vocab_path=str(self.root / smiles_vocab),
703
- splits_path=str(self.root / smiles_splits))
 
704
 
705
- self.models: Dict[Tuple[str, str], Any] = {}
706
- self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {}
 
 
707
 
708
  self._load_all_best_models()
709
 
710
- def _resolve_dir(self, prop_key: str, model_name: str, mode: str) -> Path:
711
- # map halflife -> half_life folder on disk (common layout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  disk_prop = "half_life" if prop_key == "halflife" else prop_key
713
  base = self.training_root / disk_prop
714
 
715
- # special handling for halflife xgb_wt_log / xgb_smiles
716
- if prop_key == "halflife" and model_name in {"xgb_wt_log", "xgb_smiles"}:
717
- d = base / model_name
718
- if d.exists():
719
- return d
720
-
721
- # special handling for halflife transformer wt log folder
722
- if prop_key == "halflife" and mode == "wt" and model_name == "transformer":
723
- d = base / "transformer_wt_log"
724
- if d.exists():
725
- return d
726
-
727
- if prop_key == "halflife" and model_name == "xgb":
728
- d = base / ("xgb_wt_log" if mode == "wt" else "xgb_smiles")
729
- if d.exists():
730
- return d
731
 
732
  candidates = [
733
- base / f"{model_name}_{mode}",
734
  base / model_name,
735
  ]
736
- if mode == "wt":
737
- candidates += [base / f"{model_name}_wt"]
738
- if mode == "smiles":
739
- candidates += [base / f"{model_name}_smiles"]
740
-
741
  for d in candidates:
742
- if d.exists():
743
- return d
744
 
745
  raise FileNotFoundError(
746
- f"Cannot find model directory for {prop_key} {model_name} {mode}. Tried: {candidates}"
747
  )
748
 
749
-
750
  def _load_all_best_models(self):
751
  for prop_key, row in self.manifest.items():
752
- for mode, label, thr in [
753
- ("wt", row.best_wt, row.thr_wt),
754
- ("smiles", row.best_smiles, row.thr_smiles),
755
  ]:
756
- m = canon_model(label)
757
- if m is None:
758
  continue
 
759
 
760
- # ---- binding affinity special ----
761
  if prop_key == "binding_affinity":
762
- # label is pooled/unpooled; mode chooses folder wt_wt_* vs wt_smiles_*
763
- pooled_or_unpooled = m # "pooled" or "unpooled"
764
- folder = f"wt_{mode}_{pooled_or_unpooled}" # wt_wt_pooled / wt_smiles_unpooled etc.
765
  model_dir = self.training_root / "binding_affinity" / folder
766
  art = find_best_artifact(model_dir)
767
- if art.suffix != ".pt":
768
- raise RuntimeError(f"Binding model expected best_model.pt, got {art}")
769
- model = load_binding_model(art, pooled_or_unpooled=pooled_or_unpooled, device=self.device)
770
- self.models[(prop_key, mode)] = model
771
- self.meta[(prop_key, mode)] = {
772
- "task_type": "Regression",
773
- "threshold": None,
774
- "artifact": str(art),
775
- "model_name": pooled_or_unpooled,
 
776
  }
 
 
 
 
 
 
 
 
 
 
777
  continue
778
 
779
- model_dir = self._resolve_dir(prop_key, m, mode)
 
 
 
 
780
  kind, obj, art = load_artifact(model_dir, self.device)
781
 
782
- if kind in {"xgb", "joblib"}:
783
- self.models[(prop_key, mode)] = obj
 
784
  else:
785
- # rebuild NN architecture
786
- arch = m
787
- if arch.startswith("transformer"):
788
- arch = "transformer"
789
- elif arch.startswith("mlp"):
790
- arch = "mlp"
791
- elif arch.startswith("cnn"):
792
- arch = "cnn"
793
-
794
- self.models[(prop_key, mode)] = build_torch_model_from_ckpt(arch, obj, self.device)
795
-
796
- self.meta[(prop_key, mode)] = {
797
- "task_type": row.task_type,
798
- "threshold": thr,
799
- "artifact": str(art),
800
- "model_name": m,
801
- "kind": kind,
802
- }
803
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804
 
805
- def _get_features_for_model(self, prop_key: str, mode: str, input_str: str):
806
- """
807
- Returns either:
808
- - pooled np array shape (1,H) for xgb/joblib
809
- - unpooled torch tensors (X,M) for NN
810
- """
811
- model = self.models[(prop_key, mode)]
812
- meta = self.meta[(prop_key, mode)]
813
- kind = meta.get("kind", None)
814
- model_name = meta.get("model_name", "")
815
 
816
- if prop_key == "binding_affinity":
817
- raise RuntimeError("Use predict_binding_affinity().")
818
-
819
- # If torch NN: needs unpooled
 
 
 
 
 
 
 
 
820
  if kind == "torch_ckpt":
821
- if mode == "wt":
822
- X, M = self.wt_embedder.unpooled(input_str)
823
- else:
824
- X, M = self.smiles_embedder.unpooled(input_str)
825
- return X, M
826
-
827
- # Otherwise pooled vectors for xgb/joblib
828
- if mode == "wt":
829
- v = self.wt_embedder.pooled(input_str) # (1,H)
830
- else:
831
- v = self.smiles_embedder.pooled(input_str) # (1,H)
832
- feats = v.detach().cpu().numpy().astype(np.float32)
833
- feats = np.nan_to_num(feats, nan=0.0)
834
- feats = np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max)
835
- return feats
836
-
837
- def predict_property(self, prop_key: str, mode: str, input_str: str) -> Dict[str, Any]:
838
- """
839
- mode: "wt" for AA sequence input, "smiles" for SMILES input
840
- Returns dict with score + label if classifier threshold exists.
841
- """
842
- if (prop_key, mode) not in self.models:
843
- raise KeyError(f"No model loaded for ({prop_key}, {mode}). Check manifest and folders.")
844
-
845
- meta = self.meta[(prop_key, mode)]
846
- model = self.models[(prop_key, mode)]
847
- task_type = meta["task_type"].lower()
848
- thr = meta.get("threshold", None)
849
- kind = meta.get("kind", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850
 
851
  if prop_key == "binding_affinity":
852
  raise RuntimeError("Use predict_binding_affinity().")
853
 
854
- # NN path (logits / regression)
855
  if kind == "torch_ckpt":
856
- X, M = self._get_features_for_model(prop_key, mode, input_str)
857
  with torch.no_grad():
858
- y = model(X, M).squeeze().float().cpu().item()
859
- # invert log1p(hours) ONLY for WT half-life log models
860
- model_name = meta.get("model_name", "")
861
- if (
862
- prop_key == "halflife"
863
- and mode == "wt"
864
- and model_name in {"xgb_wt_log", "transformer_wt_log"}
865
- ):
866
- y = float(np.expm1(y))
867
  if task_type == "classifier":
868
- prob = float(1.0 / (1.0 + np.exp(-y))) # sigmoid(logit)
869
- out = {"property": prop_key, "mode": mode, "score": prob}
 
870
  if thr is not None:
871
- out["label"] = int(prob >= float(thr))
872
- out["threshold"] = float(thr)
873
- return out
874
  else:
875
- return {"property": prop_key, "mode": mode, "score": float(y)}
876
-
877
- if kind == "xgb":
878
- feats = self._get_features_for_model(prop_key, mode, input_str)
879
- dmat = xgb.DMatrix(feats)
880
- pred = float(model.predict(dmat)[0])
881
-
882
- # invert log1p(hours) ONLY for WT half-life log models
883
- model_name = meta.get("model_name", "")
884
- if (
885
- prop_key == "halflife"
886
- and mode == "wt"
887
- and model_name in {"xgb_wt_log", "transformer_wt_log"}
888
- ):
889
  pred = float(np.expm1(pred))
890
-
891
- out = {"property": prop_key, "mode": mode, "score": pred}
892
-
893
- return out
894
-
895
- # joblib path (svm/enet/svr)
896
- if kind == "joblib":
897
- feats = self._get_features_for_model(prop_key, mode, input_str) # (1,H)
898
- # classifier vs regressor behavior differs by estimator
899
  if task_type == "classifier":
900
  if hasattr(model, "predict_proba"):
901
  pred = float(model.predict_proba(feats)[:, 1][0])
 
 
902
  else:
903
- if hasattr(model, "decision_function"):
904
- logit = float(model.decision_function(feats)[0])
905
- pred = float(1.0 / (1.0 + np.exp(-logit)))
906
- else:
907
- pred = float(model.predict(feats)[0])
908
- out = {"property": prop_key, "mode": mode, "score": pred}
909
  if thr is not None:
910
- out["label"] = int(pred >= float(thr))
911
- out["threshold"] = float(thr)
912
- return out
913
  else:
914
  pred = float(model.predict(feats)[0])
915
- return {"property": prop_key, "mode": mode, "score": pred}
916
-
917
- raise RuntimeError(f"Unknown model kind={kind}")
 
918
 
919
- def predict_binding_affinity(self, mode: str, target_seq: str, binder_str: str) -> Dict[str, Any]:
920
- """
921
- mode: "wt" (binder is AA sequence) -> wt_wt_(pooled|unpooled)
922
- "smiles" (binder is SMILES) -> wt_smiles_(pooled|unpooled)
923
- """
924
- prop_key = "binding_affinity"
925
- if (prop_key, mode) not in self.models:
926
- raise KeyError(f"No binding model loaded for ({prop_key}, {mode}).")
927
 
928
- model = self.models[(prop_key, mode)]
929
- pooled_or_unpooled = self.meta[(prop_key, mode)]["model_name"] # pooled/unpooled
930
 
931
- # target is always WT sequence (ESM)
932
- if pooled_or_unpooled == "pooled":
933
- t_vec = self.wt_embedder.pooled(target_seq) # (1,Ht)
934
- if mode == "wt":
935
- b_vec = self.wt_embedder.pooled(binder_str) # (1,Hb)
936
- else:
937
- b_vec = self.smiles_embedder.pooled(binder_str) # (1,Hb)
 
 
 
 
 
 
 
 
938
  with torch.no_grad():
939
  reg, logits = model(t_vec, b_vec)
940
- affinity = float(reg.squeeze().cpu().item())
941
- cls_logit = int(torch.argmax(logits, dim=-1).cpu().item())
942
- cls_thr = affinity_to_class(affinity)
943
  else:
944
  T, Mt = self.wt_embedder.unpooled(target_seq)
945
- if mode == "wt":
946
- B, Mb = self.wt_embedder.unpooled(binder_str)
947
- else:
948
- B, Mb = self.smiles_embedder.unpooled(binder_str)
949
  with torch.no_grad():
950
  reg, logits = model(T, Mt, B, Mb)
951
- affinity = float(reg.squeeze().cpu().item())
952
- cls_logit = int(torch.argmax(logits, dim=-1).cpu().item())
953
- cls_thr = affinity_to_class(affinity)
954
-
955
- names = {0: "High (≥9)", 1: "Moderate (7-9)", 2: "Low (<7)"}
956
- return {
957
- "property": "binding_affinity",
958
- "mode": mode,
959
- "affinity": affinity,
 
960
  "class_by_threshold": names[cls_thr],
961
- "class_by_logits": names[cls_logit],
962
- "binding_model": pooled_or_unpooled,
963
  }
964
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
965
 
966
  if __name__ == "__main__":
967
- predictor = PeptiVersePredictor(
968
- manifest_path="basic_models.txt",
969
- classifier_weight_root="./"
970
- )
971
- print(predictor.predict_property("hemolysis", "wt", "GIGAVLKVLTTGLPALISWIKRKRQQ"))
972
- print(predictor.predict_binding_affinity("wt", target_seq="...", binder_str="..."))
973
 
974
- # Test Embedding #
975
- """
976
- device = torch.device("cuda:0")
977
-
978
- wt = WTEmbedder(device)
979
- sm = SMILESEmbedder(device,
980
- vocab_path="./tokeizner/new_vocab.txt",
981
- splits_path="./tokenizer/new_splits.txt"
982
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
983
 
984
- p = wt.pooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,1280)
985
- X, M = wt.unpooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,Li,1280), (1,Li)
986
-
987
- p2 = sm.pooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,H_smiles)
988
- X2, M2 = sm.unpooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,Li,H_smiles), (1,Li)
989
- """
 
1
  from __future__ import annotations
 
2
  import csv, re, json
3
  from dataclasses import dataclass
4
  from pathlib import Path
5
  from typing import Dict, Optional, Tuple, Any, List
 
6
  import numpy as np
7
  import torch
8
  import torch.nn as nn
9
  import joblib
10
  import xgboost as xgb
 
11
  from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM
12
  from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
13
  from lightning.pytorch import seed_everything
 
16
  # -----------------------------
17
  # Manifest
18
  # -----------------------------
19
+
20
+ EMB_TAG_TO_FOLDER_SUFFIX = {
21
+ "wt": "wt",
22
+ "peptideclm": "smiles",
23
+ "chemberta": "chemberta",
24
+ }
25
+
26
+ EMB_TAG_TO_RUNTIME_MODE = {
27
+ "wt": "wt",
28
+ "peptideclm": "smiles",
29
+ "chemberta": "chemberta",
30
+ }
31
+
32
+ MAPIE_REGRESSION_MODELS = {"svr", "enet_gpu"}
33
+ DNN_ARCHS = {"mlp", "cnn", "transformer"}
34
+ XGB_MODELS = {"xgb", "xgb_reg", "xgb_wt_log", "xgb_smiles"}
35
+
36
+
37
  @dataclass(frozen=True)
38
  class BestRow:
39
  property_key: str
40
+ best_wt: Optional[Tuple[str, Optional[str]]]
41
+ best_smiles: Optional[Tuple[str, Optional[str]]]
42
+ task_type: str
43
+ thr_wt: Optional[float]
44
  thr_smiles: Optional[float]
45
 
46
 
 
49
 
50
  def _none_if_dash(s: str) -> Optional[str]:
51
  s = _clean(s)
52
+ return None if s in {"", "-", "-", "NA", "N/A"} else s
 
 
53
 
54
  def _float_or_none(s: str) -> Optional[float]:
55
  s = _clean(s)
56
+ return None if s in {"", "-", "-", "NA", "N/A"} else float(s)
 
 
57
 
58
  def normalize_property_key(name: str) -> str:
59
  n = name.strip().lower()
60
  n = re.sub(r"\s*\(.*?\)\s*", "", n)
61
  n = n.replace("-", "_").replace(" ", "_")
 
62
  if "permeability" in n and "pampa" not in n and "caco" not in n:
63
  return "permeability_penetrance"
64
  if n == "binding_affinity":
 
70
  return n
71
 
72
 
73
+ MODEL_ALIAS = {
74
+ "SVM": "svm_gpu",
75
+ "SVR": "svr",
76
+ "ENET": "enet_gpu",
77
+ "CNN": "cnn",
78
+ "MLP": "mlp",
79
+ "TRANSFORMER": "transformer",
80
+ "XGB": "xgb",
81
+ "XGB_REG": "xgb_reg",
82
+ "POOLED": "pooled",
83
+ "UNPOOLED": "unpooled",
84
+ "TRANSFORMER_WT_LOG": "transformer_wt_log",
85
+ }
86
+
87
+ def _parse_model_and_emb(raw: Optional[str]) -> Optional[Tuple[str, Optional[str]]]:
88
+ if raw is None:
89
+ return None
90
+ raw = _clean(raw)
91
+ if not raw or raw in {"-", "-", "NA", "N/A"}:
92
+ return None
93
+
94
+ m = re.match(r"^(.+?)\s*\((.+?)\)\s*$", raw)
95
+ if m:
96
+ model_raw = m.group(1).strip()
97
+ emb_tag = m.group(2).strip().lower()
98
+ else:
99
+ model_raw = raw
100
+ emb_tag = None
101
+
102
+ canon = MODEL_ALIAS.get(model_raw.upper(), model_raw.lower())
103
+ return canon, emb_tag
104
+
105
+
106
  def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
 
 
 
 
107
  p = Path(path)
108
  out: Dict[str, BestRow] = {}
109
 
 
129
  continue
130
  prop_key = normalize_property_key(prop_raw)
131
 
132
+ best_wt = _parse_model_and_emb(_none_if_dash(rec.get("Best_Model_WT", "")))
133
+ best_smiles = _parse_model_and_emb(_none_if_dash(rec.get("Best_Model_SMILES", "")))
134
+
135
  row = BestRow(
136
  property_key=prop_key,
137
+ best_wt=best_wt,
138
+ best_smiles=best_smiles,
139
  task_type=_clean(rec.get("Type", "Classifier")),
140
  thr_wt=_float_or_none(rec.get("Threshold_WT", "")),
141
  thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")),
 
145
  return out
146
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  # -----------------------------
149
  # Generic artifact loading
150
  # -----------------------------
151
  def find_best_artifact(model_dir: Path) -> Path:
152
+ for pat in ["best_model.json", "best_model.pt", "best_model*.joblib",
153
+ "model.json", "model.ubj", "final_model.json"]:
154
  hits = sorted(model_dir.glob(pat))
155
  if hits:
156
  return hits[0]
157
+ seed_pt = model_dir / "seed_1986" / "model.pt"
158
+ if seed_pt.exists():
159
+ return seed_pt
160
  raise FileNotFoundError(f"No best_model artifact found in {model_dir}")
161
 
162
  def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]:
163
  art = find_best_artifact(model_dir)
 
164
  if art.suffix == ".json":
165
  booster = xgb.Booster()
 
166
  booster.load_model(str(art))
167
  return "xgb", booster, art
 
168
  if art.suffix == ".joblib":
169
  obj = joblib.load(art)
170
  return "joblib", obj, art
 
171
  if art.suffix == ".pt":
172
  ckpt = torch.load(art, map_location=device, weights_only=False)
173
  return "torch_ckpt", ckpt, art
 
174
  raise ValueError(f"Unknown artifact type: {art}")
175
 
176
 
 
178
  # NN architectures
179
  # -----------------------------
180
  class MaskedMeanPool(nn.Module):
181
+ def forward(self, X, M):
182
  Mf = M.unsqueeze(-1).float()
183
  denom = Mf.sum(dim=1).clamp(min=1.0)
184
  return (X * Mf).sum(dim=1) / denom
 
188
  super().__init__()
189
  self.pool = MaskedMeanPool()
190
  self.net = nn.Sequential(
191
+ nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout),
 
 
192
  nn.Linear(hidden, 1),
193
  )
194
  def forward(self, X, M):
195
+ return self.net(self.pool(X, M)).squeeze(-1)
 
196
 
197
  class CNNHead(nn.Module):
198
  def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
199
  super().__init__()
200
+ blocks, ch = [], in_ch
 
201
  for _ in range(layers):
202
+ blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)]
 
 
203
  ch = c
204
  self.conv = nn.Sequential(*blocks)
205
  self.head = nn.Linear(c, 1)
 
206
  def forward(self, X, M):
207
+ Y = self.conv(X.transpose(1, 2)).transpose(1, 2)
 
208
  Mf = M.unsqueeze(-1).float()
209
+ pooled = (Y * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
 
210
  return self.head(pooled).squeeze(-1)
211
 
212
  class TransformerHead(nn.Module):
 
219
  )
220
  self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
221
  self.head = nn.Linear(d_model, 1)
 
222
  def forward(self, X, M):
223
+ Z = self.enc(self.proj(X), src_key_padding_mask=~M)
 
 
224
  Mf = M.unsqueeze(-1).float()
225
+ pooled = (Z * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
 
226
  return self.head(pooled).squeeze(-1)
227
 
228
  def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int:
229
+ if model_name == "mlp": return int(sd["net.0.weight"].shape[1])
230
+ if model_name == "cnn": return int(sd["conv.0.weight"].shape[1])
231
+ if model_name == "transformer": return int(sd["proj.weight"].shape[1])
 
 
 
232
  raise ValueError(model_name)
233
 
234
  def _infer_num_layers_from_sd(sd: dict, prefix: str = "enc.layers.") -> int:
 
235
  idxs = set()
236
  for k in sd.keys():
237
  if k.startswith(prefix):
238
+ m = re.match(r"(\d+)\.", k[len(prefix):])
 
239
  if m:
240
  idxs.add(int(m.group(1)))
241
  return (max(idxs) + 1) if idxs else 1
242
 
243
  def _infer_transformer_arch_from_sd(sd: dict) -> Tuple[int, int, int]:
 
 
 
 
 
 
244
  if "proj.weight" not in sd:
245
+ raise KeyError("Missing proj.weight in state_dict")
246
  d_model = int(sd["proj.weight"].shape[0])
247
+ layers = _infer_num_layers_from_sd(sd, prefix="enc.layers.")
248
+ ff = int(sd["enc.layers.0.linear1.weight"].shape[0]) if "enc.layers.0.linear1.weight" in sd else 4 * d_model
 
 
 
249
  return d_model, layers, ff
250
 
251
  def _pick_nhead(d_model: int) -> int:
 
252
  for h in (8, 6, 4, 3, 2, 1):
253
  if d_model % h == 0:
254
  return h
 
256
 
257
  def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module:
258
  params = ckpt["best_params"]
259
+ sd = ckpt["state_dict"]
260
  in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name)))
261
  dropout = float(params.get("dropout", 0.1))
262
 
 
266
  model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
267
  layers=int(params["layers"]), dropout=dropout)
268
  elif model_name == "transformer":
 
269
  d_model = params.get("d_model") or params.get("hidden") or params.get("hidden_dim")
 
270
  if d_model is None:
271
  d_model_i, layers_i, ff_i = _infer_transformer_arch_from_sd(sd)
272
  nhead_i = _pick_nhead(d_model_i)
273
  model = TransformerHead(
274
+ in_dim=in_dim, d_model=int(d_model_i), nhead=int(params.get("nhead", nhead_i)),
275
+ layers=int(params.get("layers", layers_i)), ff=int(params.get("ff", ff_i)),
 
 
 
276
  dropout=float(params.get("dropout", dropout)),
277
  )
278
  else:
279
  d_model = int(d_model)
280
  model = TransformerHead(
281
+ in_dim=in_dim, d_model=d_model,
 
282
  nhead=int(params.get("nhead", _pick_nhead(d_model))),
283
  layers=int(params.get("layers", 2)),
284
  ff=int(params.get("ff", 4 * d_model)),
285
+ dropout=dropout,
286
  )
287
  else:
288
  raise ValueError(f"Unknown NN model_name={model_name}")
289
 
290
  model.load_state_dict(sd)
291
+ model.to(device).eval()
 
292
  return model
293
 
294
 
295
  # -----------------------------
296
+ # Wrappers
297
+ # -----------------------------
298
+ from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
299
+
300
+ class PassthroughRegressor(BaseEstimator, RegressorMixin):
301
+ def __init__(self, preds: np.ndarray):
302
+ self.preds = preds
303
+ def fit(self, X, y): return self
304
+ def predict(self, X): return self.preds[:len(X)]
305
+
306
+ class PassthroughClassifier(BaseEstimator, ClassifierMixin):
307
+ def __init__(self, preds: np.ndarray):
308
+ self.preds = preds
309
+ self.classes_ = np.array([0, 1])
310
+ def fit(self, X, y): return self
311
+ def predict(self, X): return (self.preds[:len(X)] >= 0.5).astype(int)
312
+ def predict_proba(self, X):
313
+ p = self.preds[:len(X)]
314
+ return np.stack([1 - p, p], axis=1)
315
+
316
+
317
  # -----------------------------
318
+ # Uncertainty helpers
319
+ # -----------------------------
320
+ SEED_DIRS = ["seed_1986", "seed_42", "seed_0", "seed_123", "seed_12345"]
321
+
322
+ def load_seed_ensemble(model_dir: Path, arch: str, device: torch.device) -> List[nn.Module]:
323
+ ensemble = []
324
+ for sd_name in SEED_DIRS:
325
+ pt = model_dir / sd_name / "model.pt"
326
+ if not pt.exists():
327
+ continue
328
+ ckpt = torch.load(pt, map_location=device, weights_only=False)
329
+ ensemble.append(build_torch_model_from_ckpt(arch, ckpt, device))
330
+ return ensemble
331
+
332
+ def _binary_entropy(p: float) -> float:
333
+ p = float(np.clip(p, 1e-9, 1 - 1e-9))
334
+ return float(-p * np.log(p) - (1 - p) * np.log(1 - p))
335
+
336
+ def _ensemble_clf_uncertainty(ensemble: List[nn.Module], X: torch.Tensor, M: torch.Tensor) -> float:
337
+ probs = []
338
+ with torch.no_grad():
339
+ for m in ensemble:
340
+ logit = m(X, M).squeeze().float().cpu().item()
341
+ probs.append(1.0 / (1.0 + np.exp(-logit)))
342
+ return _binary_entropy(float(np.mean(probs)))
343
+
344
+ def _ensemble_reg_uncertainty(ensemble: List[nn.Module], X: torch.Tensor, M: torch.Tensor) -> float:
345
+ preds = []
346
+ with torch.no_grad():
347
+ for m in ensemble:
348
+ preds.append(m(X, M).squeeze().float().cpu().item())
349
+ return float(np.std(preds))
350
+
351
+ def _mapie_uncertainty(mapie_bundle: dict, score: float,
352
+ embedding: Optional[np.ndarray] = None) -> Tuple[float, float]:
353
+ """
354
+ Returns (ci_low, ci_high) from a conformal bundle.
355
+ - adaptive: {"quantile": q, "sigma_model": xgb, "emb_tag": ..., "adaptive": True}
356
+ Input-dependent: interval = score +/- q * sigma(embedding)
357
+ - plain_quantile: {"quantile": q, "alpha": ...}
358
+ Fixed-width: interval = score +/- q
359
+ """
360
+ # Adaptive format is input-dependent interval
361
+ if mapie_bundle.get("adaptive") and "sigma_model" in mapie_bundle:
362
+ q = float(mapie_bundle["quantile"])
363
+ if embedding is not None:
364
+ sigma_model = mapie_bundle["sigma_model"]
365
+ sigma = float(sigma_model.predict(xgb.DMatrix(embedding.reshape(1, -1)))[0])
366
+ sigma = max(sigma, 1e-6)
367
+ else:
368
+ # No embedding available - fall back to fixed interval with sigma=1
369
+ sigma = 1.0
370
+ return float(score - q * sigma), float(score + q * sigma)
371
+
372
+ # Plain quantile format
373
+ if "quantile" in mapie_bundle:
374
+ q = float(mapie_bundle["quantile"])
375
+ return float(score - q), float(score + q)
376
+
377
+ X_dummy = np.zeros((1, 1))
378
+ result = mapie.predict(X_dummy)
379
+ if isinstance(result, tuple):
380
+ intervals = np.asarray(result[1])
381
+ if intervals.ndim == 3:
382
+ return float(intervals[0, 0, 0]), float(intervals[0, 1, 0])
383
+ return float(intervals[0, 0]), float(intervals[0, 1])
384
+ raise RuntimeError(
385
+ f"Cannot extract intervals: unknown MAPIE bundle format. "
386
+ f"Bundle keys: {list(mapie_bundle.keys())}."
387
+ )
388
+
389
  def affinity_to_class(y: float) -> int:
 
390
  if y >= 9.0: return 0
391
  if y < 7.0: return 2
392
  return 1
 
396
  super().__init__()
397
  self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
398
  self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
 
399
  self.layers = nn.ModuleList([])
400
  for _ in range(n_layers):
401
  self.layers.append(nn.ModuleDict({
402
  "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
403
  "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
404
+ "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden),
405
+ "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden),
 
 
406
  "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
407
  "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
408
  }))
 
409
  self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
410
  self.reg = nn.Linear(hidden, 1)
411
  self.cls = nn.Linear(hidden, 3)
412
 
413
  def forward(self, t_vec, b_vec):
414
+ t = self.t_proj(t_vec).unsqueeze(0)
415
+ b = self.b_proj(b_vec).unsqueeze(0)
416
  for L in self.layers:
417
  t_attn, _ = L["attn_tb"](t, b, b)
418
  t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
419
  t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
 
420
  b_attn, _ = L["attn_bt"](b, t, t)
421
  b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
422
  b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
423
+ h = self.shared(torch.cat([t[0], b[0]], dim=-1))
 
 
424
  return self.reg(h).squeeze(-1), self.cls(h)
425
 
426
  class CrossAttnUnpooled(nn.Module):
 
428
  super().__init__()
429
  self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
430
  self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
 
431
  self.layers = nn.ModuleList([])
432
  for _ in range(n_layers):
433
  self.layers.append(nn.ModuleDict({
434
  "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
435
  "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
436
+ "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden),
437
+ "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden),
 
 
438
  "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
439
  "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
440
  }))
 
441
  self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
442
  self.reg = nn.Linear(hidden, 1)
443
  self.cls = nn.Linear(hidden, 3)
444
 
445
  def _masked_mean(self, X, M):
446
  Mf = M.unsqueeze(-1).float()
447
+ return (X * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
 
448
 
449
  def forward(self, T, Mt, B, Mb):
450
+ T = self.t_proj(T); Bx = self.b_proj(B)
451
+ kp_t, kp_b = ~Mt, ~Mb
 
 
 
452
  for L in self.layers:
453
  T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
454
+ T = L["n1t"](T + T_attn); T = L["n2t"](T + L["fft"](T))
 
 
455
  B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
456
+ Bx = L["n1b"](Bx + B_attn); Bx = L["n2b"](Bx + L["ffb"](Bx))
457
+ h = self.shared(torch.cat([self._masked_mean(T, Mt), self._masked_mean(Bx, Mb)], dim=-1))
 
 
 
 
 
458
  return self.reg(h).squeeze(-1), self.cls(h)
459
 
460
  def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module:
461
  ckpt = torch.load(best_model_pt, map_location=device, weights_only=False)
462
  params = ckpt["best_params"]
463
+ sd = ckpt["state_dict"]
 
 
464
  Ht = int(sd["t_proj.0.weight"].shape[1])
465
  Hb = int(sd["b_proj.0.weight"].shape[1])
466
+ common = dict(Ht=Ht, Hb=Hb, hidden=int(params["hidden_dim"]),
467
+ n_heads=int(params["n_heads"]), n_layers=int(params["n_layers"]),
468
+ dropout=float(params["dropout"]))
469
+ cls = CrossAttnPooled if pooled_or_unpooled == "pooled" else CrossAttnUnpooled
470
+ model = cls(**common)
 
 
 
 
 
 
 
 
 
 
 
471
  model.load_state_dict(sd)
472
+ return model.to(device).eval()
 
473
 
474
 
475
  # -----------------------------
476
  # Embedding generation
477
  # -----------------------------
478
  def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor:
 
 
 
479
  if hasattr(torch, "isin"):
480
  return torch.isin(ids, test_ids)
 
 
481
  return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1)
482
+
483
  class SMILESEmbedder:
484
+ def __init__(self, device, vocab_path, splits_path,
485
+ clm_name="aaronfeller/PeptideCLM-23M-all", max_len=512, use_cache=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  self.device = device
487
  self.max_len = max_len
488
  self.use_cache = use_cache
 
489
  self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path)
490
  self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval()
 
491
  self.special_ids = self._get_special_ids(self.tokenizer)
492
  self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
493
+ if self.special_ids else None)
 
494
  self._cache_pooled: Dict[str, torch.Tensor] = {}
495
  self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
496
 
497
  @staticmethod
498
  def _get_special_ids(tokenizer) -> List[int]:
499
+ cand = [getattr(tokenizer, f"{x}_token_id", None)
500
+ for x in ("pad", "cls", "sep", "bos", "eos", "mask")]
 
 
 
 
 
 
501
  return sorted({int(x) for x in cand if x is not None})
502
 
503
+ def _tokenize(self, smiles_list):
504
+ tok = self.tokenizer(smiles_list, return_tensors="pt", padding=True,
505
+ truncation=True, max_length=self.max_len)
506
+ for k in tok: tok[k] = tok[k].to(self.device)
 
 
 
 
 
 
507
  if "attention_mask" not in tok:
508
  tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
509
  return tok
510
 
511
+ def _valid_mask(self, ids, attn):
512
+ valid = attn.bool()
513
+ if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
514
+ valid = valid & (~_safe_isin(ids, self.special_ids_t))
515
+ return valid
516
+
517
  @torch.no_grad()
518
  def pooled(self, smiles: str) -> torch.Tensor:
519
  s = smiles.strip()
520
+ if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s]
521
+ tok = self._tokenize([s])
522
+ h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
523
+ valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
524
+ vf = valid.unsqueeze(-1).float()
525
+ pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9)
526
+ if self.use_cache: self._cache_pooled[s] = pooled
527
+ return pooled
528
 
529
+ @torch.no_grad()
530
+ def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]:
531
+ s = smiles.strip()
532
+ if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s]
533
  tok = self._tokenize([s])
534
+ h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
535
+ valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
536
+ X = h[:, valid[0], :]
537
+ M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
538
+ if self.use_cache: self._cache_unpooled[s] = (X, M)
539
+ return X, M
540
+
541
+
542
+ class ChemBERTaEmbedder:
543
+ def __init__(self, device, model_name="DeepChem/ChemBERTa-77M-MLM",
544
+ max_len=512, use_cache=True):
545
+ from transformers import AutoTokenizer, AutoModel
546
+ self.device = device
547
+ self.max_len = max_len
548
+ self.use_cache = use_cache
549
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
550
+ self.model = AutoModel.from_pretrained(model_name).to(device).eval()
551
+ self.special_ids = self._get_special_ids(self.tokenizer)
552
+ self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
553
+ if self.special_ids else None)
554
+ self._cache_pooled: Dict[str, torch.Tensor] = {}
555
+ self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
556
+
557
+ @staticmethod
558
+ def _get_special_ids(tokenizer) -> List[int]:
559
+ cand = [getattr(tokenizer, f"{x}_token_id", None)
560
+ for x in ("pad", "cls", "sep", "bos", "eos", "mask")]
561
+ return sorted({int(x) for x in cand if x is not None})
562
 
563
+ def _tokenize(self, smiles_list):
564
+ tok = self.tokenizer(smiles_list, return_tensors="pt", padding=True,
565
+ truncation=True, max_length=self.max_len)
566
+ for k in tok: tok[k] = tok[k].to(self.device)
567
+ if "attention_mask" not in tok:
568
+ tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
569
+ return tok
570
 
571
+ def _valid_mask(self, ids, attn):
572
+ valid = attn.bool()
573
  if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
574
  valid = valid & (~_safe_isin(ids, self.special_ids_t))
575
+ return valid
576
 
577
+ @torch.no_grad()
578
+ def pooled(self, smiles: str) -> torch.Tensor:
579
+ s = smiles.strip()
580
+ if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s]
581
+ tok = self._tokenize([s])
582
+ h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
583
+ valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
584
  vf = valid.unsqueeze(-1).float()
585
+ pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9)
586
+ if self.use_cache: self._cache_pooled[s] = pooled
 
 
 
 
587
  return pooled
588
 
589
  @torch.no_grad()
590
  def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
591
  s = smiles.strip()
592
+ if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s]
 
 
593
  tok = self._tokenize([s])
594
+ h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
595
+ valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
596
+ X = h[:, valid[0], :]
 
 
 
 
 
 
 
 
 
 
597
  M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
598
+ if self.use_cache: self._cache_unpooled[s] = (X, M)
 
 
599
  return X, M
600
 
601
 
602
  class WTEmbedder:
603
+ def __init__(self, device, esm_name="facebook/esm2_t33_650M_UR50D", max_len=1022, use_cache=True):
 
 
 
 
 
 
 
 
 
 
 
 
604
  self.device = device
605
  self.max_len = max_len
606
  self.use_cache = use_cache
 
607
  self.tokenizer = EsmTokenizer.from_pretrained(esm_name)
608
  self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval()
 
609
  self.special_ids = self._get_special_ids(self.tokenizer)
610
  self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
611
+ if self.special_ids else None)
 
612
  self._cache_pooled: Dict[str, torch.Tensor] = {}
613
  self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
614
 
615
  @staticmethod
616
  def _get_special_ids(tokenizer) -> List[int]:
617
+ cand = [getattr(tokenizer, f"{x}_token_id", None)
618
+ for x in ("pad", "cls", "sep", "bos", "eos", "mask")]
 
 
 
 
 
 
619
  return sorted({int(x) for x in cand if x is not None})
620
 
621
+ def _tokenize(self, seq_list):
622
+ tok = self.tokenizer(seq_list, return_tensors="pt", padding=True,
623
+ truncation=True, max_length=self.max_len)
 
 
 
 
 
624
  tok = {k: v.to(self.device) for k, v in tok.items()}
625
  if "attention_mask" not in tok:
626
  tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
627
  return tok
628
 
629
+ def _valid_mask(self, ids, attn):
630
+ valid = attn.bool()
631
+ if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
632
+ valid = valid & (~_safe_isin(ids, self.special_ids_t))
633
+ return valid
634
+
635
  @torch.no_grad()
636
  def pooled(self, seq: str) -> torch.Tensor:
637
  s = seq.strip()
638
+ if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s]
 
 
639
  tok = self._tokenize([s])
640
+ h = self.model(**tok).last_hidden_state
641
+ valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
 
 
 
 
 
 
 
 
642
  vf = valid.unsqueeze(-1).float()
643
+ pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9)
644
+ if self.use_cache: self._cache_pooled[s] = pooled
 
 
 
 
645
  return pooled
646
 
647
  @torch.no_grad()
648
  def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
649
  s = seq.strip()
650
+ if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s]
 
 
651
  tok = self._tokenize([s])
652
+ h = self.model(**tok).last_hidden_state
653
+ valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
654
+ X = h[:, valid[0], :]
 
 
 
 
 
 
 
 
 
655
  M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
656
+ if self.use_cache: self._cache_unpooled[s] = (X, M)
 
 
657
  return X, M
658
 
659
 
 
660
  # -----------------------------
661
  # Predictor
662
  # -----------------------------
663
+
664
  class PeptiVersePredictor:
 
 
 
 
 
665
  def __init__(
666
  self,
667
  manifest_path: str | Path,
668
  classifier_weight_root: str | Path,
669
  esm_name="facebook/esm2_t33_650M_UR50D",
670
  clm_name="aaronfeller/PeptideCLM-23M-all",
671
+ chemberta_name="DeepChem/ChemBERTa-77M-MLM",
672
  smiles_vocab="tokenizer/new_vocab.txt",
673
  smiles_splits="tokenizer/new_splits.txt",
674
  device: Optional[str] = None,
 
679
 
680
  self.manifest = read_best_manifest_csv(manifest_path)
681
 
682
+ self.wt_embedder = WTEmbedder(self.device, esm_name=esm_name)
683
+ self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name,
684
+ vocab_path=str(self.root / smiles_vocab),
685
+ splits_path=str(self.root / smiles_splits))
686
+ self.chemberta_embedder = ChemBERTaEmbedder(self.device, model_name=chemberta_name)
687
 
688
+ self.models: Dict[Tuple[str, str], Any] = {}
689
+ self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {}
690
+ self.mapie: Dict[Tuple[str, str], dict] = {}
691
+ self.ensembles: Dict[Tuple[str, str], List] = {}
692
 
693
  self._load_all_best_models()
694
 
695
+ def _get_embedder(self, emb_tag: str):
696
+ if emb_tag == "wt": return self.wt_embedder
697
+ if emb_tag == "peptideclm": return self.smiles_embedder
698
+ if emb_tag == "chemberta": return self.chemberta_embedder
699
+ raise ValueError(f"Unknown emb_tag={emb_tag!r}")
700
+
701
+ def _embed_pooled(self, emb_tag: str, input_str: str) -> np.ndarray:
702
+ v = self._get_embedder(emb_tag).pooled(input_str)
703
+ feats = v.detach().cpu().numpy().astype(np.float32)
704
+ feats = np.nan_to_num(feats, nan=0.0)
705
+ return np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max)
706
+
707
+ def _embed_unpooled(self, emb_tag: str, input_str: str) -> Tuple[torch.Tensor, torch.Tensor]:
708
+ return self._get_embedder(emb_tag).unpooled(input_str)
709
+
710
+ def _resolve_dir(self, prop_key: str, model_name: str, emb_tag: str) -> Path:
711
  disk_prop = "half_life" if prop_key == "halflife" else prop_key
712
  base = self.training_root / disk_prop
713
 
714
+ folder_suffix = EMB_TAG_TO_FOLDER_SUFFIX.get(emb_tag, emb_tag)
715
+
716
+ if prop_key == "halflife" and emb_tag == "wt":
717
+ if model_name == "transformer":
718
+ for d in [base / "transformer_wt_log", base / "transformer_wt"]:
719
+ if d.exists(): return d
720
+ if model_name in {"xgb", "xgb_reg"}:
721
+ d = base / "xgb_wt_log"
722
+ if d.exists(): return d
 
 
 
 
 
 
 
723
 
724
  candidates = [
725
+ base / f"{model_name}_{folder_suffix}",
726
  base / model_name,
727
  ]
 
 
 
 
 
728
  for d in candidates:
729
+ if d.exists(): return d
 
730
 
731
  raise FileNotFoundError(
732
+ f"Cannot find model dir for {prop_key}/{model_name}/{emb_tag}. Tried: {candidates}"
733
  )
734
 
 
735
  def _load_all_best_models(self):
736
  for prop_key, row in self.manifest.items():
737
+ for col, parsed, thr in [
738
+ ("wt", row.best_wt, row.thr_wt),
739
+ ("smiles", row.best_smiles, row.thr_smiles),
740
  ]:
741
+ if parsed is None:
 
742
  continue
743
+ model_name, emb_tag = parsed
744
 
745
+ # binding affinity
746
  if prop_key == "binding_affinity":
747
+ folder = model_name
748
+ pooled_or_unpooled = "unpooled" if "unpooled" in folder else "pooled"
 
749
  model_dir = self.training_root / "binding_affinity" / folder
750
  art = find_best_artifact(model_dir)
751
+ model = load_binding_model(art, pooled_or_unpooled, self.device)
752
+ self.models[(prop_key, col)] = model
753
+ self.meta[(prop_key, col)] = {
754
+ "task_type": "Regression",
755
+ "threshold": None,
756
+ "artifact": str(art),
757
+ "model_name": pooled_or_unpooled,
758
+ "emb_tag": emb_tag,
759
+ "folder": folder,
760
+ "kind": "binding",
761
  }
762
+ print(f" [LOAD] binding_affinity ({col}): folder={folder}, arch={pooled_or_unpooled}, emb_tag={emb_tag}, art={art.name}")
763
+ mapie_path = model_dir / "mapie_calibration.joblib"
764
+ if mapie_path.exists():
765
+ try:
766
+ self.mapie[(prop_key, col)] = joblib.load(mapie_path)
767
+ print(f" MAPIE loaded from {mapie_path.name}")
768
+ except Exception as e:
769
+ print(f" MAPIE load FAILED for ({prop_key}, {col}): {e}")
770
+ else:
771
+ print(f" No MAPIE bundle found (uncertainty will be unavailable)")
772
  continue
773
 
774
+ # infer emb_tag
775
+ if emb_tag is None:
776
+ emb_tag = col
777
+
778
+ model_dir = self._resolve_dir(prop_key, model_name, emb_tag)
779
  kind, obj, art = load_artifact(model_dir, self.device)
780
 
781
+ if kind == "torch_ckpt":
782
+ arch = self._base_arch(model_name)
783
+ model = build_torch_model_from_ckpt(arch, obj, self.device)
784
  else:
785
+ model = obj
786
+
787
+ self.models[(prop_key, col)] = model
788
+ self.meta[(prop_key, col)] = {
789
+ "task_type": row.task_type,
790
+ "threshold": thr,
791
+ "artifact": str(art),
792
+ "model_name": model_name,
793
+ "emb_tag": emb_tag,
794
+ "kind": kind,
795
+ }
796
+
797
+ print(f" [LOAD] ({prop_key}, {col}): kind={kind}, model={model_name}, emb={emb_tag}, task={row.task_type}, art={art.name}")
798
+
799
+ # MAPIE: SVR/ElasticNet, XGBoost regression, AND all regression torch_ckpt
800
+ is_regression = row.task_type.lower() == "regression"
801
+ wants_mapie = (
802
+ (model_name in MAPIE_REGRESSION_MODELS and is_regression)
803
+ or (kind == "xgb" and is_regression)
804
+ or (kind == "torch_ckpt" and is_regression)
805
+ )
806
+ if wants_mapie:
807
+ mapie_path = model_dir / "mapie_calibration.joblib"
808
+ if mapie_path.exists():
809
+ try:
810
+ self.mapie[(prop_key, col)] = joblib.load(mapie_path)
811
+ print(f" MAPIE loaded from {mapie_path.name}")
812
+ except Exception as e:
813
+ print(f" MAPIE load FAILED for ({prop_key}, {col}): {e}")
814
+ else:
815
+ print(f" No MAPIE bundle found at {mapie_path} (will fall back to ensemble if available)")
816
+
817
+ # Seed ensembles: DNN only, used when MAPIE not available
818
+ if kind == "torch_ckpt":
819
+ arch = self._base_arch(model_name)
820
+ ens = load_seed_ensemble(model_dir, arch, self.device)
821
+ if ens:
822
+ self.ensembles[(prop_key, col)] = ens
823
+ if (prop_key, col) in self.mapie:
824
+ print(f" Seed ensemble: {len(ens)} seeds loaded (MAPIE takes priority for regression)")
825
+ else:
826
+ unc_type = "ensemble_predictive_entropy" if row.task_type.lower() == "classifier" else "ensemble_std"
827
+ print(f" Seed ensemble: {len(ens)} seeds loaded uncertainty method: {unc_type}")
828
+ else:
829
+ if (prop_key, col) in self.mapie:
830
+ print(f" No seed ensemble (MAPIE covers uncertainty)")
831
+ else:
832
+ print(f" No seed ensemble found (checked: {SEED_DIRS}) - uncertainty unavailable")
833
 
834
+ # XGBoost/SVM classifiers: binary entropy
835
+ if kind in ("xgb", "joblib") and row.task_type.lower() == "classifier":
836
+ print(f" Uncertainty method: binary_predictive_entropy (computed at inference)")
 
 
 
 
 
 
 
837
 
838
+ @staticmethod
839
+ def _base_arch(model_name: str) -> str:
840
+ if model_name.startswith("transformer"): return "transformer"
841
+ if model_name.startswith("mlp"): return "mlp"
842
+ if model_name.startswith("cnn"): return "cnn"
843
+ return model_name
844
+
845
+ # Feature extraction
846
+ def _get_features(self, prop_key: str, col: str, input_str: str):
847
+ meta = self.meta[(prop_key, col)]
848
+ emb_tag = meta["emb_tag"]
849
+ kind = meta["kind"]
850
  if kind == "torch_ckpt":
851
+ return self._embed_unpooled(emb_tag, input_str)
852
+ return self._embed_pooled(emb_tag, input_str)
853
+
854
+ # Uncertainty
855
+ def _compute_uncertainty(self, prop_key: str, col: str, input_str: str,
856
+ score: float) -> Tuple[Any, str]:
857
+ meta = self.meta[(prop_key, col)]
858
+ kind = meta["kind"]
859
+ model_name = meta["model_name"]
860
+ task_type = meta["task_type"].lower()
861
+ emb_tag = meta["emb_tag"]
862
+
863
+ # Pooled embedding for adaptive MAPIE sigma model
864
+ def get_pooled_emb():
865
+ return self._embed_pooled(emb_tag, input_str) if emb_tag else None
866
+
867
+ # DNN
868
+ if kind == "torch_ckpt":
869
+ # Regression: prefer MAPIE if available
870
+ if task_type == "regression":
871
+ mapie_bundle = self.mapie.get((prop_key, col))
872
+ if mapie_bundle:
873
+ emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None
874
+ lo, hi = _mapie_uncertainty(mapie_bundle, score, emb)
875
+ return (lo, hi), "conformal_prediction_interval"
876
+ # Fall back to seed ensemble std
877
+ ens = self.ensembles.get((prop_key, col))
878
+ if ens:
879
+ X, M = self._embed_unpooled(emb_tag, input_str)
880
+ return _ensemble_reg_uncertainty(ens, X, M), "ensemble_std"
881
+ return None, "unavailable (no MAPIE bundle and no seed ensemble)"
882
+ # Classifier: ensemble predictive entropy
883
+ ens = self.ensembles.get((prop_key, col))
884
+ if not ens:
885
+ return None, "unavailable (no seed ensemble found)"
886
+ X, M = self._embed_unpooled(emb_tag, input_str)
887
+ return _ensemble_clf_uncertainty(ens, X, M), "ensemble_predictive_entropy"
888
+
889
+ # XGBoost
890
+ if kind == "xgb":
891
+ if task_type == "classifier":
892
+ return _binary_entropy(score), "binary_predictive_entropy"
893
+ mapie_bundle = self.mapie.get((prop_key, col))
894
+ if mapie_bundle:
895
+ emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None
896
+ lo, hi = _mapie_uncertainty(mapie_bundle, score, emb)
897
+ return (lo, hi), "conformal_prediction_interval"
898
+ return None, "unavailable (no MAPIE bundle for XGBoost regression)"
899
+
900
+ # SVR / ElasticNet regression: MAPIE
901
+ if kind == "joblib" and model_name in MAPIE_REGRESSION_MODELS and task_type == "regression":
902
+ mapie_bundle = self.mapie.get((prop_key, col))
903
+ if mapie_bundle:
904
+ emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None
905
+ lo, hi = _mapie_uncertainty(mapie_bundle, score, emb)
906
+ return (lo, hi), "conformal_prediction_interval"
907
+ return None, "unavailable (MAPIE bundle not found)"
908
+
909
+ # joblib classifiers (SVM, ElasticNet used as classifier)
910
+ if kind == "joblib" and task_type == "classifier":
911
+ return _binary_entropy(score), "binary_predictive_entropy_single_model"
912
+
913
+ return None, "unavailable"
914
+
915
+ def predict_property(self, prop_key: str, col: str, input_str: str,
916
+ uncertainty: bool = False) -> Dict[str, Any]:
917
+ if (prop_key, col) not in self.models:
918
+ raise KeyError(f"No model loaded for ({prop_key}, {col}).")
919
+
920
+ meta = self.meta[(prop_key, col)]
921
+ model = self.models[(prop_key, col)]
922
+ task_type = meta["task_type"].lower()
923
+ thr = meta.get("threshold")
924
+ kind = meta["kind"]
925
+ model_name = meta["model_name"]
926
 
927
  if prop_key == "binding_affinity":
928
  raise RuntimeError("Use predict_binding_affinity().")
929
 
930
+ # DNN
931
  if kind == "torch_ckpt":
932
+ X, M = self._get_features(prop_key, col, input_str)
933
  with torch.no_grad():
934
+ raw = model(X, M).squeeze().float().cpu().item()
935
+
936
+ if prop_key == "halflife" and col == "wt" and "log" in model_name:
937
+ raw = float(np.expm1(raw))
938
+
 
 
 
 
939
  if task_type == "classifier":
940
+ score = float(1.0 / (1.0 + np.exp(-raw)))
941
+ out = {"property": prop_key, "col": col, "score": score,
942
+ "emb_tag": meta["emb_tag"]}
943
  if thr is not None:
944
+ out["label"] = int(score >= float(thr)); out["threshold"] = float(thr)
 
 
945
  else:
946
+ out = {"property": prop_key, "col": col, "score": float(raw),
947
+ "emb_tag": meta["emb_tag"]}
948
+
949
+ # XGBoost
950
+ elif kind == "xgb":
951
+ feats = self._get_features(prop_key, col, input_str)
952
+ pred = float(model.predict(xgb.DMatrix(feats))[0])
953
+ if prop_key == "halflife" and col == "wt" and "log" in model_name:
 
 
 
 
 
 
954
  pred = float(np.expm1(pred))
955
+ out = {"property": prop_key, "col": col, "score": pred,
956
+ "emb_tag": meta["emb_tag"]}
957
+ if task_type == "classifier" and thr is not None:
958
+ out["label"] = int(pred >= float(thr)); out["threshold"] = float(thr)
959
+
960
+ # joblib (SVM / ElasticNet / SVR)
961
+ elif kind == "joblib":
962
+ feats = self._get_features(prop_key, col, input_str)
 
963
  if task_type == "classifier":
964
  if hasattr(model, "predict_proba"):
965
  pred = float(model.predict_proba(feats)[:, 1][0])
966
+ elif hasattr(model, "decision_function"):
967
+ pred = float(1.0 / (1.0 + np.exp(-model.decision_function(feats)[0])))
968
  else:
969
+ pred = float(model.predict(feats)[0])
970
+ out = {"property": prop_key, "col": col, "score": pred,
971
+ "emb_tag": meta["emb_tag"]}
 
 
 
972
  if thr is not None:
973
+ out["label"] = int(pred >= float(thr)); out["threshold"] = float(thr)
 
 
974
  else:
975
  pred = float(model.predict(feats)[0])
976
+ out = {"property": prop_key, "col": col, "score": pred,
977
+ "emb_tag": meta["emb_tag"]}
978
+ else:
979
+ raise RuntimeError(f"Unknown kind={kind}")
980
 
981
+ if uncertainty:
982
+ u_val, u_type = self._compute_uncertainty(prop_key, col, input_str, out["score"])
983
+ out["uncertainty"] = u_val
984
+ out["uncertainty_type"] = u_type
 
 
 
 
985
 
986
+ return out
 
987
 
988
+ def predict_binding_affinity(self, col: str, target_seq: str, binder_str: str,
989
+ uncertainty: bool = False) -> Dict[str, Any]:
990
+ prop_key = "binding_affinity"
991
+ if (prop_key, col) not in self.models:
992
+ raise KeyError(f"No binding model loaded for ({prop_key}, {col}).")
993
+
994
+ model = self.models[(prop_key, col)]
995
+ meta = self.meta[(prop_key, col)]
996
+ arch = meta["model_name"]
997
+ emb_tag = meta.get("emb_tag")
998
+
999
+ if arch == "pooled":
1000
+ t_vec = self.wt_embedder.pooled(target_seq)
1001
+ b_vec = self._get_embedder(emb_tag or col).pooled(binder_str) if emb_tag else \
1002
+ (self.wt_embedder.pooled(binder_str) if col == "wt" else self.smiles_embedder.pooled(binder_str))
1003
  with torch.no_grad():
1004
  reg, logits = model(t_vec, b_vec)
 
 
 
1005
  else:
1006
  T, Mt = self.wt_embedder.unpooled(target_seq)
1007
+ binder_emb = self._get_embedder(emb_tag or col) if emb_tag else \
1008
+ (self.wt_embedder if col == "wt" else self.smiles_embedder)
1009
+ B, Mb = binder_emb.unpooled(binder_str)
 
1010
  with torch.no_grad():
1011
  reg, logits = model(T, Mt, B, Mb)
1012
+
1013
+ affinity = float(reg.squeeze().cpu().item())
1014
+ cls_logit = int(torch.argmax(logits, dim=-1).cpu().item())
1015
+ cls_thr = affinity_to_class(affinity)
1016
+ names = {0: "High (≥9)", 1: "Moderate (7-9)", 2: "Low (<7)"}
1017
+
1018
+ out = {
1019
+ "property": "binding_affinity",
1020
+ "col": col,
1021
+ "affinity": affinity,
1022
  "class_by_threshold": names[cls_thr],
1023
+ "class_by_logits": names[cls_logit],
1024
+ "binding_model": arch,
1025
  }
1026
 
1027
+ if uncertainty:
1028
+ mapie_bundle = self.mapie.get((prop_key, col))
1029
+ if mapie_bundle:
1030
+ if mapie_bundle.get("adaptive") and "sigma_model" in mapie_bundle:
1031
+ # Concatenate target + binder pooled embeddings for sigma model
1032
+ binder_emb_tag = mapie_bundle.get("emb_tag") or col
1033
+ target_emb_tag = mapie_bundle.get("target_emb_tag", "wt")
1034
+ t_vec = self.wt_embedder.pooled(target_seq).cpu().float().numpy()
1035
+ b_vec = self._get_embedder(binder_emb_tag).pooled(binder_str).cpu().float().numpy()
1036
+ emb = np.concatenate([t_vec, b_vec], axis=1)
1037
+ else:
1038
+ emb = None
1039
+ lo, hi = _mapie_uncertainty(mapie_bundle, affinity, emb)
1040
+ out["uncertainty"] = (lo, hi)
1041
+ out["uncertainty_type"] = "conformal_prediction_interval"
1042
+ else:
1043
+ out["uncertainty"] = None
1044
+ out["uncertainty_type"] = "unavailable (no MAPIE bundle found)"
1045
+
1046
+ return out
1047
 
1048
  if __name__ == "__main__":
1049
+ root = Path(__file__).resolve().parent # current script folder
 
 
 
 
 
1050
 
1051
+ predictor = PeptiVersePredictor(
1052
+ manifest_path=root / "best_models.txt",
1053
+ classifier_weight_root=root
 
 
 
 
 
1054
  )
1055
+ print(predictor.training_root)
1056
+ print("MAPIE keys:", list(predictor.mapie.keys()))
1057
+ print("Ensemble keys:", list(predictor.ensembles.keys()))
1058
+
1059
+ seq = "GIGAVLKVLTTGLPALISWIKRKRQQ"
1060
+ smiles = "C(C)C[C@@H]1NC(=O)[C@@H]2CCCN2C(=O)[C@@H](CC(C)C)NC(=O)[C@@H](CC(C)C)N(C)C(=O)[C@H](C)NC(=O)[C@H](Cc2ccccc2)NC1=O"
1061
+
1062
+ print(predictor.predict_property("hemolysis", "wt", seq))
1063
+ print(predictor.predict_property("hemolysis", "smiles", smiles, uncertainty=True))
1064
+ print(predictor.predict_property("nf", "wt", seq, uncertainty=True))
1065
+ print(predictor.predict_property("nf", "smiles", smiles, uncertainty=True))
1066
+ print(predictor.predict_binding_affinity("wt", target_seq=seq, binder_str="GIGAVLKVLT"))
1067
+ print(predictor.predict_binding_affinity("wt", target_seq=seq, binder_str="GIGAVLKVLT", uncertainty=True))
1068
+ seq1 = "GIGAVLKVLTTGLPALISWIKRKRQQ"
1069
+ seq2 = "ACDEFGHIKLMNPQRSTVWY"
1070
 
1071
+ r1 = predictor.predict_binding_affinity("wt", target_seq=seq2, binder_str="GIGAVLKVLT", uncertainty=True)
1072
+ r2 = predictor.predict_property("nf", "wt", seq1, uncertainty=True)
1073
+ r3 = predictor.predict_property("nf", "wt", seq2, uncertainty=True)
1074
+ print(r1)
1075
+ print(r2)
1076
+ print(r3)
training_classifiers/binding_training.py CHANGED
@@ -51,8 +51,9 @@ def load_split_paired(path: str):
51
  # Collate: pooled paired
52
  # -----------------------------
53
  def collate_pair_pooled(batch):
54
- Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32) # (B,Ht)
55
- Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32) # (B,Hb)
 
56
  y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
57
  return Pt, Pb, y
58
 
@@ -147,7 +148,7 @@ class CrossAttnUnpooled(nn.Module):
147
  self.layers = nn.ModuleList([])
148
  for _ in range(n_layers):
149
  self.layers.append(nn.ModuleDict({
150
- "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
151
  "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
152
  "n1t": nn.LayerNorm(hidden),
153
  "n2t": nn.LayerNorm(hidden),
@@ -272,7 +273,8 @@ def objective_crossattn(trial: optuna.Trial, mode: str, train_ds, val_ds) -> flo
272
  # infer dims from first row
273
  if mode == "pooled":
274
  Ht = len(train_ds[0]["target_embedding"])
275
- Hb = len(train_ds[0]["binder_embedding"])
 
276
  collate = collate_pair_pooled
277
  model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
278
  train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
@@ -349,7 +351,8 @@ def run(dataset_path: str, out_dir: str, mode: str, n_trials: int = 50):
349
 
350
  if mode == "pooled":
351
  Ht = len(train_ds[0]["target_embedding"])
352
- Hb = len(train_ds[0]["binder_embedding"])
 
353
  model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
354
  collate = collate_pair_pooled
355
  train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
 
51
  # Collate: pooled paired
52
  # -----------------------------
53
  def collate_pair_pooled(batch):
54
+ binder_key = "binder_embedding" if "binder_embedding" in batch[0] else "embedding"
55
+ Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32)
56
+ Pb = torch.tensor([x[binder_key] for x in batch], dtype=torch.float32)
57
  y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
58
  return Pt, Pb, y
59
 
 
148
  self.layers = nn.ModuleList([])
149
  for _ in range(n_layers):
150
  self.layers.append(nn.ModuleDict({
151
+ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), # (B, L, H) for embeddings now
152
  "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
153
  "n1t": nn.LayerNorm(hidden),
154
  "n2t": nn.LayerNorm(hidden),
 
273
  # infer dims from first row
274
  if mode == "pooled":
275
  Ht = len(train_ds[0]["target_embedding"])
276
+ binder_key = "binder_embedding" if "binder_embedding" in train_ds.column_names else "embedding"
277
+ Hb = len(train_ds[0][binder_key])
278
  collate = collate_pair_pooled
279
  model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
280
  train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
 
351
 
352
  if mode == "pooled":
353
  Ht = len(train_ds[0]["target_embedding"])
354
+ binder_key = "binder_embedding" if "binder_embedding" in train_ds.column_names else "embedding"
355
+ Hb = len(train_ds[0][binder_key])
356
  model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
357
  collate = collate_pair_pooled
358
  train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
training_classifiers/long_aggregated.csv CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2830a2099262d9e7ffdff70bc789b74a69bb44bb3dd380d8d05b91c9d01d065a
3
- size 45506
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:513cd88f97ef4b04ef92baaec85f2a5fe255a7dd50664025b2628a4ab6d94a99
3
+ size 45539
training_classifiers/ml_uncertainty.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import numpy as np
5
+ import pandas as pd
6
+ import xgboost as xgb
7
+ from scipy import stats
8
+ from sklearn.metrics import f1_score, roc_auc_score, precision_recall_curve
9
+ from datasets import load_from_disk, DatasetDict
10
+
11
+ def best_f1_threshold(y_true, y_prob):
12
+ p, r, thr = precision_recall_curve(y_true, y_prob)
13
+ f1s = (2 * p[:-1] * r[:-1]) / (p[:-1] + r[:-1] + 1e-12)
14
+ i = int(np.nanargmax(f1s))
15
+ return float(thr[i]), float(f1s[i])
16
+
17
+
18
+ def bootstrap_ci(
19
+ y_true: np.ndarray,
20
+ y_prob: np.ndarray,
21
+ n_bootstrap: int = 2000,
22
+ ci: float = 0.95,
23
+ seed: int = 1986,
24
+ ) -> dict:
25
+ """
26
+ Non-parametric bootstrap CI for F1 (at val-optimal threshold) and AUC.
27
+ Resamples (y_true, y_prob) pairs
28
+ """
29
+ rng = np.random.default_rng(seed=seed)
30
+ n = len(y_true)
31
+
32
+ # Threshold picked on the full val set
33
+ thr, _ = best_f1_threshold(y_true, y_prob)
34
+
35
+ f1_scores, auc_scores = [], []
36
+
37
+ for _ in range(n_bootstrap):
38
+ idx = rng.integers(0, n, size=n)
39
+ yt, yp = y_true[idx], y_prob[idx]
40
+
41
+ # Skip degenerate bootstraps (only one class)
42
+ if len(np.unique(yt)) < 2:
43
+ continue
44
+
45
+ f1_scores.append(f1_score(yt, (yp >= thr).astype(int), zero_division=0))
46
+ auc_scores.append(roc_auc_score(yt, yp))
47
+
48
+ alpha = 1 - ci
49
+ lo, hi = alpha / 2, 1 - alpha / 2
50
+
51
+ results = {}
52
+ for name, arr in [("f1", f1_scores), ("auc", auc_scores)]:
53
+ arr = np.array(arr)
54
+ results[name] = {
55
+ "mean": float(arr.mean()),
56
+ "std": float(arr.std()),
57
+ "ci_low": float(np.quantile(arr, lo)),
58
+ "ci_high": float(np.quantile(arr, hi)),
59
+ "report": f"{arr.mean():.4f} [{np.quantile(arr, lo):.4f}, {np.quantile(arr, hi):.4f}]",
60
+ "n_bootstrap": len(arr),
61
+ }
62
+
63
+ results["threshold_used"] = float(thr)
64
+ results["n_samples"] = int(n)
65
+ return results
66
+
67
+ def prob_margin_uncertainty(val_preds_df: pd.DataFrame) -> pd.DataFrame:
68
+ """
69
+ Uncertainty = distance from the decision boundary in probability space.
70
+
71
+ |prob - 0.5| if = 0.0 means maximally uncertain, 0.5 means maximally confident.
72
+ Normalized to [0, 1]: confidence = 2 * |prob - 0.5|
73
+ This reflecting how far the model is from a coin-flip on given sequence.
74
+ """
75
+ df = val_preds_df.copy()
76
+ df["uncertainty"] = 1 - 2 * (df["y_prob"] - 0.5).abs() # 0=confident, 1=uncertain
77
+ df["confidence"] = 1 - df["uncertainty"] # 0=uncertain, 1=confident
78
+ return df
79
+
80
+ def save_ci_report(ci_results: dict, out_dir: str, model_name: str = ""):
81
+ os.makedirs(out_dir, exist_ok=True)
82
+ path = os.path.join(out_dir, "bootstrap_ci.json")
83
+ with open(path, "w") as f:
84
+ json.dump(ci_results, f, indent=2)
85
+
86
+ print(f"\n=== Bootstrap 95% CI ({model_name}) ===")
87
+ print(f" F1 : {ci_results['f1']['report']}")
88
+ print(f" AUC : {ci_results['auc']['report']}")
89
+ print(f" (threshold={ci_results['threshold_used']:.4f}, "
90
+ f"n_bootstrap={ci_results['f1']['n_bootstrap']}, "
91
+ f"n_val={ci_results['n_samples']})")
92
+ print(f"Saved to {path}")
93
+
94
+
95
+ def save_uncertainty_csv(df: pd.DataFrame, out_dir: str, fname: str = "val_uncertainty.csv"):
96
+ os.makedirs(out_dir, exist_ok=True)
97
+ path = os.path.join(out_dir, fname)
98
+ df.to_csv(path, index=False)
99
+ print(f"\n=== Per-molecule uncertainty ===")
100
+ print(f" Mean uncertainty : {df['uncertainty'].mean():.4f}")
101
+ print(f" Mean confidence : {df['confidence'].mean():.4f}")
102
+ print(f" Saved to {path}")
103
+
104
+ if __name__ == "__main__":
105
+ parser = argparse.ArgumentParser()
106
+ parser.add_argument("--mode", choices=["ci", "uncertainty_xgb", "uncertainty_prob"],
107
+ required=True,
108
+ help=(
109
+ "ci : bootstrap CI from val_predictions.csv (all models)\n"
110
+ "uncertainty_prob : margin uncertainty for SVM/ElasticNet/XGB"
111
+ ))
112
+ parser.add_argument("--val_preds", type=str, help="Path to val_predictions.csv")
113
+ parser.add_argument("--model_path", type=str, help="Path to best_model.json (XGB only)")
114
+ parser.add_argument("--dataset_path", type=str, help="HuggingFace dataset path (XGB uncertainty only)")
115
+ parser.add_argument("--out_dir", type=str, required=True)
116
+ parser.add_argument("--model_name", type=str, default="", help="Label for report (xgb_smiles)")
117
+ parser.add_argument("--n_bootstrap", type=int, default=2000)
118
+ args = parser.parse_args()
119
+
120
+ if args.mode == "ci":
121
+ assert args.val_preds, "--val_preds required for ci mode"
122
+ df = pd.read_csv(args.val_preds)
123
+ ci = bootstrap_ci(df["y_true"].values, df["y_prob"].values,
124
+ n_bootstrap=args.n_bootstrap)
125
+ save_ci_report(ci, args.out_dir, args.model_name)
126
+ elif args.mode == "uncertainty_prob":
127
+ assert args.val_preds, "--val_preds required for uncertainty_prob"
128
+ df_preds = pd.read_csv(args.val_preds)
129
+ # CI
130
+ ci = bootstrap_ci(df_preds["y_true"].values, df_preds["y_prob"].values,
131
+ n_bootstrap=args.n_bootstrap)
132
+ save_ci_report(ci, args.out_dir, args.model_name)
133
+ # Uncertainty from margin
134
+ df_unc = prob_margin_uncertainty(df_preds)
135
+ save_uncertainty_csv(df_unc, args.out_dir, "val_uncertainty_prob.csv")
training_classifiers/ml_uncertainty_reg.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import numpy as np
5
+ import pandas as pd
6
+ import xgboost as xgb
7
+ from scipy.stats import spearmanr
8
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
9
+ from datasets import load_from_disk, DatasetDict
10
+
11
+ def safe_spearmanr(y_true, y_pred):
12
+ rho = spearmanr(y_true, y_pred).correlation
13
+ return 0.0 if (rho is None or np.isnan(rho)) else float(rho)
14
+
15
+ def eval_regression(y_true, y_pred):
16
+ try:
17
+ from sklearn.metrics import root_mean_squared_error
18
+ rmse = float(root_mean_squared_error(y_true, y_pred))
19
+ except Exception:
20
+ rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
21
+ return {
22
+ "spearman_rho": safe_spearmanr(y_true, y_pred),
23
+ "rmse": rmse,
24
+ "mae": float(mean_absolute_error(y_true, y_pred)),
25
+ "r2": float(r2_score(y_true, y_pred)),
26
+ }
27
+
28
+ # ======================== Bootstrap CI =========================================
29
+
30
+ def bootstrap_ci_reg(
31
+ y_true: np.ndarray,
32
+ y_pred: np.ndarray,
33
+ n_bootstrap: int = 2000,
34
+ ci: float = 0.95,
35
+ seed: int = 1986,
36
+ ) -> dict:
37
+ """
38
+ Percentile bootstrap CI for regression metrics.
39
+ Uses percentile method (not t-CI) because:
40
+ - Spearman rho is bounded [-1, 1] - t-CI can produce impossible values near extremes
41
+ - RMSE is strictly positive - symmetric t-CI is inappropriate near 0
42
+ - Percentile bootstrap makes no distributional assumptions
43
+
44
+ Fisher z-transform CI for rho is also computed as a cross-check.
45
+ """
46
+ rng = np.random.default_rng(seed=seed)
47
+ n = len(y_true)
48
+ alpha = 1 - ci
49
+ lo, hi = alpha / 2, 1 - alpha / 2
50
+
51
+ boot_metrics = {k: [] for k in ["spearman_rho", "rmse", "mae", "r2"]}
52
+
53
+ for _ in range(n_bootstrap):
54
+ idx = rng.integers(0, n, size=n)
55
+ yt, yp = y_true[idx], y_pred[idx]
56
+ if len(np.unique(yt)) < 2:
57
+ continue
58
+ m = eval_regression(yt, yp)
59
+ for k in boot_metrics:
60
+ boot_metrics[k].append(m[k])
61
+
62
+ results = {}
63
+ for name, arr in boot_metrics.items():
64
+ arr = np.array(arr)
65
+ results[name] = {
66
+ "mean": float(arr.mean()),
67
+ "std": float(arr.std()),
68
+ "ci_low": float(np.quantile(arr, lo)),
69
+ "ci_high": float(np.quantile(arr, hi)),
70
+ "report": f"{arr.mean():.4f} [{np.quantile(arr, lo):.4f}, {np.quantile(arr, hi):.4f}]",
71
+ "n_bootstrap": len(arr),
72
+ }
73
+
74
+ # Fisher z-transform CI for Spearman rho (cross-check, more accurate near ±1)
75
+ rho_vals = np.array(boot_metrics["spearman_rho"])
76
+ rho_obs = safe_spearmanr(y_true, y_pred)
77
+ # z-transform: arctanh(rho), SE = 1/sqrt(n-3)
78
+ z = np.arctanh(np.clip(rho_obs, -0.9999, 0.9999))
79
+ se_z = 1.0 / np.sqrt(max(n - 3, 1))
80
+ z_lo = z - 1.96 * se_z
81
+ z_hi = z + 1.96 * se_z
82
+ results["spearman_rho"]["fisher_z_ci"] = {
83
+ "ci_low": float(np.tanh(z_lo)),
84
+ "ci_high": float(np.tanh(z_hi)),
85
+ "report": f"[{np.tanh(z_lo):.4f}, {np.tanh(z_hi):.4f}]",
86
+ "note": "Fisher z-transform CI - more accurate when rho > 0.9",
87
+ }
88
+
89
+ results["n_samples"] = int(n)
90
+ return results
91
+
92
+
93
+ def residual_uncertainty(val_preds_df: pd.DataFrame, coverage: float = 0.95) -> pd.DataFrame:
94
+ """
95
+ - Assume residuals ~ N(0, sigma) where sigma = std(residuals)
96
+ - 95% prediction interval for molecule i: y_pred_i ± z * sigma
97
+ - Uncertainty score = sigma (constant across all molecules for linear models)
98
+ - Dataset-level uncertainty
99
+ """
100
+ df = val_preds_df.copy()
101
+
102
+ residuals = df["y_true"] - df["y_pred"]
103
+ sigma = float(residuals.std(ddof=1))
104
+ z = {0.90: 1.645, 0.95: 1.960, 0.99: 2.576}.get(coverage, 1.960)
105
+ half_width = z * sigma
106
+
107
+ df["pred_interval_low"] = df["y_pred"] - half_width
108
+ df["pred_interval_high"] = df["y_pred"] + half_width
109
+ df["pred_interval_width"] = 2 * half_width # constant for linear models
110
+ df["abs_error"] = residuals.abs()
111
+
112
+ # what fraction of y_true actually falls inside the interval
113
+ empirical_coverage = float(
114
+ ((df["y_true"] >= df["pred_interval_low"]) &
115
+ (df["y_true"] <= df["pred_interval_high"])).mean()
116
+ )
117
+
118
+ meta = {
119
+ "residual_std": round(sigma, 6),
120
+ "interval_halfwidth": round(half_width, 6),
121
+ f"nominal_coverage": coverage,
122
+ "empirical_coverage": round(empirical_coverage, 4),
123
+ "note": (
124
+ "Prediction interval assumes N(0, sigma) residuals."
125
+ "Interval width is constant across molecules for linear models. "
126
+ ),
127
+ }
128
+ return df, meta
129
+
130
+ def save_ci_report(ci_results: dict, out_dir: str, model_name: str = ""):
131
+ os.makedirs(out_dir, exist_ok=True)
132
+ path = os.path.join(out_dir, "bootstrap_ci_reg.json")
133
+ with open(path, "w") as f:
134
+ json.dump(ci_results, f, indent=2)
135
+
136
+ print(f"\n=== Bootstrap 95% CI - Regression ({model_name}) ===")
137
+ for metric in ["spearman_rho", "rmse", "mae", "r2"]:
138
+ r = ci_results[metric]
139
+ print(f" {metric:15s}: {r['report']}")
140
+ if metric == "spearman_rho" and "fisher_z_ci" in r:
141
+ fz = r["fisher_z_ci"]
142
+ print(f" Fisher z CI : {fz['report']} ← use this if rho > 0.9")
143
+ print(f" n_val={ci_results['n_samples']}, n_bootstrap={ci_results['spearman_rho']['n_bootstrap']}")
144
+ print(f"Saved to {path}")
145
+
146
+ if __name__ == "__main__":
147
+ parser = argparse.ArgumentParser()
148
+ parser.add_argument("--mode", required=True,
149
+ choices=["ci", "uncertainty_residual"],
150
+ help=(
151
+ "ci : bootstrap CI from val_predictions.csv\n"
152
+ "uncertainty_residual: residual interval for ElasticNet/SVR"
153
+ ))
154
+ parser.add_argument("--val_preds", type=str, help="Path to val_predictions.csv")
155
+ parser.add_argument("--out_dir", type=str, required=True)
156
+ parser.add_argument("--model_name", type=str, default="")
157
+ parser.add_argument("--n_bootstrap", type=int, default=2000)
158
+ args = parser.parse_args()
159
+
160
+ if args.mode == "ci":
161
+ assert args.val_preds, "--val_preds required"
162
+ df = pd.read_csv(args.val_preds)
163
+ ci = bootstrap_ci_reg(df["y_true"].values, df["y_pred"].values,
164
+ n_bootstrap=args.n_bootstrap)
165
+ save_ci_report(ci, args.out_dir, args.model_name)
166
+ elif args.mode == "uncertainty_residual":
167
+ assert args.val_preds
168
+ df_preds = pd.read_csv(args.val_preds)
169
+ ci = bootstrap_ci_reg(df_preds["y_true"].values, df_preds["y_pred"].values,
170
+ n_bootstrap=args.n_bootstrap)
171
+ save_ci_report(ci, args.out_dir, args.model_name)
172
+ df_unc, meta = residual_uncertainty(df_preds)
173
+ path = os.path.join(args.out_dir, "val_uncertainty_residual.csv")
174
+ df_unc.to_csv(path, index=False)
175
+ meta_path = os.path.join(args.out_dir, "residual_interval_meta.json")
176
+ with open(meta_path, "w") as f:
177
+ json.dump(meta, f, indent=2)
178
+ print(f"\nResidual interval summary:")
179
+ print(f" Residual std : {meta['residual_std']:.4f}")
180
+ print(f" 95% interval ± {meta['interval_halfwidth']:.4f}")
181
+ print(f" Empirical coverage : {meta['empirical_coverage']:.4f} (nominal={meta['nominal_coverage']})")
182
+ print(f" Saved to {path}")
training_classifiers/refit_binding_affinity_seed.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import DataLoader
9
+ from datasets import load_from_disk, DatasetDict
10
+ from scipy.stats import spearmanr
11
+ from scipy import stats as scipy_stats
12
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
13
+ from lightning.pytorch import seed_everything
14
+ import sys
15
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
16
+ from binding_training import (
17
+ CrossAttnPooled,
18
+ CrossAttnUnpooled,
19
+ collate_pair_pooled,
20
+ collate_pair_unpooled,
21
+ eval_spearman_pooled,
22
+ eval_spearman_unpooled,
23
+ train_one_epoch_pooled,
24
+ train_one_epoch_unpooled,
25
+ affinity_to_class_tensor,
26
+ safe_spearmanr,
27
+ )
28
+
29
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30
+
31
+ def load_split_paired(path: str):
32
+ dd = load_from_disk(path)
33
+ if not isinstance(dd, DatasetDict):
34
+ raise ValueError(f"Expected DatasetDict at {path}")
35
+ return dd["train"], dd["val"]
36
+
37
+
38
+ def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> dict:
39
+ try:
40
+ from sklearn.metrics import root_mean_squared_error
41
+ rmse = float(root_mean_squared_error(y_true, y_pred))
42
+ except Exception:
43
+ rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
44
+ return {
45
+ "spearman_rho": safe_spearmanr(y_true, y_pred),
46
+ "rmse": rmse,
47
+ "mae": float(mean_absolute_error(y_true, y_pred)),
48
+ "r2": float(r2_score(y_true, y_pred)),
49
+ }
50
+
51
+ @torch.no_grad()
52
+ def predict_all_pooled(model, loader):
53
+ model.eval()
54
+ ys, ps = [], []
55
+ for t, b, y in loader:
56
+ t = t.to(DEVICE, non_blocking=True)
57
+ b = b.to(DEVICE, non_blocking=True)
58
+ pred, _ = model(t, b)
59
+ ys.append(y.numpy())
60
+ ps.append(pred.detach().cpu().numpy())
61
+ return np.concatenate(ys), np.concatenate(ps)
62
+
63
+
64
+ @torch.no_grad()
65
+ def predict_all_unpooled(model, loader):
66
+ model.eval()
67
+ ys, ps = [], []
68
+ for T, Mt, B, Mb, y in loader:
69
+ T = T.to(DEVICE, non_blocking=True)
70
+ Mt = Mt.to(DEVICE, non_blocking=True)
71
+ B = B.to(DEVICE, non_blocking=True)
72
+ Mb = Mb.to(DEVICE, non_blocking=True)
73
+ pred, _ = model(T, Mt, B, Mb)
74
+ ys.append(y.numpy())
75
+ ps.append(pred.detach().cpu().numpy())
76
+ return np.concatenate(ys), np.concatenate(ps)
77
+
78
+
79
+ def build_model(mode: str, params: dict, train_ds) -> nn.Module:
80
+ hidden = int(params["hidden_dim"])
81
+ n_heads = int(params["n_heads"])
82
+ n_layers = int(params["n_layers"])
83
+ dropout = float(params["dropout"])
84
+
85
+ binder_key = "embedding" if "binder_embedding" not in train_ds.column_names else "binder_embedding"
86
+
87
+ if mode == "pooled":
88
+ Ht = len(train_ds[0]["target_embedding"])
89
+ Hb = len(train_ds[0][binder_key])
90
+ return CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads,
91
+ n_layers=n_layers, dropout=dropout).to(DEVICE)
92
+ else:
93
+ Ht = len(train_ds[0]["target_embedding"][0])
94
+ Hb = len(train_ds[0]["binder_embedding"][0])
95
+ return CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads,
96
+ n_layers=n_layers, dropout=dropout).to(DEVICE)
97
+
98
+
99
+ # Refit
100
+ def refit_with_seed(dataset_path: str, base_out_dir: str, mode: str,
101
+ seed: int, patience: int = 20) -> dict:
102
+ model_path = os.path.join(base_out_dir, "best_model.pt")
103
+ if not os.path.exists(model_path):
104
+ raise FileNotFoundError(
105
+ f"No best_model.pt found at {model_path}. Run Optuna (binding_training.py) first."
106
+ )
107
+
108
+ checkpoint = torch.load(model_path, map_location="cpu")
109
+ best_params = checkpoint["best_params"]
110
+ print(f"Loaded best_params from {model_path}")
111
+ print(json.dumps(best_params, indent=2))
112
+
113
+ seed_everything(seed)
114
+ out_dir = os.path.join(base_out_dir, f"seed_{seed}")
115
+ os.makedirs(out_dir, exist_ok=True)
116
+
117
+ train_ds, val_ds = load_split_paired(dataset_path)
118
+ print(f"[Data] Train={len(train_ds)} Val={len(val_ds)} mode={mode}")
119
+
120
+ batch = int(best_params["batch_size"])
121
+ cls_w = float(best_params["cls_weight"])
122
+
123
+ if mode == "pooled":
124
+ collate = collate_pair_pooled
125
+ eval_fn = eval_spearman_pooled
126
+ train_fn = train_one_epoch_pooled
127
+ predict = predict_all_pooled
128
+ else:
129
+ collate = collate_pair_unpooled
130
+ eval_fn = eval_spearman_unpooled
131
+ train_fn = train_one_epoch_unpooled
132
+ predict = predict_all_unpooled
133
+
134
+ train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True,
135
+ num_workers=4, pin_memory=True, collate_fn=collate)
136
+ val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False,
137
+ num_workers=4, pin_memory=True, collate_fn=collate)
138
+
139
+ model = build_model(mode, best_params, train_ds)
140
+ opt = torch.optim.AdamW(model.parameters(),
141
+ lr=float(best_params["lr"]),
142
+ weight_decay=float(best_params["weight_decay"]))
143
+ loss_reg = nn.MSELoss()
144
+ loss_cls = nn.CrossEntropyLoss()
145
+
146
+ best_rho, bad, best_state = -1e9, 0, None
147
+
148
+ for epoch in range(1, 201):
149
+ train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
150
+ rho = eval_fn(model, val_loader)
151
+
152
+ if rho > best_rho + 1e-6:
153
+ best_rho = rho
154
+ bad = 0
155
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
156
+ else:
157
+ bad += 1
158
+ if bad >= patience:
159
+ print(f" Early stopping at epoch {epoch} (best rho={best_rho:.4f})")
160
+ break
161
+
162
+ if best_state:
163
+ model.load_state_dict(best_state)
164
+
165
+ y_true, y_pred = predict(model, val_loader)
166
+ metrics = eval_regression(y_true, y_pred)
167
+
168
+ # Save predictions
169
+ df_val = pd.DataFrame({
170
+ "y_true": y_true.astype(float),
171
+ "y_pred": y_pred.astype(float),
172
+ "residual": (y_true - y_pred).astype(float),
173
+ "abs_error": np.abs(y_true - y_pred).astype(float),
174
+ })
175
+ for col in ("target_sequence", "sequence", "affinity_class"):
176
+ if col in val_ds.column_names:
177
+ df_val.insert(0, col, np.asarray(val_ds[col]))
178
+ df_val.to_csv(os.path.join(out_dir, "val_predictions.csv"), index=False)
179
+
180
+ torch.save({"state_dict": model.state_dict(),
181
+ "best_params": best_params,
182
+ "mode": mode,
183
+ "seed": seed},
184
+ os.path.join(out_dir, "model.pt"))
185
+
186
+ summary = {"mode": mode, "seed": seed,
187
+ **{k: round(v, 6) for k, v in metrics.items()}}
188
+ with open(os.path.join(out_dir, "metrics.json"), "w") as f:
189
+ json.dump(summary, f, indent=2)
190
+
191
+ print(f"\n[Seed {seed}] rho={metrics['spearman_rho']:.4f} "
192
+ f"RMSE={metrics['rmse']:.4f} R2={metrics['r2']:.4f}")
193
+ return summary
194
+
195
+
196
+ # CI aggregation
197
+
198
+ def aggregate_seed_results(base_out_dir: str, seeds: list) -> pd.DataFrame:
199
+ records = []
200
+ for seed in seeds:
201
+ p = os.path.join(base_out_dir, f"seed_{seed}", "metrics.json")
202
+ if os.path.exists(p):
203
+ records.append(json.load(open(p)))
204
+ else:
205
+ print(f"[WARN] Missing seed {seed} at {p}")
206
+
207
+ if not records:
208
+ raise ValueError("No seed results found — did the refit jobs complete?")
209
+
210
+ df = pd.DataFrame(records)
211
+ print("\nPer-seed results:")
212
+ print(df.to_string(index=False))
213
+
214
+ summary_rows = []
215
+ for metric in ["spearman_rho", "rmse", "mae", "r2"]:
216
+ vals = df[metric].values
217
+ n = len(vals)
218
+ mean = vals.mean()
219
+ std = vals.std(ddof=1)
220
+ se = std / np.sqrt(n)
221
+ t_crit = scipy_stats.t.ppf(0.975, df=n - 1)
222
+ ci = t_crit * se
223
+ row = {
224
+ "metric": metric,
225
+ "mean": round(mean, 4),
226
+ "std": round(std, 4),
227
+ "ci_95": round(ci, 4),
228
+ "report": f"{mean:.4f} ± {ci:.4f}",
229
+ "n_seeds": n,
230
+ }
231
+ if metric == "spearman_rho" and (mean + ci > 0.95 or mean - ci < -0.95):
232
+ row["note"] = "rho near boundary — consider Fisher z-transform CI"
233
+ summary_rows.append(row)
234
+
235
+ summary_df = pd.DataFrame(summary_rows)
236
+ out_path = os.path.join(base_out_dir, "seed_aggregated_metrics.csv")
237
+ summary_df.to_csv(out_path, index=False)
238
+
239
+ print("\n=== Aggregated Metrics (95% CI, t-distribution) ===")
240
+ for _, row in summary_df.iterrows():
241
+ note = f" ← {row['note']}" if "note" in row and pd.notna(row.get("note")) else ""
242
+ print(f" {row['metric']:15s}: {row['report']}{note}")
243
+ print(f"\nSaved → {out_path}")
244
+ return summary_df
245
+
246
+
247
+ if __name__ == "__main__":
248
+ parser = argparse.ArgumentParser()
249
+ parser.add_argument("--dataset_path", type=str, required=True,
250
+ help="Paired DatasetDict path")
251
+ parser.add_argument("--base_out_dir", type=str, required=True,
252
+ help="Directory containing best_model.pt from the Optuna run")
253
+ parser.add_argument("--mode", type=str, required=True)
254
+ parser.add_argument("--seed", type=int, required=True)
255
+ parser.add_argument("--patience", type=int, default=20)
256
+ parser.add_argument("--aggregate", action="store_true",
257
+ help="Aggregate across seed runs instead of training")
258
+ parser.add_argument("--all_seeds", type=int, nargs="+", default=[1986, 42, 0, 123, 12345])
259
+ args = parser.parse_args()
260
+
261
+ if args.aggregate:
262
+ aggregate_seed_results(args.base_out_dir, args.all_seeds)
263
+ else:
264
+ refit_with_seed(
265
+ dataset_path=args.dataset_path,
266
+ base_out_dir=args.base_out_dir,
267
+ mode=args.mode,
268
+ seed=args.seed,
269
+ patience=args.patience,
270
+ )
training_classifiers/refit_ml_walltime.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loads best params from optimization_summary.txt, refits the model once on the
3
+ train split, and appends a wall-time record to wall_clock_ml.jsonl.
4
+ """
5
+ import json
6
+ import time
7
+ import joblib
8
+ import argparse
9
+ import re
10
+ import numpy as np
11
+ from pathlib import Path
12
+ from datetime import datetime
13
+ # Classification trainers
14
+ from train_ml import (
15
+ load_split_data as load_split_cls,
16
+ train_cuml_svc,
17
+ train_cuml_elastic_net,
18
+ train_xgb,
19
+ train_svm,
20
+ )
21
+ # Regression trainers
22
+ from train_ml_regression import (
23
+ load_split_data as load_split_reg,
24
+ train_cuml_elasticnet_reg,
25
+ train_svr_reg,
26
+ train_xgb_reg,
27
+ )
28
+
29
+ MODEL_FILE_MAP = [
30
+ ("best_model_cuml_svc.joblib", "svm_gpu", "classification"),
31
+ ("best_model_cuml_enet.joblib", "enet_gpu", "auto"),
32
+ ("best_model_svr.joblib", "svr", "regression"),
33
+ ("best_model.joblib", "svm", "classification"),
34
+ ("best_model.json", "xgb", "auto"),
35
+ ]
36
+
37
+ def detect_model_type(model_dir: Path) -> tuple:
38
+ """Returns (model_type, task)."""
39
+ for fname, model_type, task in MODEL_FILE_MAP:
40
+ if (model_dir / fname).exists():
41
+ if task == "auto":
42
+ if (model_dir / "scaler.joblib").exists():
43
+ task = "regression"
44
+ if model_type == "xgb":
45
+ model_type = "xgb_reg"
46
+ else:
47
+ task = "classification"
48
+ return model_type, task
49
+ raise FileNotFoundError(
50
+ f"No recognised model file in {model_dir}. "
51
+ f"Expected one of: {[f for f, _, _ in MODEL_FILE_MAP]}"
52
+ )
53
+
54
+
55
+ def parse_best_params(model_dir: Path) -> dict:
56
+ """
57
+ Extracts the JSON block after 'Best params:' in optimization_summary.txt.
58
+ """
59
+ summary_path = model_dir / "optimization_summary.txt"
60
+ if not summary_path.exists():
61
+ raise FileNotFoundError(f"optimization_summary.txt not found in {model_dir}")
62
+
63
+ text = summary_path.read_text()
64
+ match = re.search(r"Best params:\s*(\{.*?\})\s*={10,}", text, re.DOTALL)
65
+ if not match:
66
+ raise ValueError(
67
+ f"Could not find 'Best params:' JSON block in {summary_path}.\n"
68
+ f"File contents:\n{text}"
69
+ )
70
+ return json.loads(match.group(1))
71
+
72
+ def parse_objective_and_wt(model_dir: Path) -> tuple:
73
+ """
74
+ Expects layout: .../training_classifiers/<objective>/<model>_<wt>/
75
+ Example: hemolysis/svm_gpu_smiles -> objective=hemolysis, wt=smiles
76
+ """
77
+ parts = model_dir.parts
78
+ model_folder = parts[-1].lower()
79
+ objective = parts[-2]
80
+
81
+ for suffix, wt in [("_chemberta", "chemberta"), ("_smiles", "smiles"), ("_wt", "wt")]:
82
+ if model_folder.endswith(suffix):
83
+ return objective, wt
84
+ return objective, "wt"
85
+
86
+ def refit_and_time(model_dir: Path, dataset_path: str) -> tuple:
87
+ model_type, task = detect_model_type(model_dir)
88
+ best_params = parse_best_params(model_dir)
89
+
90
+ print(f" Model type : {model_type} ({task})")
91
+ print(f" Best params: {best_params}")
92
+
93
+ # Load scaler if present (regression models)
94
+ scaler_path = model_dir / "scaler.joblib"
95
+ scaler = joblib.load(scaler_path) if scaler_path.exists() else None
96
+
97
+ load_fn = load_split_reg if task == "regression" else load_split_cls
98
+ data = load_fn(dataset_path)
99
+ print(f" Train: {data.X_train.shape} Val: {data.X_val.shape}")
100
+
101
+ # Build params
102
+ if model_type == "xgb":
103
+ params = {
104
+ "objective": "binary:logistic",
105
+ "eval_metric": "logloss",
106
+ "lambda": best_params["lambda"],
107
+ "alpha": best_params["alpha"],
108
+ "colsample_bytree": best_params["colsample_bytree"],
109
+ "subsample": best_params["subsample"],
110
+ "learning_rate": best_params["learning_rate"],
111
+ "max_depth": best_params["max_depth"],
112
+ "min_child_weight": best_params["min_child_weight"],
113
+ "gamma": best_params["gamma"],
114
+ "tree_method": "hist",
115
+ "device": "cuda",
116
+ "num_boost_round": best_params["num_boost_round"],
117
+ "early_stopping_rounds": best_params["early_stopping_rounds"],
118
+ }
119
+ train_fn = train_xgb
120
+
121
+ elif model_type == "xgb_reg":
122
+ params = {
123
+ "objective": "reg:squarederror",
124
+ "eval_metric": "rmse",
125
+ "lambda": best_params["lambda"],
126
+ "alpha": best_params["alpha"],
127
+ "gamma": best_params["gamma"],
128
+ "max_depth": best_params["max_depth"],
129
+ "min_child_weight": best_params["min_child_weight"],
130
+ "subsample": best_params["subsample"],
131
+ "colsample_bytree": best_params["colsample_bytree"],
132
+ "learning_rate": best_params["learning_rate"],
133
+ "tree_method": "hist",
134
+ "device": "cuda",
135
+ "num_boost_round": best_params["num_boost_round"],
136
+ "early_stopping_rounds": best_params["early_stopping_rounds"],
137
+ }
138
+ train_fn = train_xgb_reg
139
+
140
+ elif model_type == "svm_gpu":
141
+ params = best_params
142
+ train_fn = train_cuml_svc
143
+
144
+ elif model_type == "enet_gpu" and task == "classification":
145
+ params = best_params
146
+ train_fn = train_cuml_elastic_net
147
+
148
+ elif model_type == "enet_gpu" and task == "regression":
149
+ params = best_params
150
+ train_fn = train_cuml_elasticnet_reg
151
+
152
+ elif model_type == "svm":
153
+ params = best_params
154
+ train_fn = train_svm
155
+
156
+ elif model_type == "svr":
157
+ params = best_params
158
+ train_fn = train_svr_reg
159
+
160
+ else:
161
+ raise ValueError(f"Unhandled model_type={model_type}, task={task}")
162
+
163
+ # Timed block
164
+ t0 = time.perf_counter()
165
+
166
+ X_train = data.X_train
167
+ X_val = data.X_val
168
+ if scaler is not None:
169
+ X_train = scaler.transform(X_train).astype(np.float32)
170
+ X_val = scaler.transform(X_val).astype(np.float32)
171
+
172
+ train_fn(X_train, data.y_train, X_val, data.y_val, params)
173
+
174
+ wall_s = time.perf_counter() - t0
175
+ print(f" Wall time: {wall_s:.1f}s")
176
+ return wall_s, model_type
177
+
178
+ def write_wall_time(logs_dir: Path, objective: str, wt: str,
179
+ model_type: str, wall_s: float):
180
+ logs_dir.mkdir(parents=True, exist_ok=True)
181
+ date_str = datetime.now().strftime("%m_%d")
182
+ jsonl_path = logs_dir / f"{date_str}_wall_clock_ml.jsonl"
183
+
184
+ record = {
185
+ "model": model_type,
186
+ "objective": objective,
187
+ "wt": wt,
188
+ "wall_s": round(wall_s),
189
+ }
190
+ with open(jsonl_path, "a") as f:
191
+ f.write(json.dumps(record) + "\n")
192
+ print(f" Appended to {jsonl_path}: {record}")
193
+
194
+ if __name__ == "__main__":
195
+ parser = argparse.ArgumentParser()
196
+ parser.add_argument("--model_dir", type=str, required=True,
197
+ help="e.g. .../hemolysis/svm_gpu_smiles")
198
+ parser.add_argument("--dataset_path", type=str, required=True,
199
+ help="HuggingFace dataset path for this objective/embedding")
200
+ parser.add_argument("--logs_dir", type=str, required=True,
201
+ help="Directory to write *_wall_clock_ml.jsonl")
202
+ args = parser.parse_args()
203
+
204
+ model_dir = Path(args.model_dir)
205
+ objective, wt = parse_objective_and_wt(model_dir)
206
+ print(f"\nObjective: {objective} Embedding: {wt}")
207
+
208
+ wall_s, model_type = refit_and_time(model_dir, args.dataset_path)
209
+ write_wall_time(Path(args.logs_dir), objective, wt, model_type, wall_s)
training_classifiers/refit_nn_seed.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from datasets import load_from_disk, DatasetDict
5
+ from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score
6
+ import torch.nn as nn
7
+ import os
8
+ import json
9
+ import pandas as pd
10
+ import argparse
11
+ from typing import Optional
12
+ from lightning.pytorch import seed_everything
13
+
14
+ def infer_in_dim_from_unpooled_ds(ds) -> int:
15
+ ex = ds[0]
16
+ return int(len(ex["embedding"][0]))
17
+
18
+ def load_split(dataset_path):
19
+ ds = load_from_disk(dataset_path)
20
+ if isinstance(ds, DatasetDict):
21
+ return ds["train"], ds["val"]
22
+ raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
23
+
24
+ def collate_unpooled(batch):
25
+ lengths = [int(x["length"]) for x in batch]
26
+ Lmax = max(lengths)
27
+ H = len(batch[0]["embedding"][0])
28
+
29
+ X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
30
+ M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
31
+ y = torch.tensor([x["label"] for x in batch], dtype=torch.float32)
32
+
33
+ for i, x in enumerate(batch):
34
+ emb = torch.tensor(x["embedding"], dtype=torch.float32)
35
+ L = emb.shape[0]
36
+ X[i, :L] = emb
37
+ if "attention_mask" in x:
38
+ m = torch.tensor(x["attention_mask"], dtype=torch.bool)
39
+ M[i, :L] = m[:L]
40
+ else:
41
+ M[i, :L] = True
42
+
43
+ return X, M, y
44
+
45
+ # ======================== Models =========================================
46
+
47
+ class MaskedMeanPool(nn.Module):
48
+ def forward(self, X, M):
49
+ Mf = M.unsqueeze(-1).float()
50
+ denom = Mf.sum(dim=1).clamp(min=1.0)
51
+ return (X * Mf).sum(dim=1) / denom
52
+
53
+ class MLPClassifier(nn.Module):
54
+ def __init__(self, in_dim, hidden=512, dropout=0.1):
55
+ super().__init__()
56
+ self.pool = MaskedMeanPool()
57
+ self.net = nn.Sequential(
58
+ nn.Linear(in_dim, hidden),
59
+ nn.GELU(),
60
+ nn.Dropout(dropout),
61
+ nn.Linear(hidden, 1),
62
+ )
63
+ def forward(self, X, M):
64
+ return self.net(self.pool(X, M)).squeeze(-1)
65
+
66
+ class CNNClassifier(nn.Module):
67
+ def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
68
+ super().__init__()
69
+ blocks, ch = [], in_ch
70
+ for _ in range(layers):
71
+ blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)]
72
+ ch = c
73
+ self.conv = nn.Sequential(*blocks)
74
+ self.head = nn.Linear(c, 1)
75
+
76
+ def forward(self, X, M):
77
+ Y = self.conv(X.transpose(1, 2)).transpose(1, 2)
78
+ Mf = M.unsqueeze(-1).float()
79
+ pooled = (Y * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
80
+ return self.head(pooled).squeeze(-1)
81
+
82
+ class TransformerClassifier(nn.Module):
83
+ def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
84
+ super().__init__()
85
+ self.proj = nn.Linear(in_dim, d_model)
86
+ enc_layer = nn.TransformerEncoderLayer(
87
+ d_model=d_model, nhead=nhead, dim_feedforward=ff,
88
+ dropout=dropout, batch_first=True, activation="gelu"
89
+ )
90
+ self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
91
+ self.head = nn.Linear(d_model, 1)
92
+
93
+ def forward(self, X, M):
94
+ Z = self.enc(self.proj(X), src_key_padding_mask=~M)
95
+ Mf = M.unsqueeze(-1).float()
96
+ pooled = (Z * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
97
+ return self.head(pooled).squeeze(-1)
98
+
99
+ # ======================== Training utils =========================================
100
+
101
+ def best_f1_threshold(y_true, y_prob):
102
+ p, r, thr = precision_recall_curve(y_true, y_prob)
103
+ f1s = (2 * p[:-1] * r[:-1]) / (p[:-1] + r[:-1] + 1e-12)
104
+ i = int(np.nanargmax(f1s))
105
+ return float(thr[i]), float(f1s[i])
106
+
107
+ @torch.no_grad()
108
+ def eval_probs(model, loader, device):
109
+ model.eval()
110
+ ys, ps = [], []
111
+ for X, M, y in loader:
112
+ X, M = X.to(device), M.to(device)
113
+ ps.append(torch.sigmoid(model(X, M)).cpu().numpy())
114
+ ys.append(y.numpy())
115
+ return np.concatenate(ys), np.concatenate(ps)
116
+
117
+ def train_one_epoch(model, loader, optim, criterion, device):
118
+ model.train()
119
+ for X, M, y in loader:
120
+ X, M, y = X.to(device), M.to(device), y.to(device)
121
+ optim.zero_grad(set_to_none=True)
122
+ criterion(model(X, M), y).backward()
123
+ optim.step()
124
+
125
+ def build_model(model_name, in_dim, params):
126
+ dropout = float(params.get("dropout", 0.1))
127
+ if model_name == "mlp":
128
+ return MLPClassifier(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout)
129
+ elif model_name == "cnn":
130
+ return CNNClassifier(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
131
+ layers=int(params["layers"]), dropout=dropout)
132
+ elif model_name == "transformer":
133
+ return TransformerClassifier(in_dim=in_dim, d_model=int(params["d_model"]),
134
+ nhead=int(params["nhead"]), layers=int(params["layers"]),
135
+ ff=int(params["ff"]), dropout=dropout)
136
+ raise ValueError(model_name)
137
+
138
+ # ======================== Main refit =========================================
139
+
140
+ def refit_with_seed(dataset_path, base_out_dir, model_name, seed, device="cuda:0"):
141
+ """
142
+ Loads best_params from base_out_dir/best_model.pt (saved by original Optuna run),
143
+ retrains with the given seed, saves results to base_out_dir/seed_{seed}/.
144
+ """
145
+ # Load best params from completed Optuna run
146
+ model_path = os.path.join(base_out_dir, "best_model.pt")
147
+ if not os.path.exists(model_path):
148
+ raise FileNotFoundError(f"No best_model.pt found at {model_path}. Run Optuna first.")
149
+
150
+ checkpoint = torch.load(model_path, map_location="cpu")
151
+ best_params = checkpoint["best_params"]
152
+ print(f"Loaded best_params from {model_path}")
153
+ print(json.dumps(best_params, indent=2))
154
+
155
+ # Seed
156
+ seed_everything(seed)
157
+
158
+ out_dir = os.path.join(base_out_dir, f"seed_{seed}")
159
+ os.makedirs(out_dir, exist_ok=True)
160
+
161
+ # Data import
162
+ train_ds, val_ds = load_split(dataset_path)
163
+ print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
164
+
165
+ batch_size = int(best_params.get("batch_size", 32))
166
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
167
+ collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
168
+ val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
169
+ collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
170
+
171
+ in_dim = infer_in_dim_from_unpooled_ds(train_ds)
172
+ model = build_model(model_name, in_dim, best_params).to(device)
173
+
174
+ # Loss
175
+ ytr = np.asarray(train_ds["label"], dtype=np.int64)
176
+ pos, neg = ytr.sum(), len(ytr) - ytr.sum()
177
+ pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32)
178
+ criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
179
+
180
+ optim = torch.optim.AdamW(model.parameters(),
181
+ lr=float(best_params["lr"]),
182
+ weight_decay=float(best_params["weight_decay"]))
183
+
184
+ # Training loop with early stopping
185
+ best_f1, best_thr, bad, patience = -1.0, 0.5, 0, 12
186
+ best_state = None
187
+
188
+ for epoch in range(1, 151):
189
+ train_one_epoch(model, train_loader, optim, criterion, device)
190
+ y_true, y_prob = eval_probs(model, val_loader, device)
191
+ thr, f1 = best_f1_threshold(y_true, y_prob)
192
+
193
+ if f1 > best_f1 + 1e-4:
194
+ best_f1 = f1
195
+ best_thr = thr
196
+ bad = 0
197
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
198
+ else:
199
+ bad += 1
200
+ if bad >= patience:
201
+ print(f"Early stopping at epoch {epoch}")
202
+ break
203
+
204
+ if best_state is not None:
205
+ model.load_state_dict(best_state)
206
+
207
+ # Final eval
208
+ y_true_val, y_prob_val = eval_probs(model, val_loader, device)
209
+ best_thr_final, best_f1_final = best_f1_threshold(y_true_val, y_prob_val)
210
+ auc_final = roc_auc_score(y_true_val, y_prob_val)
211
+
212
+ # Save
213
+ df_val = pd.DataFrame({
214
+ "y_true": y_true_val.astype(int),
215
+ "y_prob": y_prob_val.astype(float),
216
+ "y_pred": (y_prob_val >= best_thr_final).astype(int),
217
+ })
218
+ if "sequence" in val_ds.column_names:
219
+ df_val.insert(0, "sequence", np.asarray(val_ds["sequence"]))
220
+ df_val.to_csv(os.path.join(out_dir, "val_predictions.csv"), index=False)
221
+
222
+ torch.save({"state_dict": model.state_dict(), "best_params": best_params, "seed": seed},
223
+ os.path.join(out_dir, "model.pt"))
224
+
225
+ summary = {
226
+ "model": model_name,
227
+ "seed": seed,
228
+ "val_f1": round(best_f1_final, 6),
229
+ "val_auc": round(auc_final, 6),
230
+ "val_thr": round(best_thr_final, 6),
231
+ }
232
+ with open(os.path.join(out_dir, "metrics.json"), "w") as f:
233
+ json.dump(summary, f, indent=2)
234
+
235
+ print(f"\n[Seed {seed}] F1={best_f1_final:.4f} AUC={auc_final:.4f} thr={best_thr_final:.4f}")
236
+ print(f"Saved to {out_dir}")
237
+ return summary
238
+
239
+
240
+ # ======================== CI aggregation =========================================
241
+
242
+ def aggregate_seed_results(base_out_dir, seeds):
243
+ """
244
+ Call after all seed runs finish to compute mean ± 95% CI across seeds.
245
+ Saves a summary CSV to base_out_dir/seed_aggregated_metrics.csv
246
+ """
247
+ from scipy import stats
248
+
249
+ records = []
250
+ for seed in seeds:
251
+ p = os.path.join(base_out_dir, f"seed_{seed}", "metrics.json")
252
+ if os.path.exists(p):
253
+ records.append(json.load(open(p)))
254
+ else:
255
+ print(f"Warning: missing seed {seed} at {p}")
256
+
257
+ if not records:
258
+ raise ValueError("No seed results found.")
259
+
260
+ df = pd.DataFrame(records)
261
+ print("\nPer-seed results:")
262
+ print(df.to_string(index=False))
263
+
264
+ summary_rows = []
265
+ for metric in ["val_f1", "val_auc"]:
266
+ vals = df[metric].values
267
+ n = len(vals)
268
+ mean = vals.mean()
269
+ std = vals.std(ddof=1)
270
+ se = std / np.sqrt(n)
271
+ t_crit = stats.t.ppf(0.975, df=n - 1)
272
+ ci = t_crit * se
273
+ summary_rows.append({
274
+ "metric": metric,
275
+ "mean": round(mean, 4),
276
+ "std": round(std, 4),
277
+ "ci_95": round(ci, 4),
278
+ "report": f"{mean:.4f} ± {ci:.4f}",
279
+ "n_seeds": n,
280
+ })
281
+
282
+ summary_df = pd.DataFrame(summary_rows)
283
+ out_path = os.path.join(base_out_dir, "seed_aggregated_metrics.csv")
284
+ summary_df.to_csv(out_path, index=False)
285
+
286
+ print("\n=== Aggregated Metrics (95% CI) ===")
287
+ for _, row in summary_df.iterrows():
288
+ print(f" {row['metric']:12s}: {row['report']} (n={row['n_seeds']})")
289
+ print(f"\nSaved to {out_path}")
290
+ return summary_df
291
+
292
+
293
+ if __name__ == "__main__":
294
+ parser = argparse.ArgumentParser()
295
+ parser.add_argument("--dataset_path", type=str, required=True)
296
+ parser.add_argument("--base_out_dir", type=str, required=True,
297
+ help="Directory containing best_model.pt from Optuna run")
298
+ parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True)
299
+ parser.add_argument("--seed", type=int, required=True,
300
+ help="Training seed for this run (1986, 42, 0, 123, 12345)")
301
+ parser.add_argument("--aggregate", action="store_true",
302
+ help="After all seeds done: aggregate results into CI summary")
303
+ parser.add_argument("--all_seeds", type=int, nargs="+", default=[1986, 42, 0, 123, 12345],
304
+ help="All seeds to aggregate (used with --aggregate)")
305
+ args = parser.parse_args()
306
+
307
+ if args.aggregate:
308
+ aggregate_seed_results(args.base_out_dir, args.all_seeds)
309
+ else:
310
+ refit_with_seed(
311
+ dataset_path=args.dataset_path,
312
+ base_out_dir=args.base_out_dir,
313
+ model_name=args.model,
314
+ seed=args.seed,
315
+ )
training_classifiers/refit_regression_seed.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import DataLoader
9
+ from torch.cuda.amp import autocast, GradScaler
10
+ from datasets import load_from_disk, DatasetDict
11
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
12
+ from scipy.stats import spearmanr
13
+ from lightning.pytorch import seed_everything
14
+ from typing import Dict, Optional
15
+
16
+ scaler_amp = GradScaler(enabled=torch.cuda.is_available())
17
+
18
+ def load_split(dataset_path):
19
+ ds = load_from_disk(dataset_path)
20
+ if isinstance(ds, DatasetDict):
21
+ return ds["train"], ds["val"]
22
+ raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
23
+
24
+ def infer_in_dim(ds) -> int:
25
+ return int(len(ds[0]["embedding"][0]))
26
+
27
+ def collate_unpooled_reg(batch):
28
+ lengths = [int(x["length"]) for x in batch]
29
+ Lmax = max(lengths)
30
+ H = len(batch[0]["embedding"][0])
31
+
32
+ X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
33
+ M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
34
+ y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
35
+
36
+ for i, x in enumerate(batch):
37
+ emb = torch.tensor(x["embedding"], dtype=torch.float32)
38
+ L = emb.shape[0]
39
+ X[i, :L] = emb
40
+ if "attention_mask" in x:
41
+ m = torch.tensor(x["attention_mask"], dtype=torch.bool)
42
+ M[i, :L] = m[:L]
43
+ else:
44
+ M[i, :L] = True
45
+ return X, M, y
46
+
47
+ # ======================== Models =========================================
48
+
49
+ class MaskedMeanPool(nn.Module):
50
+ def forward(self, X, M):
51
+ Mf = M.unsqueeze(-1).float()
52
+ return (X * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
53
+
54
+ class MLPRegressor(nn.Module):
55
+ def __init__(self, in_dim, hidden=512, dropout=0.1):
56
+ super().__init__()
57
+ self.pool = MaskedMeanPool()
58
+ self.net = nn.Sequential(
59
+ nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, 1)
60
+ )
61
+ def forward(self, X, M):
62
+ return self.net(self.pool(X, M)).squeeze(-1)
63
+
64
+ class CNNRegressor(nn.Module):
65
+ def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
66
+ super().__init__()
67
+ blocks, ch = [], in_ch
68
+ for _ in range(layers):
69
+ blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)]
70
+ ch = c
71
+ self.conv = nn.Sequential(*blocks)
72
+ self.head = nn.Linear(c, 1)
73
+ def forward(self, X, M):
74
+ Y = self.conv(X.transpose(1, 2)).transpose(1, 2)
75
+ Mf = M.unsqueeze(-1).float()
76
+ return self.head((Y * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)).squeeze(-1)
77
+
78
+ class TransformerRegressor(nn.Module):
79
+ def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
80
+ super().__init__()
81
+ self.proj = nn.Linear(in_dim, d_model)
82
+ self.enc = nn.TransformerEncoder(
83
+ nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=ff,
84
+ dropout=dropout, batch_first=True, activation="gelu"),
85
+ num_layers=layers
86
+ )
87
+ self.head = nn.Linear(d_model, 1)
88
+ def forward(self, X, M):
89
+ Z = self.enc(self.proj(X), src_key_padding_mask=~M)
90
+ Mf = M.unsqueeze(-1).float()
91
+ return self.head((Z * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)).squeeze(-1)
92
+
93
+ # ======================== utils =========================================
94
+
95
+ def safe_spearmanr(y_true, y_pred):
96
+ rho = spearmanr(y_true, y_pred).correlation
97
+ return 0.0 if (rho is None or np.isnan(rho)) else float(rho)
98
+
99
+ def eval_regression(y_true, y_pred) -> Dict[str, float]:
100
+ try:
101
+ from sklearn.metrics import root_mean_squared_error
102
+ rmse = float(root_mean_squared_error(y_true, y_pred))
103
+ except Exception:
104
+ rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
105
+ return {
106
+ "spearman_rho": safe_spearmanr(y_true, y_pred),
107
+ "rmse": rmse,
108
+ "mae": float(mean_absolute_error(y_true, y_pred)),
109
+ "r2": float(r2_score(y_true, y_pred)),
110
+ }
111
+
112
+ def score_from_metrics(metrics, objective):
113
+ return {"spearman": metrics["spearman_rho"],
114
+ "neg_rmse": -metrics["rmse"],
115
+ "r2": metrics["r2"]}[objective]
116
+
117
+ @torch.no_grad()
118
+ def eval_preds(model, loader, device):
119
+ model.eval()
120
+ ys, ps = [], []
121
+ for X, M, y in loader:
122
+ X, M = X.to(device), M.to(device)
123
+ ps.append(model(X, M).cpu().numpy())
124
+ ys.append(y.numpy())
125
+ return np.concatenate(ys), np.concatenate(ps)
126
+
127
+ def train_one_epoch(model, loader, optim, criterion, device):
128
+ model.train()
129
+ for X, M, y in loader:
130
+ X, M, y = X.to(device), M.to(device), y.to(device)
131
+ optim.zero_grad(set_to_none=True)
132
+ with autocast(enabled=torch.cuda.is_available()):
133
+ loss = criterion(model(X, M), y)
134
+ scaler_amp.scale(loss).backward()
135
+ scaler_amp.step(optim)
136
+ scaler_amp.update()
137
+
138
+ def build_model(model_name, in_dim, params):
139
+ dropout = float(params.get("dropout", 0.1))
140
+ if model_name == "mlp":
141
+ return MLPRegressor(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout)
142
+ elif model_name == "cnn":
143
+ return CNNRegressor(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
144
+ layers=int(params["layers"]), dropout=dropout)
145
+ elif model_name == "transformer":
146
+ return TransformerRegressor(in_dim=in_dim, d_model=int(params["d_model"]),
147
+ nhead=int(params["nhead"]), layers=int(params["layers"]),
148
+ ff=int(params["ff"]), dropout=dropout)
149
+ raise ValueError(model_name)
150
+
151
+ # ======================== Refit Loop =========================================
152
+
153
+ def refit_with_seed(dataset_path, base_out_dir, model_name, seed,
154
+ objective="spearman", device="cuda:0"):
155
+ model_path = os.path.join(base_out_dir, "best_model.pt")
156
+ if not os.path.exists(model_path):
157
+ raise FileNotFoundError(f"No best_model.pt at {model_path}. Run Optuna first.")
158
+
159
+ checkpoint = torch.load(model_path, map_location="cpu")
160
+ best_params = checkpoint["best_params"]
161
+ print(f"Loaded best_params from {model_path}")
162
+ print(json.dumps(best_params, indent=2))
163
+
164
+ seed_everything(seed)
165
+ out_dir = os.path.join(base_out_dir, f"seed_{seed}")
166
+ os.makedirs(out_dir, exist_ok=True)
167
+
168
+ train_ds, val_ds = load_split(dataset_path)
169
+ print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
170
+
171
+ batch_size = int(best_params.get("batch_size", 32))
172
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
173
+ collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
174
+ val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
175
+ collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
176
+
177
+ in_dim = infer_in_dim(train_ds)
178
+ model = build_model(model_name, in_dim, best_params).to(device)
179
+
180
+ # Loss
181
+ loss_name = best_params.get("loss", "mse")
182
+ if loss_name == "mse":
183
+ criterion = nn.MSELoss()
184
+ else:
185
+ criterion = nn.HuberLoss(delta=float(best_params.get("huber_delta", 1.0)))
186
+
187
+ optim = torch.optim.AdamW(model.parameters(),
188
+ lr=float(best_params["lr"]),
189
+ weight_decay=float(best_params["weight_decay"]))
190
+
191
+ best_score, bad, patience = -1e18, 0, 15
192
+ best_state, best_metrics = None, {}
193
+
194
+ for epoch in range(1, 201):
195
+ train_one_epoch(model, train_loader, optim, criterion, device)
196
+ y_true, y_pred = eval_preds(model, val_loader, device)
197
+ metrics = eval_regression(y_true, y_pred)
198
+ score = score_from_metrics(metrics, objective)
199
+
200
+ if score > best_score + 1e-6:
201
+ best_score = score
202
+ best_metrics = metrics
203
+ bad = 0
204
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
205
+ else:
206
+ bad += 1
207
+ if bad >= patience:
208
+ print(f"Early stopping at epoch {epoch}")
209
+ break
210
+
211
+ if best_state:
212
+ model.load_state_dict(best_state)
213
+
214
+ y_true_val, y_pred_val = eval_preds(model, val_loader, device)
215
+ final_metrics = eval_regression(y_true_val, y_pred_val)
216
+
217
+ df_val = pd.DataFrame({
218
+ "y_true": y_true_val.astype(float),
219
+ "y_pred": y_pred_val.astype(float),
220
+ "residual": (y_true_val - y_pred_val).astype(float),
221
+ "abs_error": np.abs(y_true_val - y_pred_val).astype(float),
222
+ })
223
+ if "sequence" in val_ds.column_names:
224
+ df_val.insert(0, "sequence", np.asarray(val_ds["sequence"]))
225
+ df_val.to_csv(os.path.join(out_dir, "val_predictions.csv"), index=False)
226
+
227
+ torch.save({"state_dict": model.state_dict(), "best_params": best_params, "seed": seed},
228
+ os.path.join(out_dir, "model.pt"))
229
+
230
+ summary = {"model": model_name, "seed": seed, **{k: round(v, 6) for k, v in final_metrics.items()}}
231
+ with open(os.path.join(out_dir, "metrics.json"), "w") as f:
232
+ json.dump(summary, f, indent=2)
233
+
234
+ print(f"\n[Seed {seed}] rho={final_metrics['spearman_rho']:.4f} "
235
+ f"RMSE={final_metrics['rmse']:.4f} R2={final_metrics['r2']:.4f}")
236
+ return summary
237
+
238
+ # ======================== CI aggregation =========================================
239
+
240
+ def aggregate_seed_results(base_out_dir, seeds):
241
+ """
242
+ Aggregates across seed runs using:
243
+ - t-distribution 95% CI for Spearman rho, RMSE, R2, MAE
244
+ For rho > 0.9, use Fisher z-transform CI instead.
245
+ """
246
+ from scipy import stats
247
+
248
+ records = []
249
+ for seed in seeds:
250
+ p = os.path.join(base_out_dir, f"seed_{seed}", "metrics.json")
251
+ if os.path.exists(p):
252
+ records.append(json.load(open(p)))
253
+ else:
254
+ print(f"Warning: missing seed {seed}")
255
+
256
+ if not records:
257
+ raise ValueError("No seed results found.")
258
+
259
+ df = pd.DataFrame(records)
260
+ print("\nPer-seed results:")
261
+ print(df.to_string(index=False))
262
+
263
+ summary_rows = []
264
+ for metric in ["spearman_rho", "rmse", "mae", "r2"]:
265
+ vals = df[metric].values
266
+ n = len(vals)
267
+ mean = vals.mean()
268
+ std = vals.std(ddof=1)
269
+ se = std / np.sqrt(n)
270
+ t_crit = stats.t.ppf(0.975, df=n - 1)
271
+ ci = t_crit * se
272
+ row = {
273
+ "metric": metric,
274
+ "mean": round(mean, 4),
275
+ "std": round(std, 4),
276
+ "ci_95": round(ci, 4),
277
+ "report": f"{mean:.4f} ± {ci:.4f}",
278
+ "n_seeds": n,
279
+ }
280
+ # Flag if rho is high enough that the t-CI boundary might exceed 1.0
281
+ if metric == "spearman_rho" and (mean + ci > 0.95 or mean - ci < -0.95):
282
+ row["note"] = "rho near boundary — consider Fisher z-transform CI"
283
+ summary_rows.append(row)
284
+
285
+ summary_df = pd.DataFrame(summary_rows)
286
+ out_path = os.path.join(base_out_dir, "seed_aggregated_metrics.csv")
287
+ summary_df.to_csv(out_path, index=False)
288
+
289
+ print("\n=== Aggregated Metrics (95% CI, t-distribution) ===")
290
+ for _, row in summary_df.iterrows():
291
+ note = f" ← {row['note']}" if "note" in row and pd.notna(row.get("note")) else ""
292
+ print(f" {row['metric']:15s}: {row['report']}{note}")
293
+ print(f"\nSaved to {out_path}")
294
+ return summary_df
295
+
296
+
297
+ if __name__ == "__main__":
298
+ parser = argparse.ArgumentParser()
299
+ parser.add_argument("--dataset_path", type=str, required=True)
300
+ parser.add_argument("--base_out_dir", type=str, required=True)
301
+ parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True)
302
+ parser.add_argument("--seed", type=int, required=True)
303
+ parser.add_argument("--objective", type=str, default="spearman",
304
+ choices=["spearman", "neg_rmse", "r2"])
305
+ parser.add_argument("--aggregate", action="store_true")
306
+ parser.add_argument("--all_seeds", type=int, nargs="+", default=[1986, 42, 0, 123, 12345])
307
+ args = parser.parse_args()
308
+
309
+ if args.aggregate:
310
+ aggregate_seed_results(args.base_out_dir, args.all_seeds)
311
+ else:
312
+ refit_with_seed(
313
+ dataset_path=args.dataset_path,
314
+ base_out_dir=args.base_out_dir,
315
+ model_name=args.model,
316
+ seed=args.seed,
317
+ objective=args.objective,
318
+ )
training_classifiers/src_bash/binding_refit.bash ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=ba-refit-seed
3
+ #SBATCH --partition=dgx-b200
4
+ #SBATCH --gpus=1
5
+ #SBATCH --cpus-per-task=10
6
+ #SBATCH --mem=200G
7
+ #SBATCH --time=24:00:00
8
+ #SBATCH --output=%x_%A_%a.out
9
+ #SBATCH --array=0-4 # 5 seeds → indices 0..4
10
+
11
+ HOME_LOC=~/
12
+ SCRIPT_LOC=$HOME_LOC/PeptiVerse/training_classifiers
13
+ ALT_EMB_LOC=$HOME_LOC/PeptiVerse/training_data_clean
14
+
15
+ # ── Configure per submission ──────────────────────────────────────────
16
+ BINDER_MODEL='wt' # chemberta / peptideclm / wt
17
+ MODE='pooled' # pooled / unpooled
18
+
19
+ # wt-wt
20
+ DATA_PATH="${SCRIPT_LOC}/binding_affinity/pair_wt_wt_${MODE}"
21
+ BASE_OUT_DIR="${SCRIPT_LOC}/binding_affinity/wt_wt_${MODE}"
22
+
23
+ # wt-smiles (chemberta or peptideclm)
24
+ #DATA_PATH="${ALT_EMB_LOC}/binding_affinity/${BINDER_MODEL}/pair_wt_smiles_${MODE}"
25
+ #BASE_OUT_DIR="${SCRIPT_LOC}/binding_affinity/${BINDER_MODEL}_smiles_${MODE}"
26
+ # ────────────────────────────────────────────────────────────────────────────
27
+
28
+ SEEDS=(1986 42 0 123 12345)
29
+ SEED=${SEEDS[$SLURM_ARRAY_TASK_ID]}
30
+
31
+ LOG_LOC=$SCRIPT_LOC/src_bash/logs
32
+ mkdir -p $LOG_LOC
33
+ DATE=$(date +%m_%d)
34
+
35
+ cd $SCRIPT_LOC
36
+
37
+ echo "Running: binder=${BINDER_MODEL} mode=${MODE} seed=${SEED}"
38
+ echo " data : ${DATA_PATH}"
39
+ echo " out : ${BASE_OUT_DIR}"
40
+
41
+ START_TIME=$(date +%s%N)
42
+
43
+ python -u refit_binding_affinity_seed.py \
44
+ --dataset_path "${DATA_PATH}" \
45
+ --base_out_dir "${BASE_OUT_DIR}" \
46
+ --mode "${MODE}" \
47
+ --seed "${SEED}" \
48
+ > "${LOG_LOC}/${DATE}_ba_refit_${BINDER_MODEL}_${MODE}_seed${SEED}.log" 2>&1
49
+
50
+ END_TIME=$(date +%s%N)
51
+ ELAPSED_S=$(( (END_TIME - START_TIME) / 1000000000 ))
52
+ echo "Seed ${SEED} done at $(date) — wall clock: ${ELAPSED_S}s"
53
+ echo "{\"binder\": \"${BINDER_MODEL}\", \"mode\": \"${MODE}\", \"seed\": ${SEED}, \"wall_s\": ${ELAPSED_S}}" \
54
+ >> "${LOG_LOC}/${DATE}_wall_clock_ba_refit.jsonl"
55
+
56
+ conda deactivate
training_classifiers/src_bash/ml_uncertainty.bash ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=ml-walltime
3
+ #SBATCH --partition=b200-mig45
4
+ #SBATCH --gpus=1
5
+ #SBATCH --cpus-per-task=5
6
+ #SBATCH --mem=50G
7
+ #SBATCH --time=6:00:00
8
+ #SBATCH --output=%x_%j.out
9
+
10
+ # =============================================================================
11
+ # Unified Bootstrap CI + Uncertainty + Wall-time Refit
12
+ # wt, smiles, chemberta embeddings
13
+ # Runs sequentially: bootstrap/uncertainty first, then wall-time refit
14
+ # =============================================================================
15
+
16
+ HOME_LOC=~/
17
+ SCRIPT_LOC=$HOME_LOC/PeptiVerse/training_classifiers
18
+ ALT_EMB_LOC=$HOME_LOC/PeptiVerse/training_data_cleaned
19
+ LOG_LOC=$SCRIPT_LOC/src_bash/logs
20
+ mkdir -p $LOG_LOC
21
+ DATE=$(date +%m_%d)
22
+
23
+ cd $SCRIPT_LOC
24
+ # =============================================================================
25
+ # Helper functions
26
+ # =============================================================================
27
+
28
+ # Bootstrap CI + uncertainty
29
+ # $1=OBJECTIVE $2=WT $3=UNCERTAINTY_SCRIPT $4=MODEL_TYPE $5=UNC_MODE
30
+ run_bootstrap() {
31
+ local OBJECTIVE=$1
32
+ local WT=$2
33
+ local SCRIPT=$3
34
+ local MODEL_TYPE=$4
35
+ local UNC_MODE=$5
36
+
37
+ local VAL_PREDS="${SCRIPT_LOC}/${OBJECTIVE}/${MODEL_TYPE}_${WT}/val_predictions.csv"
38
+ local OUT_DIR="${SCRIPT_LOC}/${OBJECTIVE}/${MODEL_TYPE}_${WT}"
39
+ local LOG_FILE="${LOG_LOC}/${DATE}_ci_${MODEL_TYPE}_${OBJECTIVE}_${WT}.log"
40
+
41
+ if [ ! -f "$VAL_PREDS" ]; then
42
+ echo " [SKIP bootstrap] val_predictions.csv not found: $VAL_PREDS"
43
+ return
44
+ fi
45
+
46
+ echo " [bootstrap ci] ${MODEL_TYPE} / ${OBJECTIVE} / ${WT}"
47
+ python -u "$SCRIPT" \
48
+ --mode ci \
49
+ --val_preds "$VAL_PREDS" \
50
+ --out_dir "$OUT_DIR" \
51
+ --model_name "${MODEL_TYPE}_${WT}" \
52
+ >> "$LOG_FILE" 2>&1
53
+
54
+ echo " [bootstrap unc] ${MODEL_TYPE} / ${OBJECTIVE} / ${WT} (${UNC_MODE})"
55
+ python -u "$SCRIPT" \
56
+ --mode "$UNC_MODE" \
57
+ --val_preds "$VAL_PREDS" \
58
+ --out_dir "$OUT_DIR" \
59
+ --model_name "${MODEL_TYPE}_${WT}" \
60
+ >> "$LOG_FILE" 2>&1
61
+
62
+ echo " ${OUT_DIR}/"
63
+ }
64
+
65
+ # Wall-time refit
66
+ # $1=OBJECTIVE $2=WT $3=MODEL_TYPE $4=DATASET_PATH
67
+ run_walltime() {
68
+ local OBJECTIVE=$1
69
+ local WT=$2
70
+ local MODEL_TYPE=$3
71
+ local DATASET_PATH=$4
72
+
73
+ local MODEL_DIR="${SCRIPT_LOC}/${OBJECTIVE}/${MODEL_TYPE}_${WT}"
74
+ local LOG_FILE="${LOG_LOC}/${DATE}_walltime_${MODEL_TYPE}_${OBJECTIVE}_${WT}.log"
75
+
76
+ if [ ! -d "$MODEL_DIR" ]; then
77
+ echo " [SKIP walltime] model_dir not found: $MODEL_DIR"
78
+ return
79
+ fi
80
+ if [ ! -d "$DATASET_PATH" ]; then
81
+ echo " [SKIP walltime] dataset not found: $DATASET_PATH"
82
+ return
83
+ fi
84
+
85
+ echo " [walltime] ${MODEL_TYPE} / ${OBJECTIVE} / ${WT}"
86
+ python -u refit_ml_walltime.py \
87
+ --model_dir "$MODEL_DIR" \
88
+ --dataset_path "$DATASET_PATH" \
89
+ --logs_dir "$LOG_LOC" \
90
+ >> "$LOG_FILE" 2>&1
91
+
92
+ echo " logged to ${LOG_LOC}/${DATE}_wall_clock_ml.jsonl"
93
+ }
94
+
95
+ # =============================================================================
96
+ # Dataset path lookup
97
+ # $1=OBJECTIVE $2=WT
98
+ # =============================================================================
99
+ get_dataset_path() {
100
+ local OBJECTIVE=$1
101
+ local WT=$2
102
+
103
+ local DATA_LOC=$HOME_LOC/projects/Classifier_Weight/training_data_cleaned
104
+
105
+ case "${OBJECTIVE}|${WT}" in
106
+ # -- wt embeddings (ESM2 / original) ------------------------------
107
+ "hemolysis|wt") echo "${DATA_LOC}/hemolysis/hemo_wt_with_embeddings" ;;
108
+ "nf|wt") echo "${DATA_LOC}/nf/nf_wt_with_embeddings" ;;
109
+ "solubility|wt") echo "${DATA_LOC}/solubility/sol_wt_with_embeddings" ;;
110
+ "permeability_penetrance|wt") echo "${DATA_LOC}/permeability_penetrance/perm_wt_with_embeddings_pooled" ;;
111
+ # -- smiles embeddings (PeptideCLM) -------------------------------
112
+ "hemolysis|smiles") echo "${ALT_EMB_LOC}/hemolysis_peptideclm/hemo_smiles_with_embeddings" ;;
113
+ "nf|smiles") echo "${ALT_EMB_LOC}/nf_peptideclm/nf_smiles_with_embeddings" ;;
114
+ "permeability_pampa|smiles") echo "${ALT_EMB_LOC}/permeability_pampa_peptideclm/pampa_smiles_with_embeddings" ;;
115
+ "permeability_caco2|smiles") echo "${ALT_EMB_LOC}/permeability_caco2_peptideclm/caco2_smiles_with_embeddings" ;;
116
+ # -- chemberta embeddings -----------------------------------------
117
+ "hemolysis|chemberta") echo "${ALT_EMB_LOC}/hemolysis_chemberta/hemo_smiles_with_embeddings" ;;
118
+ "nf|chemberta") echo "${ALT_EMB_LOC}/nf_chemberta/nf_smiles_with_embeddings" ;;
119
+ "permeability_penetrance|chemberta") echo "${ALT_EMB_LOC}/permeability_chemberta/perm_smiles_with_embeddings" ;;
120
+ "permeability_penetrance|peptideclm") echo "${ALT_EMB_LOC}/permeability_peptideclm/perm_smiles_with_embeddings" ;;
121
+ "permeability_pampa|chemberta") echo "${ALT_EMB_LOC}/permeability_pampa_chemberta/pampa_smiles_with_embeddings" ;;
122
+ "permeability_caco2|chemberta") echo "${ALT_EMB_LOC}/permeability_caco2_chemberta/caco2_smiles_with_embeddings" ;;
123
+ *)
124
+ echo ""
125
+ ;;
126
+ esac
127
+ }
128
+
129
+ # =============================================================================
130
+ # SECTION 1 - Classification tasks
131
+ # =============================================================================
132
+ echo ""
133
+ echo "============================================================"
134
+ echo " SECTION 1: Classification bootstrap + walltime"
135
+ echo "============================================================"
136
+
137
+ CLS_MODEL_TYPES=("svm_gpu" "enet_gpu" "xgb")
138
+
139
+ # hemolysis, nf - wt + smiles + chemberta
140
+ for OBJECTIVE in "hemolysis" "nf"; do
141
+ for WT in "wt" "smiles" "chemberta"; do
142
+ for MODEL_TYPE in "${CLS_MODEL_TYPES[@]}"; do
143
+ echo ""
144
+ echo "-- ${OBJECTIVE} / ${WT} / ${MODEL_TYPE} --"
145
+ run_bootstrap "$OBJECTIVE" "$WT" "ml_uncertainty.py" "$MODEL_TYPE" "uncertainty_prob"
146
+ DPATH=$(get_dataset_path "$OBJECTIVE" "$WT")
147
+ run_walltime "$OBJECTIVE" "$WT" "$MODEL_TYPE" "$DPATH"
148
+ done
149
+ done
150
+ done
151
+
152
+ # solubility, permeability_penetrance - wt + chemberta (no smiles embeddings)
153
+ for OBJECTIVE in "solubility" "permeability_penetrance"; do
154
+ for WT in "wt" "chemberta"; do
155
+ for MODEL_TYPE in "${CLS_MODEL_TYPES[@]}"; do
156
+ echo ""
157
+ echo "-- ${OBJECTIVE} / ${WT} / ${MODEL_TYPE} --"
158
+ run_bootstrap "$OBJECTIVE" "$WT" "ml_uncertainty.py" "$MODEL_TYPE" "uncertainty_prob"
159
+ DPATH=$(get_dataset_path "$OBJECTIVE" "$WT")
160
+ run_walltime "$OBJECTIVE" "$WT" "$MODEL_TYPE" "$DPATH"
161
+ done
162
+ done
163
+ done
164
+
165
+ # =============================================================================
166
+ # SECTION 2 - Regression tasks (PAMPA, Caco-2)
167
+ # =============================================================================
168
+ echo ""
169
+ echo "============================================================"
170
+ echo " SECTION 2: Regression bootstrap + walltime"
171
+ echo "============================================================"
172
+
173
+ REG_MODEL_TYPES=("svr" "enet_gpu" "xgb")
174
+
175
+ for OBJECTIVE in "permeability_pampa" "permeability_caco2"; do
176
+ for WT in "smiles" "chemberta"; do
177
+ for MODEL_TYPE in "${REG_MODEL_TYPES[@]}"; do
178
+ echo ""
179
+ echo "-- ${OBJECTIVE} / ${WT} / ${MODEL_TYPE} --"
180
+ run_bootstrap "$OBJECTIVE" "$WT" "ml_uncertainty_reg.py" "$MODEL_TYPE" "uncertainty_residual"
181
+ DPATH=$(get_dataset_path "$OBJECTIVE" "$WT")
182
+ run_walltime "$OBJECTIVE" "$WT" "$MODEL_TYPE" "$DPATH"
183
+ done
184
+ done
185
+ done
186
+
187
+ echo ""
188
+ echo "============================================================"
189
+ echo "All runs completed at $(date)"
190
+ echo "============================================================"
191
+
192
+ conda deactivate
training_classifiers/src_bash/nn_uncertainty.bash ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=refit-seed-array
3
+ #SBATCH --partition=dgx-b200
4
+ #SBATCH --gpus=1
5
+ #SBATCH --cpus-per-task=10
6
+ #SBATCH --mem=100G
7
+ #SBATCH --time=12:00:00
8
+ #SBATCH --output=%x_%A_%a.out
9
+ #SBATCH --array=0-4 # 5 seeds → indices 0..4
10
+
11
+ HOME_LOC=~/
12
+ SCRIPT_LOC=$HOME_LOC/PeptiVerse/training_classifiers
13
+ DATA_LOC=$HOME_LOC/PeptiVerse/training_data_cleaned
14
+ # ── Configure per submission ──────────────────────────────────────────
15
+ OBJECTIVE='permeability_pampa' # nf / solubility / hemolysis / permeability_penetrance/ permeability_pampa / permeability_caco2
16
+ WT='chemberta' # wt / smiles / chemberta / peptideclm
17
+ MODEL_TYPE='mlp' # mlp / cnn / transformer
18
+ DATA_FILE="hemo_${WT}_with_embeddings_unpooled" # nf / sol/ hemo / perm / pampa/ caco2
19
+ # Points to the directory where Optuna already saved best_model.pt
20
+ BASE_OUT_DIR="${SCRIPT_LOC}/${OBJECTIVE}/${MODEL_TYPE}_${WT}"
21
+ DATASET_PATH="${DATA_LOC}/permeability_${WT}/${DATA_FILE}"
22
+ # ────────────────────────────────────────────────────────────────────────────
23
+
24
+ SEEDS=(1986 42 0 123 12345)
25
+ SEED=${SEEDS[$SLURM_ARRAY_TASK_ID]}
26
+
27
+ LOG_LOC=$SCRIPT_LOC/src_bash/logs
28
+ mkdir -p $LOG_LOC
29
+ DATE=$(date +%m_%d)
30
+
31
+ cd $SCRIPT_LOC
32
+
33
+ echo "Running seed=$SEED model=$MODEL_TYPE objective=$OBJECTIVE wt=$WT"
34
+
35
+ START_TIME=$(date +%s%N)
36
+
37
+ python -u refit_nn_seed.py \
38
+ --dataset_path "${DATASET_PATH}" \
39
+ --base_out_dir "${BASE_OUT_DIR}" \
40
+ --model "${MODEL_TYPE}" \
41
+ --seed "${SEED}" \
42
+ > "${LOG_LOC}/${DATE}_refit_${MODEL_TYPE}_${OBJECTIVE}_${WT}_seed${SEED}.log" 2>&1
43
+
44
+ END_TIME=$(date +%s%N)
45
+ ELAPSED_S=$(( (END_TIME - START_TIME) / 1000000000 ))
46
+
47
+ echo "Seed $SEED done at $(date) — wall clock: ${ELAPSED_S}s"
48
+ echo "{\"model\": \"${MODEL_TYPE}\", \"objective\": \"${OBJECTIVE}\", \"wt\": \"${WT}\", \"seed\": ${SEED}, \"wall_s\": ${ELAPSED_S}}" \
49
+ >> "${LOG_LOC}/${DATE}_wall_clock_refit.jsonl"
50
+
51
+
training_classifiers/train_ml.py CHANGED
@@ -55,11 +55,9 @@ def _stack_embeddings(col) -> np.ndarray:
55
  def load_split_data(dataset_path: str) -> SplitData:
56
  ds = load_from_disk(dataset_path)
57
 
58
- # Case A: DatasetDict with train/val
59
  if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
60
  train_ds, val_ds = ds["train"], ds["val"]
61
  else:
62
- # Case B: Single dataset with "split" column
63
  if "split" not in ds.column_names:
64
  raise ValueError(
65
  "Dataset must be a DatasetDict(train/val) or have a 'split' column."
@@ -201,7 +199,6 @@ def train_svm(X_train, y_train, X_val, y_val, params):
201
  def train_linearsvm_calibrated(X_train, y_train, X_val, y_val, params):
202
  """
203
  Fast linear SVM (LinearSVC) + probability calibration.
204
- Usually much faster than SVC on large datasets.
205
  """
206
  base = LinearSVC(
207
  C=float(params["C"]),
 
55
  def load_split_data(dataset_path: str) -> SplitData:
56
  ds = load_from_disk(dataset_path)
57
 
 
58
  if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
59
  train_ds, val_ds = ds["train"], ds["val"]
60
  else:
 
61
  if "split" not in ds.column_names:
62
  raise ValueError(
63
  "Dataset must be a DatasetDict(train/val) or have a 'split' column."
 
199
  def train_linearsvm_calibrated(X_train, y_train, X_val, y_val, params):
200
  """
201
  Fast linear SVM (LinearSVC) + probability calibration.
 
202
  """
203
  base = LinearSVC(
204
  C=float(params["C"]),
training_data_cleaned/binding_affinity/binding_affinity_smiles_meta_with_split.csv CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:351f3a76e9dcd50191d8408d6b15a8133eb519d2f463c83c1e7934c0514c6d78
3
- size 4454310
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3aee738ef2b17343ae69723a75473821b4188a196a55dacd0286ec47d065d531
3
+ size 4436974
training_data_cleaned/binding_affinity/binding_affinity_wt_meta_with_split.csv CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d8bd2ec03e42b503e502bcfb88b567c64da77daaf6f2b79ce1142d187cc79bd0
3
- size 3714505
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7abc47729fa52a9f0aa68bffc6dd8c6562d0e4621d437a3a939c4ab27f46d80
3
+ size 3704486
training_data_cleaned/binding_affinity_split.py CHANGED
@@ -1,62 +1,77 @@
1
- import os
2
  import math
3
- from pathlib import Path
4
  import sys
5
  from contextlib import contextmanager
 
 
6
  import numpy as np
7
  import pandas as pd
8
  import torch
 
9
  from tqdm import tqdm
10
- from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
11
- from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence
12
- from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM
13
- from lightning.pytorch import seed_everything
14
- seed_everything(1986)
15
 
16
- CSV_PATH = Path("./Classifier_Weight/training_data_cleaned/binding_affinity/c-binding_with_openfold_scores.csv")
 
 
17
 
18
- OUT_ROOT = Path(
19
- "./Classifier_Weight/training_data_cleaned/binding_affinity"
20
- )
21
 
22
- # WT embedding model
23
- WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
24
- WT_MAX_LEN = 1022
25
- WT_BATCH = 32
26
 
27
- # SMILES embedding model + tokenizer
28
- SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all"
29
- TOKENIZER_VOCAB = "./Classifier_Weight/tokenizer/new_vocab.txt"
30
- TOKENIZER_SPLITS = "./Classifier_Weight/tokenizer/new_splits.txt"
31
- SMI_MAX_LEN = 768
32
- SMI_BATCH = 128
33
 
34
- # Split config
35
- TRAIN_FRAC = 0.80
36
- RANDOM_SEED = 1986
37
- AFFINITY_Q_BINS = 30
38
-
39
- COL_SEQ1 = "seq1"
40
- COL_SEQ2 = "seq2"
41
- COL_AFF = "affinity"
42
- COL_F2S = "Fasta2SMILES"
43
- COL_REACT = "REACT_SMILES"
44
- COL_WT_IPTM = "wt_iptm_score"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  COL_SMI_IPTM = "smiles_iptm_score"
46
 
47
- # Device
 
 
 
 
 
 
 
 
48
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
49
 
50
 
51
- QUIET = True
52
- USE_TQDM = False
53
- LOG_FILE = None
54
 
55
  def log(msg: str):
56
- if LOG_FILE is not None:
57
- Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True)
58
- with open(LOG_FILE, "a") as f:
59
- f.write(msg.rstrip() + "\n")
60
  if not QUIET:
61
  print(msg)
62
 
@@ -70,14 +85,22 @@ def section(title: str):
70
  log(f"=== done: {title} ===")
71
 
72
 
73
- # -------------------------
74
- # Helpers
75
- # -------------------------
 
76
  def has_uaa(seq: str) -> bool:
77
  return "X" in str(seq).upper()
78
 
 
 
 
 
 
 
 
 
79
  def affinity_to_class(a: float) -> str:
80
- # High: >= 9 ; Moderate: [7, 9) ; Low: < 7
81
  if a >= 9.0:
82
  return "High"
83
  elif a >= 7.0:
@@ -87,10 +110,8 @@ def affinity_to_class(a: float) -> str:
87
 
88
  def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
89
  df = df.copy()
90
-
91
  df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
92
  df = df.dropna(subset=[COL_AFF]).reset_index(drop=True)
93
-
94
  df["affinity_class"] = df[COL_AFF].apply(affinity_to_class)
95
 
96
  try:
@@ -101,717 +122,446 @@ def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
101
  strat_col = "aff_bin"
102
 
103
  rng = np.random.RandomState(RANDOM_SEED)
104
-
105
  df["split"] = None
106
  for _, g in df.groupby(strat_col, observed=True):
107
  idx = g.index.to_numpy()
108
  rng.shuffle(idx)
109
  n_train = int(math.floor(len(idx) * TRAIN_FRAC))
110
  df.loc[idx[:n_train], "split"] = "train"
111
- df.loc[idx[n_train:], "split"] = "val"
112
-
113
  df["split"] = df["split"].fillna("train")
114
  return df
115
 
116
- def _summ(x):
117
- x = np.asarray(x, dtype=float)
118
- x = x[~np.isnan(x)]
119
- if len(x) == 0:
120
- return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
121
- return {
122
- "n": int(len(x)),
123
- "mean": float(np.mean(x)),
124
- "std": float(np.std(x)),
125
- "p50": float(np.quantile(x, 0.50)),
126
- "p95": float(np.quantile(x, 0.95)),
127
- }
128
-
129
- def _len_stats(seqs):
130
- lens = np.asarray([len(str(s)) for s in seqs], dtype=float)
131
- if len(lens) == 0:
132
- return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
133
- return {
134
- "n": int(len(lens)),
135
- "mean": float(lens.mean()),
136
- "std": float(lens.std()),
137
- "p50": float(np.quantile(lens, 0.50)),
138
- "p95": float(np.quantile(lens, 0.95)),
139
- }
140
-
141
- def verify_split_before_embedding(
142
- df2: pd.DataFrame,
143
- affinity_col: str,
144
- split_col: str,
145
- seq_col: str,
146
- iptm_col: str,
147
- aff_class_col: str = "affinity_class",
148
- aff_bins: int = 30,
149
- save_report_prefix: str | None = None,
150
- verbose: bool = False,
151
- ):
152
- df2 = df2.copy()
153
- df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce")
154
- df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce")
155
-
156
- assert split_col in df2.columns, f"Missing split col: {split_col}"
157
- assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}"
158
- assert df2[affinity_col].notna().any(), "No valid affinity values after coercion."
159
 
160
- try:
161
- df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop")
162
- except Exception:
163
- df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str)
164
-
165
- tr = df2[df2[split_col] == "train"].reset_index(drop=True)
166
- va = df2[df2[split_col] == "val"].reset_index(drop=True)
167
-
168
- tr_aff = _summ(tr[affinity_col].to_numpy())
169
- va_aff = _summ(va[affinity_col].to_numpy())
170
- tr_len = _len_stats(tr[seq_col].tolist())
171
- va_len = _len_stats(va[seq_col].tolist())
172
-
173
- # bin drift
174
- bin_ct = (
175
- df2.groupby([split_col, "_aff_bin_dbg"])
176
- .size()
177
- .groupby(level=0)
178
- .apply(lambda s: s / s.sum())
179
- )
180
- tr_bins = bin_ct.loc["train"]
181
- va_bins = bin_ct.loc["val"]
182
- all_bins = tr_bins.index.union(va_bins.index)
183
- tr_bins = tr_bins.reindex(all_bins, fill_value=0.0)
184
- va_bins = va_bins.reindex(all_bins, fill_value=0.0)
185
- max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values)))
186
-
187
- msg = (
188
- f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | "
189
- f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | "
190
- f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | "
191
- f"max_bin_diff={max_bin_diff:.4f}"
192
- )
193
- log(msg)
194
-
195
- if verbose and (not QUIET):
196
- class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
197
- class_prop = class_ct.div(class_ct.sum(axis=1), axis=0)
198
- print("\n[verbose] affinity_class counts:\n", class_ct)
199
- print("\n[verbose] affinity_class proportions:\n", class_prop.round(4))
200
-
201
- if save_report_prefix is not None:
202
- out = Path(save_report_prefix)
203
- out.parent.mkdir(parents=True, exist_ok=True)
204
-
205
- stats_df = pd.DataFrame([
206
- {"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}},
207
- {"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}},
208
- ])
209
- class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
210
- class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index()
211
-
212
- stats_df.to_csv(out.with_suffix(".stats.csv"), index=False)
213
- class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False)
214
-
215
-
216
- # -------------------------
217
- # WT pooled (ESM2)
218
- # -------------------------
219
- @torch.no_grad()
220
- def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022):
221
- embs = []
222
- for i in pbar(range(0, len(seqs), batch_size)):
223
- batch = seqs[i:i + batch_size]
224
- inputs = tokenizer(
225
- batch,
226
- padding=True,
227
- truncation=True,
228
- max_length=max_length,
229
- return_tensors="pt",
230
- )
231
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
232
- out = model(**inputs)
233
- h = out.last_hidden_state # (B, L, H)
234
 
235
- attn = inputs["attention_mask"].unsqueeze(-1) # (B, L, 1)
236
- summed = (h * attn).sum(dim=1) # (B, H)
237
- denom = attn.sum(dim=1).clamp(min=1e-9) # (B, 1)
238
- pooled = (summed / denom).detach().cpu().numpy()
239
- embs.append(pooled)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
- return np.vstack(embs)
 
 
 
 
242
 
243
 
244
- # -------------------------
245
- # WT unpooled (ESM2)
246
- # -------------------------
247
  @torch.no_grad()
248
- def wt_unpooled_one(seq, tokenizer, model, cls_id, eos_id, max_length=1022):
249
- tok = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt")
250
- tok = {k: v.to(DEVICE) for k, v in tok.items()}
251
- out = model(**tok)
252
- h = out.last_hidden_state[0] # (L, H)
253
- attn = tok["attention_mask"][0].bool() # (L,)
254
- ids = tok["input_ids"][0]
255
-
256
- keep = attn.clone()
257
- if cls_id is not None:
258
- keep &= (ids != cls_id)
259
- if eos_id is not None:
260
- keep &= (ids != eos_id)
261
-
262
- return h[keep].detach().cpu().to(torch.float16).numpy()
263
-
264
- def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model):
265
- """
266
- Expects df_split to have:
267
- - target_sequence (seq1)
268
- - sequence (binder seq2; WT binder)
269
- - label, affinity_class, COL_AFF, COL_WT_IPTM
270
- Saves a dataset where each row contains BOTH:
271
- - target_embedding (Lt,H), target_attention_mask, target_length
272
- - binder_embedding (Lb,H), binder_attention_mask, binder_length
273
- """
274
- cls_id = tokenizer.cls_token_id
275
- eos_id = tokenizer.eos_token_id
276
- H = model.config.hidden_size
277
-
278
- features = Features({
279
- "target_sequence": Value("string"),
280
- "sequence": Value("string"),
281
- "label": Value("float32"),
282
- "affinity": Value("float32"),
283
- "affinity_class": Value("string"),
284
-
285
- "target_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
286
- "target_attention_mask": HFSequence(Value("int8")),
287
- "target_length": Value("int64"),
288
-
289
- "binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
290
- "binder_attention_mask": HFSequence(Value("int8")),
291
- "binder_length": Value("int64"),
292
-
293
- COL_WT_IPTM: Value("float32"),
294
- COL_AFF: Value("float32"),
295
- })
296
 
297
- def gen_rows(df: pd.DataFrame):
298
- for r in pbar(df.itertuples(index=False), total=len(df)):
299
- tgt = str(getattr(r, "target_sequence")).strip()
300
- bnd = str(getattr(r, "sequence")).strip()
301
-
302
- y = float(getattr(r, "label"))
303
- aff = float(getattr(r, COL_AFF))
304
- acls = str(getattr(r, "affinity_class"))
305
-
306
- iptm = getattr(r, COL_WT_IPTM)
307
- iptm = float(iptm) if pd.notna(iptm) else np.nan
308
-
309
- # token embeddings for target + binder (both ESM)
310
- t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lt,H)
311
- b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lb,H)
312
-
313
- t_list = t_emb.tolist()
314
- b_list = b_emb.tolist()
315
- Lt = len(t_list)
316
- Lb = len(b_list)
317
-
318
- yield {
319
- "target_sequence": tgt,
320
- "sequence": bnd,
321
- "label": np.float32(y),
322
- "affinity": np.float32(aff),
323
- "affinity_class": acls,
324
-
325
- "target_embedding": t_list,
326
- "target_attention_mask": [1] * Lt,
327
- "target_length": int(Lt),
328
-
329
- "binder_embedding": b_list,
330
- "binder_attention_mask": [1] * Lb,
331
- "binder_length": int(Lb),
332
-
333
- COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
334
- COL_AFF: np.float32(aff),
335
- }
336
-
337
- out_dir.mkdir(parents=True, exist_ok=True)
338
- ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
339
- ds.save_to_disk(str(out_dir), max_shard_size="1GB")
340
- return ds
341
-
342
- def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled,
343
- smi_tok, smi_roformer):
344
  """
345
- df_split must have:
346
- - target_sequence (seq1)
347
- - sequence (binder smiles string)
348
- - label, affinity_class, COL_AFF, COL_SMI_IPTM
349
- Saves rows with:
350
- target_embedding (Lt,Ht) from ESM
351
- binder_embedding (Lb,Hb) from PeptideCLM
352
  """
353
- cls_id = wt_tokenizer.cls_token_id
354
- eos_id = wt_tokenizer.eos_token_id
355
- Ht = wt_model_unpooled.config.hidden_size
356
-
357
- Hb = getattr(smi_roformer.config, "hidden_size", None)
358
- if Hb is None:
359
- Hb = getattr(smi_roformer.config, "dim", None)
360
- if Hb is None:
361
- raise ValueError("Cannot infer Hb from smi_roformer config; print(smi_roformer.config) and set Hb manually.")
362
-
363
- features = Features({
364
- "target_sequence": Value("string"),
365
- "sequence": Value("string"),
366
- "label": Value("float32"),
367
- "affinity": Value("float32"),
368
- "affinity_class": Value("string"),
369
-
370
- "target_embedding": HFSequence(HFSequence(Value("float16"), length=Ht)),
371
- "target_attention_mask": HFSequence(Value("int8")),
372
- "target_length": Value("int64"),
373
-
374
- "binder_embedding": HFSequence(HFSequence(Value("float16"), length=Hb)),
375
- "binder_attention_mask": HFSequence(Value("int8")),
376
- "binder_length": Value("int64"),
377
-
378
- COL_SMI_IPTM: Value("float32"),
379
- COL_AFF: Value("float32"),
 
 
 
 
 
 
 
 
 
380
  })
 
 
381
 
382
- def gen_rows(df: pd.DataFrame):
383
- for r in pbar(df.itertuples(index=False), total=len(df)):
384
- tgt = str(getattr(r, "target_sequence")).strip()
385
- bnd = str(getattr(r, "sequence")).strip()
386
-
387
- y = float(getattr(r, "label"))
388
- aff = float(getattr(r, COL_AFF))
389
- acls = str(getattr(r, "affinity_class"))
390
-
391
- iptm = getattr(r, COL_SMI_IPTM)
392
- iptm = float(iptm) if pd.notna(iptm) else np.nan
393
-
394
- # target token embeddings (ESM)
395
- t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN)
396
- t_list = t_emb.tolist()
397
- Lt = len(t_list)
398
-
399
- # binder token embeddings (PeptideCLM)
400
- _, tok_list, mask_list, lengths = smiles_embed_batch_return_both(
401
- [bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN
402
- )
403
- b_emb = tok_list[0]
404
- b_list = b_emb.tolist()
405
- Lb = int(lengths[0])
406
- b_mask = mask_list[0].astype(np.int8).tolist()
407
-
408
- yield {
409
- "target_sequence": tgt,
410
- "sequence": bnd,
411
- "label": np.float32(y),
412
- "affinity": np.float32(aff),
413
- "affinity_class": acls,
414
-
415
- "target_embedding": t_list,
416
- "target_attention_mask": [1] * Lt,
417
- "target_length": int(Lt),
418
-
419
- "binder_embedding": b_list,
420
- "binder_attention_mask": [int(x) for x in b_mask],
421
- "binder_length": int(Lb),
422
-
423
- COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
424
- COL_AFF: np.float32(aff),
425
- }
426
-
427
- out_dir.mkdir(parents=True, exist_ok=True)
428
- ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
429
- ds.save_to_disk(str(out_dir), max_shard_size="1GB")
430
- return ds
431
-
432
-
433
- # -------------------------
434
- # SMILES pooled + unpooled (PeptideCLM)
435
- # -------------------------
436
- def get_special_ids(tokenizer_obj):
437
- cand = [
438
- getattr(tokenizer_obj, "pad_token_id", None),
439
- getattr(tokenizer_obj, "cls_token_id", None),
440
- getattr(tokenizer_obj, "sep_token_id", None),
441
- getattr(tokenizer_obj, "bos_token_id", None),
442
- getattr(tokenizer_obj, "eos_token_id", None),
443
- getattr(tokenizer_obj, "mask_token_id", None),
444
- ]
445
- return sorted({x for x in cand if x is not None})
446
 
447
- @torch.no_grad()
448
- def smiles_embed_batch_return_both(batch_sequences, tokenizer_obj, model_roformer, max_length):
449
- tok = tokenizer_obj(
450
- batch_sequences,
451
- return_tensors="pt",
452
- padding=True,
453
- truncation=True,
454
- max_length=max_length,
455
- )
456
- input_ids = tok["input_ids"].to(DEVICE)
457
- attention_mask = tok["attention_mask"].to(DEVICE)
458
-
459
- outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask)
460
- last_hidden = outputs.last_hidden_state # (B, L, H)
461
-
462
- special_ids = get_special_ids(tokenizer_obj)
463
  valid = attention_mask.bool()
464
- if len(special_ids) > 0:
465
- sid = torch.tensor(special_ids, device=DEVICE, dtype=torch.long)
466
- if hasattr(torch, "isin"):
467
- valid = valid & (~torch.isin(input_ids, sid))
468
- else:
469
- m = torch.zeros_like(valid)
470
- for s in special_ids:
471
- m |= (input_ids == s)
472
- valid = valid & (~m)
473
 
474
  valid_f = valid.unsqueeze(-1).float()
475
- summed = torch.sum(last_hidden * valid_f, dim=1)
476
- denom = torch.clamp(valid_f.sum(dim=1), min=1e-9)
477
- pooled = (summed / denom).detach().cpu().numpy()
 
478
 
479
- token_emb_list, mask_list, lengths = [], [], []
480
  for b in range(last_hidden.shape[0]):
481
- emb = last_hidden[b, valid[b]] # (Li, H)
482
- token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy())
483
- li = emb.shape[0]
484
- lengths.append(int(li))
485
- mask_list.append(np.ones((li,), dtype=np.int8))
486
-
487
- return pooled, token_emb_list, mask_list, lengths
488
-
489
- def smiles_generate_embeddings_batched_both(seqs, tokenizer_obj, model_roformer, batch_size, max_length):
490
- pooled_all = []
491
- token_emb_all = []
492
- mask_all = []
493
- lengths_all = []
494
-
495
- for i in pbar(range(0, len(seqs), batch_size)):
496
- batch = seqs[i:i + batch_size]
497
- pooled, tok_list, m_list, lens = smiles_embed_batch_return_both(
498
- batch, tokenizer_obj, model_roformer, max_length
499
- )
500
- pooled_all.append(pooled)
501
- token_emb_all.extend(tok_list)
502
- mask_all.extend(m_list)
503
- lengths_all.extend(lens)
504
-
505
- return np.vstack(pooled_all), token_emb_all, mask_all, lengths_all
506
-
507
- def build_target_cache_from_wt_view(wt_view_train: pd.DataFrame, wt_view_val: pd.DataFrame):
508
- wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
509
- wt_model = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
510
-
511
- # compute target pooled embeddings once
512
- tgt_wt_train = wt_view_train["target_sequence"].astype(str).tolist()
513
- tgt_wt_val = wt_view_val["target_sequence"].astype(str).tolist()
514
-
515
- wt_train_tgt_emb = wt_pooled_embeddings(
516
- tgt_wt_train, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
517
- )
518
- wt_val_tgt_emb = wt_pooled_embeddings(
519
- tgt_wt_val, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
520
- )
521
-
522
- # build dict: target_sequence -> embedding
523
- train_map = {s: e for s, e in zip(tgt_wt_train, wt_train_tgt_emb)}
524
- val_map = {s: e for s, e in zip(tgt_wt_val, wt_val_tgt_emb)}
525
- return wt_tok, wt_model, wt_train_tgt_emb, wt_val_tgt_emb, train_map, val_map
526
- # -------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  # Main
528
- # -------------------------
 
529
  def main():
530
- log(f"[INFO] DEVICE: {DEVICE}")
 
 
531
  OUT_ROOT.mkdir(parents=True, exist_ok=True)
532
 
 
 
 
533
  with section("load csv + dedup"):
534
  df = pd.read_csv(CSV_PATH)
535
- for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]:
 
 
 
536
  if c in df.columns:
537
  df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
538
-
539
- # Dedup
540
- DEDUP_COLS = [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]
541
- df = df.drop_duplicates(subset=DEDUP_COLS).reset_index(drop=True)
542
-
543
- print("Rows after dedup on", DEDUP_COLS, ":", len(df))
544
-
545
- need = [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]
546
- missing = [c for c in need if c not in df.columns]
547
- if missing:
548
- raise ValueError(f"Missing required columns: {missing}")
549
-
550
- # numeric affinity for both branches
551
  df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
552
 
553
- # WT subset + SMILES subset separately
554
- with section("prepare wt/smiles subsets"):
555
- # WT: requires a canonical peptide sequence (no X) + affinity
 
 
 
556
  df_wt = df.copy()
557
  df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip()
558
- df_wt = df_wt.dropna(subset=[COL_AFF]).reset_index(drop=True)
559
- df_wt = df_wt[df_wt["wt_sequence"].notna() & (df_wt["wt_sequence"] != "")]
560
- df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)].reset_index(drop=True)
561
-
562
- # SMILES: requires affinity + a usable picked SMILES (UAA->REACT, else->Fasta2SMILES)
 
 
 
 
 
563
  df_smi = df.copy()
564
- df_smi = df_smi.dropna(subset=[COL_AFF]).reset_index(drop=True)
565
  df_smi = df_smi[
566
  pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna()
567
- ].reset_index(drop=True) # empty iptm means sth wrong with their smiles sequence
568
-
569
- is_uaa = df_smi[COL_SEQ2].astype(str).str.contains("X", case=False, na=False)
570
- df_smi["smiles_sequence"] = np.where(is_uaa, df_smi[COL_REACT], df_smi[COL_F2S])
571
- df_smi["smiles_sequence"] = df_smi["smiles_sequence"].astype(str).str.strip()
572
- df_smi = df_smi[df_smi["smiles_sequence"].notna() & (df_smi["smiles_sequence"] != "")]
573
- df_smi = df_smi[~df_smi["smiles_sequence"].isin(["nan", "None"])].reset_index(drop=True)
574
-
575
- log(f"[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)} (after per-branch filtering)")
576
-
577
- # Split separately
578
- with section("split wt and smiles separately"):
579
- df_wt2 = make_distribution_matched_split(df_wt)
 
 
 
 
 
 
 
 
 
 
 
 
580
  df_smi2 = make_distribution_matched_split(df_smi)
581
 
582
- # save split tables
583
- wt_split_csv = OUT_ROOT / "binding_affinity_wt_meta_with_split.csv"
584
- smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv"
585
- df_wt2.to_csv(wt_split_csv, index=False)
586
- df_smi2.to_csv(smi_split_csv, index=False)
587
- log(f"Saved WT split meta: {wt_split_csv}")
588
- log(f"Saved SMILES split meta: {smi_split_csv}")
589
-
590
- verify_split_before_embedding(
591
- df2=df_wt2,
592
- affinity_col=COL_AFF,
593
- split_col="split",
594
- seq_col="wt_sequence",
595
- iptm_col=COL_WT_IPTM,
596
- aff_class_col="affinity_class",
597
- aff_bins=AFFINITY_Q_BINS,
598
- save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"),
599
- verbose=False,
600
- )
601
- verify_split_before_embedding(
602
- df2=df_smi2,
603
- affinity_col=COL_AFF,
604
- split_col="split",
605
- seq_col="smiles_sequence",
606
- iptm_col=COL_SMI_IPTM,
607
- aff_class_col="affinity_class",
608
- aff_bins=AFFINITY_Q_BINS,
609
- save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"),
610
- verbose=False,
611
- )
612
 
613
- # Prepare split views
614
- def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame:
615
- out = df_in.copy()
616
- out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() # <-- NEW
617
- out["sequence"] = out[binder_seq_col].astype(str).str.strip() # binder
618
- out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce")
619
- out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce")
620
- out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce")
621
- out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True)
622
- return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]]
623
-
624
- wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM)
625
- smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM)
626
-
627
- # -------------------------
628
- # Split views
629
- # -------------------------
630
- wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True)
631
- wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True)
632
  smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True)
633
  smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True)
634
-
635
-
636
- # =========================
637
- # TARGET pooled embeddings (ESM) — SEPARATE per branch
638
- # =========================
639
- with section("TARGET pooled embeddings (ESM) WT + SMILES separately"):
640
- wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
641
- wt_esm = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
642
-
643
- # ---- WT targets ----
644
- wt_train_tgt_emb = wt_pooled_embeddings(
645
- wt_train["target_sequence"].astype(str).str.strip().tolist(),
646
- wt_tok, wt_esm,
647
- batch_size=WT_BATCH,
648
- max_length=WT_MAX_LEN,
649
- ).astype(np.float32)
650
-
651
- wt_val_tgt_emb = wt_pooled_embeddings(
652
- wt_val["target_sequence"].astype(str).str.strip().tolist(),
653
- wt_tok, wt_esm,
654
- batch_size=WT_BATCH,
655
- max_length=WT_MAX_LEN,
656
- ).astype(np.float32)
657
-
658
- # ---- SMILES targets ----
659
- smi_train_tgt_emb = wt_pooled_embeddings(
660
- smi_train["target_sequence"].astype(str).str.strip().tolist(),
661
- wt_tok, wt_esm,
662
- batch_size=WT_BATCH,
663
- max_length=WT_MAX_LEN,
664
- ).astype(np.float32)
665
-
666
- smi_val_tgt_emb = wt_pooled_embeddings(
667
- smi_val["target_sequence"].astype(str).str.strip().tolist(),
668
- wt_tok, wt_esm,
669
- batch_size=WT_BATCH,
670
- max_length=WT_MAX_LEN,
671
- ).astype(np.float32)
672
-
673
-
674
- # =========================
675
- # WT pooled binder embeddings (binder = WT peptide)
676
- # =========================
677
- with section("WT pooled binder embeddings + save"):
678
- wt_train_emb = wt_pooled_embeddings(
679
- wt_train["sequence"].astype(str).str.strip().tolist(),
680
- wt_tok, wt_esm,
681
- batch_size=WT_BATCH,
682
- max_length=WT_MAX_LEN,
683
- ).astype(np.float32)
684
-
685
- wt_val_emb = wt_pooled_embeddings(
686
- wt_val["sequence"].astype(str).str.strip().tolist(),
687
- wt_tok, wt_esm,
688
- batch_size=WT_BATCH,
689
- max_length=WT_MAX_LEN,
690
- ).astype(np.float32)
691
-
692
- wt_train_ds = Dataset.from_dict({
693
- "target_sequence": wt_train["target_sequence"].tolist(),
694
- "sequence": wt_train["sequence"].tolist(),
695
- "label": wt_train["label"].astype(float).tolist(),
696
- "target_embedding": wt_train_tgt_emb,
697
- "embedding": wt_train_emb,
698
- COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(),
699
- COL_AFF: wt_train[COL_AFF].astype(float).tolist(),
700
- "affinity_class": wt_train["affinity_class"].tolist(),
701
- })
702
-
703
- wt_val_ds = Dataset.from_dict({
704
- "target_sequence": wt_val["target_sequence"].tolist(),
705
- "sequence": wt_val["sequence"].tolist(),
706
- "label": wt_val["label"].astype(float).tolist(),
707
- "target_embedding": wt_val_tgt_emb,
708
- "embedding": wt_val_emb,
709
- COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(),
710
- COL_AFF: wt_val[COL_AFF].astype(float).tolist(),
711
- "affinity_class": wt_val["affinity_class"].tolist(),
712
- })
713
-
714
- wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds})
715
- wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled"
716
- wt_pooled_dd.save_to_disk(str(wt_pooled_out))
717
- log(f"Saved WT pooled -> {wt_pooled_out}")
718
-
719
-
720
- # =========================
721
- # SMILES pooled binder embeddings (binder = SMILES via PeptideCLM)
722
- # =========================
723
- with section("SMILES pooled binder embeddings + save"):
724
- smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
725
- smi_roformer = (
726
- AutoModelForMaskedLM
727
- .from_pretrained(SMI_MODEL_NAME)
728
- .roformer
729
- .to(DEVICE)
730
- .eval()
731
- )
732
-
733
- smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
734
- smi_train["sequence"].astype(str).str.strip().tolist(),
735
- smi_tok, smi_roformer,
736
- batch_size=SMI_BATCH,
737
- max_length=SMI_MAX_LEN,
738
  )
739
-
740
- smi_val_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
741
- smi_val["sequence"].astype(str).str.strip().tolist(),
742
- smi_tok, smi_roformer,
743
- batch_size=SMI_BATCH,
744
- max_length=SMI_MAX_LEN,
 
745
  )
746
-
747
- smi_train_ds = Dataset.from_dict({
748
- "target_sequence": smi_train["target_sequence"].tolist(),
749
- "sequence": smi_train["sequence"].tolist(),
750
- "label": smi_train["label"].astype(float).tolist(),
751
- "target_embedding": smi_train_tgt_emb,
752
- "embedding": smi_train_pooled.astype(np.float32),
753
- COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(),
754
- COL_AFF: smi_train[COL_AFF].astype(float).tolist(),
755
- "affinity_class": smi_train["affinity_class"].tolist(),
756
- })
757
-
758
- smi_val_ds = Dataset.from_dict({
759
- "target_sequence": smi_val["target_sequence"].tolist(),
760
- "sequence": smi_val["sequence"].tolist(),
761
- "label": smi_val["label"].astype(float).tolist(),
762
- "target_embedding": smi_val_tgt_emb,
763
- "embedding": smi_val_pooled.astype(np.float32),
764
- COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(),
765
- COL_AFF: smi_val[COL_AFF].astype(float).tolist(),
766
- "affinity_class": smi_val["affinity_class"].tolist(),
767
- })
768
-
769
- smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds})
770
- smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled"
771
- smi_pooled_dd.save_to_disk(str(smi_pooled_out))
772
- log(f"Saved SMILES pooled -> {smi_pooled_out}")
773
-
774
-
775
- # =========================
776
- # WT unpooled paired (ESM target + ESM binder) + save
777
- # =========================
778
- with section("WT unpooled paired embeddings + save"):
779
- wt_tok_unpooled = wt_tok # reuse tokenizer
780
- wt_esm_unpooled = wt_esm # reuse model
781
-
782
- wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled"
783
- wt_unpooled_dd = DatasetDict({
784
- "train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train",
785
- wt_tok_unpooled, wt_esm_unpooled),
786
- "val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val",
787
- wt_tok_unpooled, wt_esm_unpooled),
788
- })
789
- wt_unpooled_dd.save_to_disk(str(wt_unpooled_out))
790
- log(f"Saved WT unpooled -> {wt_unpooled_out}")
791
-
792
-
793
- # =========================
794
- # SMILES unpooled paired (ESM target + PeptideCLM binder) + save
795
- # =========================
796
- with section("SMILES unpooled paired embeddings + save"):
797
- smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled"
798
- smi_unpooled_dd = DatasetDict({
799
- "train": build_smiles_unpooled_paired_dataset(
800
- smi_train, smi_unpooled_out / "train",
801
- wt_tok, wt_esm,
802
- smi_tok, smi_roformer
803
- ),
804
- "val": build_smiles_unpooled_paired_dataset(
805
- smi_val, smi_unpooled_out / "val",
806
- wt_tok, wt_esm,
807
- smi_tok, smi_roformer
808
- ),
809
- })
810
- smi_unpooled_dd.save_to_disk(str(smi_unpooled_out))
811
- log(f"Saved SMILES unpooled -> {smi_unpooled_out}")
812
-
813
- log(f"\n[DONE] All datasets saved under: {OUT_ROOT}")
814
 
815
 
816
  if __name__ == "__main__":
817
- main()
 
 
1
  import math
 
2
  import sys
3
  from contextlib import contextmanager
4
+ from pathlib import Path
5
+
6
  import numpy as np
7
  import pandas as pd
8
  import torch
9
+ from datasets import Dataset, DatasetDict
10
  from tqdm import tqdm
11
+ from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, EsmModel
 
 
 
 
12
 
13
+ # ======================
14
+ # CONFIG
15
+ # ======================
16
 
17
+ ROOT = Path("<>") # CHANGE HERE
18
+ PROJ_ROOT = ROOT / "PeptiVerse"
 
19
 
20
+ CSV_PATH = PROJ_ROOT / "training_data" / "c-binding.csv"
 
 
 
21
 
22
+ OUT_ROOT = PROJ_ROOT / "training_data_cleaned" / "binding_affinity"
 
 
 
 
 
23
 
24
+ # ESM2 - target encoder (shared across all branches)
25
+ ESM_MODEL = "facebook/esm2_t33_650M_UR50D"
26
+ ESM_MAX_LEN = 1022
27
+ ESM_BATCH = 32
28
+
29
+ # PeptideCLM - SMILES binder encoder
30
+ sys.path.append(str(PROJ_ROOT))
31
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
32
+
33
+ PEPTIDECLM_MODEL = "aaronfeller/PeptideCLM-23M-all"
34
+ TOKENIZER_VOCAB = str(PROJ_ROOT / "tokenizer" / "new_vocab.txt")
35
+ TOKENIZER_SPLITS = str(PROJ_ROOT / "tokenizer" / "new_splits.txt")
36
+ PEPTIDECLM_MAX_LEN = 768
37
+ PEPTIDECLM_BATCH = 128
38
+
39
+ # ChemBERTa - SMILES binder encoder
40
+ CHEMBERTA_MODEL = "DeepChem/ChemBERTa-77M-MLM"
41
+ CHEMBERTA_MAX_LEN = 512
42
+ CHEMBERTA_BATCH = 128
43
+
44
+ # Which SMILES binder models to run
45
+ RUN_PEPTIDECLM = True
46
+ RUN_CHEMBERTA = True
47
+
48
+ # CSV column names
49
+ COL_SEQ1 = "seq1"
50
+ COL_SEQ2 = "seq2"
51
+ COL_AFF = "affinity"
52
+ COL_F2S = "Fasta2SMILES"
53
+ COL_REACT = "REACT_SMILES"
54
+ COL_MERGE = "Merge_SMILES"
55
+ COL_WT_IPTM = "wt_iptm_score"
56
  COL_SMI_IPTM = "smiles_iptm_score"
57
 
58
+ # Split config
59
+ TRAIN_FRAC = 0.80
60
+ RANDOM_SEED = 1986
61
+ AFFINITY_Q_BINS = 30
62
+
63
+ # Logging
64
+ QUIET = True
65
+ USE_TQDM = False
66
+
67
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
68
 
69
 
70
+ # ======================
71
+ # Logging / progress
72
+ # ======================
73
 
74
  def log(msg: str):
 
 
 
 
75
  if not QUIET:
76
  print(msg)
77
 
 
85
  log(f"=== done: {title} ===")
86
 
87
 
88
+ # ======================
89
+ # Data Handling
90
+ # ======================
91
+
92
  def has_uaa(seq: str) -> bool:
93
  return "X" in str(seq).upper()
94
 
95
+ def pick_smiles(row) -> str | None:
96
+ """Column Priority: Fasta2SMILES > REACT_SMILES > Merge_SMILES."""
97
+ for col in [COL_F2S, COL_REACT, COL_MERGE]:
98
+ val = row.get(col, None)
99
+ if val is not None and str(val).strip() not in ("", "nan", "None"):
100
+ return str(val).strip()
101
+ return None
102
+
103
  def affinity_to_class(a: float) -> str:
 
104
  if a >= 9.0:
105
  return "High"
106
  elif a >= 7.0:
 
110
 
111
  def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
112
  df = df.copy()
 
113
  df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
114
  df = df.dropna(subset=[COL_AFF]).reset_index(drop=True)
 
115
  df["affinity_class"] = df[COL_AFF].apply(affinity_to_class)
116
 
117
  try:
 
122
  strat_col = "aff_bin"
123
 
124
  rng = np.random.RandomState(RANDOM_SEED)
 
125
  df["split"] = None
126
  for _, g in df.groupby(strat_col, observed=True):
127
  idx = g.index.to_numpy()
128
  rng.shuffle(idx)
129
  n_train = int(math.floor(len(idx) * TRAIN_FRAC))
130
  df.loc[idx[:n_train], "split"] = "train"
131
+ df.loc[idx[n_train:], "split"] = "val"
 
132
  df["split"] = df["split"].fillna("train")
133
  return df
134
 
135
+ def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame:
136
+ out = df_in.copy()
137
+ out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip()
138
+ out["sequence"] = out[binder_seq_col].astype(str).str.strip()
139
+ out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce")
140
+ out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce")
141
+ out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce")
142
+ out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True)
143
+ return out[["target_sequence", "sequence", "label", "split",
144
+ iptm_col, COL_AFF, "affinity_class"]]
145
+
146
+
147
+ # ======================
148
+ # Dataset builders
149
+ # ======================
150
+
151
+ def build_pooled_ds(view: pd.DataFrame, iptm_col: str,
152
+ tgt_embs: np.ndarray, bnd_embs: np.ndarray) -> Dataset:
153
+ """Both target and binder are (N, H) pooled float32 arrays."""
154
+ return Dataset.from_dict({
155
+ "target_sequence": view["target_sequence"].tolist(),
156
+ "sequence": view["sequence"].tolist(),
157
+ "label": view["label"].astype(float).tolist(),
158
+ "target_embedding": tgt_embs, # (N, H_esm) float32
159
+ "binder_embedding": bnd_embs, # (N, H_binder) float32
160
+ "affinity": view[COL_AFF].astype(float).tolist(),
161
+ "affinity_class": view["affinity_class"].tolist(),
162
+ iptm_col: view[iptm_col].astype(float).tolist(),
163
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ def build_unpooled_ds(view: pd.DataFrame, iptm_col: str,
167
+ tgt_tok_embs, tgt_masks, tgt_lengths,
168
+ bnd_tok_embs, bnd_masks, bnd_lengths) -> Dataset:
169
+ """
170
+ Per-token lists for both sides.
171
+ target_embedding[i] : (Lt_i, H_esm) float16 ndarray
172
+ binder_embedding[i] : (Lb_i, H_binder) float16 ndarray
173
+ """
174
+ return Dataset.from_dict({
175
+ "target_sequence": view["target_sequence"].tolist(),
176
+ "sequence": view["sequence"].tolist(),
177
+ "label": view["label"].astype(float).tolist(),
178
+
179
+ "target_embedding": tgt_tok_embs,
180
+ "target_attention_mask": tgt_masks,
181
+ "target_length": tgt_lengths,
182
+
183
+ "binder_embedding": bnd_tok_embs,
184
+ "binder_attention_mask": bnd_masks,
185
+ "binder_length": bnd_lengths,
186
+
187
+ "affinity": view[COL_AFF].astype(float).tolist(),
188
+ "affinity_class": view["affinity_class"].tolist(),
189
+ iptm_col: view[iptm_col].astype(float).tolist(),
190
+ })
191
+
192
+
193
+ # ======================
194
+ # ESM2 - shared target encoder
195
+ # ======================
196
 
197
+ def load_esm():
198
+ print(f" Loading ESM2: {ESM_MODEL}")
199
+ tok = AutoTokenizer.from_pretrained(ESM_MODEL)
200
+ model = EsmModel.from_pretrained(ESM_MODEL).to(DEVICE).eval()
201
+ return tok, model
202
 
203
 
 
 
 
204
  @torch.no_grad()
205
+ def embed_esm_pooled(seqs, tok, model) -> np.ndarray:
206
+ """Returns (N, H) float32 - mean-pooled over non-pad tokens."""
207
+ all_embs = []
208
+ for i in pbar(range(0, len(seqs), ESM_BATCH), desc=" ESM2 pooled"):
209
+ batch = seqs[i:i + ESM_BATCH]
210
+ enc = tok(batch, return_tensors="pt", padding=True,
211
+ truncation=True, max_length=ESM_MAX_LEN)
212
+ ids = enc["input_ids"].to(DEVICE)
213
+ mask = enc["attention_mask"].to(DEVICE)
214
+ h = model(input_ids=ids, attention_mask=mask).last_hidden_state
215
+ attn_f = mask.unsqueeze(-1).float()
216
+ pooled = ((h * attn_f).sum(dim=1) /
217
+ attn_f.sum(dim=1).clamp(min=1e-9)).cpu().numpy().astype(np.float32)
218
+ all_embs.append(pooled)
219
+ return np.vstack(all_embs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
+
222
+ @torch.no_grad()
223
+ def embed_esm_unpooled(seqs, tok, model):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  """
225
+ Returns per-token lists (CLS/EOS/pad excluded).
226
+ tok_embs : list of (Lt_i, H) float16 arrays
227
+ masks : list of (Lt_i,) int8 arrays (all-ones)
228
+ lengths : list of int
 
 
 
229
  """
230
+ cls_id = tok.cls_token_id
231
+ eos_id = tok.eos_token_id
232
+
233
+ tok_embs, masks, lengths = [], [], []
234
+ for i in pbar(range(0, len(seqs), ESM_BATCH), desc=" ESM2 unpooled"):
235
+ batch = seqs[i:i + ESM_BATCH]
236
+ enc = tok(batch, return_tensors="pt", padding=True,
237
+ truncation=True, max_length=ESM_MAX_LEN)
238
+ ids = enc["input_ids"].to(DEVICE)
239
+ mask = enc["attention_mask"].to(DEVICE)
240
+ h = model(input_ids=ids, attention_mask=mask).last_hidden_state
241
+
242
+ for b in range(h.shape[0]):
243
+ keep = mask[b].bool()
244
+ if cls_id is not None:
245
+ keep = keep & (ids[b] != cls_id)
246
+ if eos_id is not None:
247
+ keep = keep & (ids[b] != eos_id)
248
+ emb = h[b, keep].cpu().to(torch.float16).numpy()
249
+ tok_embs.append(emb)
250
+ masks.append(np.ones(emb.shape[0], dtype=np.int8))
251
+ lengths.append(emb.shape[0])
252
+ return tok_embs, masks, lengths
253
+
254
+
255
+ # ======================
256
+ # Generic binder embedding helpers
257
+ # ======================
258
+
259
+ def _get_special_ids_t(tokenizer):
260
+ special_ids = sorted({
261
+ x for x in [
262
+ getattr(tokenizer, attr, None)
263
+ for attr in ("pad_token_id", "cls_token_id", "sep_token_id",
264
+ "bos_token_id", "eos_token_id", "mask_token_id")
265
+ ] if x is not None
266
  })
267
+ return (torch.tensor(special_ids, device=DEVICE, dtype=torch.long)
268
+ if special_ids else None)
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
+ def _pool_and_unpool(last_hidden, input_ids, attention_mask, special_ids_t):
272
+ """Mean-pool over non-special valid tokens; also return per-token arrays."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  valid = attention_mask.bool()
274
+ if special_ids_t is not None:
275
+ valid = valid & (~torch.isin(input_ids, special_ids_t))
 
 
 
 
 
 
 
276
 
277
  valid_f = valid.unsqueeze(-1).float()
278
+ pooled = (
279
+ torch.sum(last_hidden * valid_f, dim=1) /
280
+ torch.clamp(valid_f.sum(dim=1), min=1e-9)
281
+ ).cpu().numpy().astype(np.float32)
282
 
283
+ tok_embs, masks, lengths = [], [], []
284
  for b in range(last_hidden.shape[0]):
285
+ emb = last_hidden[b, valid[b]].cpu().to(torch.float16).numpy()
286
+ tok_embs.append(emb)
287
+ masks.append(np.ones(emb.shape[0], dtype=np.int8))
288
+ lengths.append(emb.shape[0])
289
+ return pooled, tok_embs, masks, lengths
290
+
291
+
292
+ # ======================
293
+ # PeptideCLM - SMILES binder encoder
294
+ # ======================
295
+
296
+ def load_peptideclm():
297
+ print(f" Loading PeptideCLM: {PEPTIDECLM_MODEL}")
298
+ tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
299
+ model = (AutoModelForMaskedLM.from_pretrained(PEPTIDECLM_MODEL)
300
+ .roformer.to(DEVICE).eval())
301
+ return tok, model, _get_special_ids_t(tok)
302
+
303
+
304
+ @torch.no_grad()
305
+ def embed_peptideclm(seqs, tok, model, sid_t):
306
+ pooled_all, tok_all, mask_all, len_all = [], [], [], []
307
+ for i in pbar(range(0, len(seqs), PEPTIDECLM_BATCH), desc=" PeptideCLM binder"):
308
+ batch = seqs[i:i + PEPTIDECLM_BATCH]
309
+ enc = tok(batch, return_tensors="pt", padding=True,
310
+ truncation=True, max_length=PEPTIDECLM_MAX_LEN)
311
+ ids = enc["input_ids"].to(DEVICE)
312
+ mask = enc["attention_mask"].to(DEVICE)
313
+ h = model(input_ids=ids, attention_mask=mask).last_hidden_state
314
+ p, t, m, l = _pool_and_unpool(h, ids, mask, sid_t)
315
+ pooled_all.append(p); tok_all.extend(t); mask_all.extend(m); len_all.extend(l)
316
+ return np.vstack(pooled_all), tok_all, mask_all, len_all
317
+
318
+
319
+ # ======================
320
+ # ChemBERTa - SMILES binder encoder
321
+ # ======================
322
+
323
+ def load_chemberta():
324
+ print(f" Loading ChemBERTa: {CHEMBERTA_MODEL}")
325
+ tok = AutoTokenizer.from_pretrained(CHEMBERTA_MODEL)
326
+ model = AutoModel.from_pretrained(CHEMBERTA_MODEL).to(DEVICE).eval()
327
+ return tok, model, _get_special_ids_t(tok)
328
+
329
+
330
+ @torch.no_grad()
331
+ def embed_chemberta(seqs, tok, model, sid_t):
332
+ pooled_all, tok_all, mask_all, len_all = [], [], [], []
333
+ for i in pbar(range(0, len(seqs), CHEMBERTA_BATCH), desc=" ChemBERTa binder"):
334
+ batch = seqs[i:i + CHEMBERTA_BATCH]
335
+ enc = tok(batch, return_tensors="pt", padding=True,
336
+ truncation=True, max_length=CHEMBERTA_MAX_LEN)
337
+ ids = enc["input_ids"].to(DEVICE)
338
+ mask = enc["attention_mask"].to(DEVICE)
339
+ h = model(input_ids=ids, attention_mask=mask).last_hidden_state
340
+ p, t, m, l = _pool_and_unpool(h, ids, mask, sid_t)
341
+ pooled_all.append(p); tok_all.extend(t); mask_all.extend(m); len_all.extend(l)
342
+ return np.vstack(pooled_all), tok_all, mask_all, len_all
343
+
344
+
345
+ # ======================
346
+ # WT branch (ESM2 × ESM2)
347
+ # ======================
348
+
349
+ def run_wt_branch(wt_train: pd.DataFrame, wt_val: pd.DataFrame,
350
+ esm_tok, esm_model):
351
+ print("\n" + "="*55)
352
+ print(" Branch : WT (ESM2 target × ESM2 binder)")
353
+ print("="*55)
354
+
355
+ pooled_splits, unpooled_splits = {}, {}
356
+
357
+ for split_name, view in [("train", wt_train), ("val", wt_val)]:
358
+ print(f"\n [{split_name}] {len(view)} rows")
359
+ targets = view["target_sequence"].tolist()
360
+ binders = view["sequence"].tolist()
361
+
362
+ tgt_pooled = embed_esm_pooled(targets, esm_tok, esm_model)
363
+ bnd_pooled = embed_esm_pooled(binders, esm_tok, esm_model)
364
+
365
+ tgt_tok_embs, tgt_masks, tgt_lengths = embed_esm_unpooled(targets, esm_tok, esm_model)
366
+ bnd_tok_embs, bnd_masks, bnd_lengths = embed_esm_unpooled(binders, esm_tok, esm_model)
367
+
368
+ pooled_splits[split_name] = build_pooled_ds(
369
+ view, COL_WT_IPTM, tgt_pooled, bnd_pooled)
370
+ unpooled_splits[split_name] = build_unpooled_ds(
371
+ view, COL_WT_IPTM,
372
+ tgt_tok_embs, tgt_masks, tgt_lengths,
373
+ bnd_tok_embs, bnd_masks, bnd_lengths)
374
+
375
+ pooled_out = OUT_ROOT / "pair_wt_wt_pooled"
376
+ unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled"
377
+ DatasetDict(pooled_splits).save_to_disk(str(pooled_out))
378
+ DatasetDict(unpooled_splits).save_to_disk(str(unpooled_out))
379
+ print(f"\n WT pooled to {pooled_out}")
380
+ print(f" WT unpooled to {unpooled_out}")
381
+
382
+
383
+ # ======================
384
+ # SMILES branch (ESM2 × {PeptideCLM | ChemBERTa})
385
+ # ======================
386
+
387
+ def run_smiles_binder_model(name: str,
388
+ smi_train: pd.DataFrame, smi_val: pd.DataFrame,
389
+ esm_tok, esm_model,
390
+ load_fn, embed_fn):
391
+ print("\n" + "="*55)
392
+ print(f" Branch : SMILES (ESM2 target × {name} binder)")
393
+ print("="*55)
394
+
395
+ binder_tok, binder_model, sid_t = load_fn()
396
+ pooled_splits, unpooled_splits = {}, {}
397
+
398
+ for split_name, view in [("train", smi_train), ("val", smi_val)]:
399
+ print(f"\n [{split_name}] {len(view)} rows")
400
+ targets = view["target_sequence"].tolist()
401
+ binders = view["sequence"].tolist()
402
+
403
+ print(" ESM2 target - pooled ...")
404
+ tgt_pooled = embed_esm_pooled(targets, esm_tok, esm_model)
405
+
406
+ print(" ESM2 target - unpooled ...")
407
+ tgt_tok_embs, tgt_masks, tgt_lengths = embed_esm_unpooled(
408
+ targets, esm_tok, esm_model)
409
+
410
+ print(f" {name} binder - pooled + unpooled ...")
411
+ bnd_pooled, bnd_tok_embs, bnd_masks, bnd_lengths = embed_fn(
412
+ binders, binder_tok, binder_model, sid_t)
413
+
414
+ pooled_splits[split_name] = build_pooled_ds(
415
+ view, COL_SMI_IPTM, tgt_pooled, bnd_pooled)
416
+ unpooled_splits[split_name] = build_unpooled_ds(
417
+ view, COL_SMI_IPTM,
418
+ tgt_tok_embs, tgt_masks, tgt_lengths,
419
+ bnd_tok_embs, bnd_masks, bnd_lengths)
420
+
421
+ suffix = "" if name.lower() == "peptideclm" else f"_{name.lower()}"
422
+ pooled_out = OUT_ROOT / f"pair_wt_smiles_pooled{suffix}"
423
+ unpooled_out = OUT_ROOT / f"pair_wt_smiles_unpooled{suffix}"
424
+ DatasetDict(pooled_splits).save_to_disk(str(pooled_out))
425
+ DatasetDict(unpooled_splits).save_to_disk(str(unpooled_out))
426
+ print(f"\n {name} pooled to {pooled_out}")
427
+ print(f" {name} unpooled to {unpooled_out}")
428
+
429
+ del binder_model
430
+ torch.cuda.empty_cache()
431
+
432
+
433
+ # ======================
434
  # Main
435
+ # ======================
436
+
437
  def main():
438
+ print(f"Device : {DEVICE}")
439
+ print(f"CSV : {CSV_PATH}")
440
+ print(f"Out : {OUT_ROOT}\n")
441
  OUT_ROOT.mkdir(parents=True, exist_ok=True)
442
 
443
+ # ------------------------------------------------------------------
444
+ # 1. Load + dedup
445
+ # ------------------------------------------------------------------
446
  with section("load csv + dedup"):
447
  df = pd.read_csv(CSV_PATH)
448
+ print(f"Raw rows: {len(df)}")
449
+ df["orig_idx"] = df.index # traceability only
450
+
451
+ for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT, COL_MERGE]:
452
  if c in df.columns:
453
  df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
454
+
455
+ for col in [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]:
456
+ if col not in df.columns:
457
+ raise ValueError(f"Missing required column: '{col}'")
458
+
459
+ dedup_cols = [c for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT, COL_MERGE]
460
+ if c in df.columns]
461
+ before = len(df)
462
+ df = df.drop_duplicates(subset=dedup_cols, keep="first").reset_index(drop=True)
463
+ print(f"After dedup pass 1 (raw columns) : {len(df)} (-{before - len(df)})")
464
+
 
 
465
  df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
466
 
467
+ # ------------------------------------------------------------------
468
+ # 2. Prepare per-branch subsets
469
+ # ------------------------------------------------------------------
470
+ with section("prepare WT / SMILES subsets"):
471
+ # ── WT branch ──────────────────────────────────────────────────
472
+ # Both seq1 and seq2 must be canonical (no X) for ESM2
473
  df_wt = df.copy()
474
  df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip()
475
+ df_wt = df_wt.dropna(subset=[COL_AFF])
476
+ df_wt = df_wt[~df_wt[COL_SEQ1].astype(str).str.contains("X", case=False, na=False)]
477
+ df_wt = df_wt[df_wt["wt_sequence"] != ""]
478
+ df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)]
479
+ df_wt = df_wt.reset_index(drop=True)
480
+
481
+ # ── SMILES branch ──────────────────────────────────────────────
482
+ # seq1 must be canonical (no X) for ESM2; binder SMILES picked
483
+ # by priority (Fasta2SMILES > REACT_SMILES > Merge_SMILES), then
484
+ # dedup pass 2 on (seq1, picked smiles_sequence)
485
  df_smi = df.copy()
486
+ df_smi = df_smi.dropna(subset=[COL_AFF])
487
  df_smi = df_smi[
488
  pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna()
489
+ ]
490
+ df_smi = df_smi[~df_smi[COL_SEQ1].astype(str).str.contains("X", case=False, na=False)]
491
+ df_smi = df_smi.reset_index(drop=True)
492
+
493
+ df_smi["smiles_sequence"] = df_smi.apply(pick_smiles, axis=1)
494
+ df_smi = df_smi[df_smi["smiles_sequence"].notna()].reset_index(drop=True)
495
+ print(f"After requiring ≥1 valid SMILES : {len(df_smi)}")
496
+
497
+ # Dedup pass 2: (seq1, picked smiles_sequence)
498
+ before = len(df_smi)
499
+ df_smi = df_smi.drop_duplicates(
500
+ subset=[COL_SEQ1, "smiles_sequence"], keep="first"
501
+ ).reset_index(drop=True)
502
+ print(f"After dedup pass 2 (seq1, smiles_sequence): {len(df_smi)} (-{before - len(df_smi)})")
503
+
504
+ assert not df_smi.duplicated(subset=[COL_SEQ1, "smiles_sequence"]).any(), \
505
+ "BUG: duplicate (seq1, smiles_sequence) pairs remain!"
506
+
507
+ print(f"\n[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)}")
508
+
509
+ # ------------------------------------------------------------------
510
+ # 3. Split
511
+ # ------------------------------------------------------------------
512
+ with section("split WT and SMILES separately"):
513
+ df_wt2 = make_distribution_matched_split(df_wt)
514
  df_smi2 = make_distribution_matched_split(df_smi)
515
 
516
+ df_wt2.to_csv(OUT_ROOT / "binding_affinity_wt_meta_with_split.csv", index=False)
517
+ df_smi2.to_csv(OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv", index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
 
519
+ # ------------------------------------------------------------------
520
+ # 4. Build split views
521
+ # ------------------------------------------------------------------
522
+ wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM)
523
+ smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM)
524
+
525
+ wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True)
526
+ wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True)
 
 
 
 
 
 
 
 
 
 
 
527
  smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True)
528
  smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True)
529
+
530
+ print(f"\nSplit sizes - WT: train={len(wt_train)} val={len(wt_val)}")
531
+ print(f"Split sizes - SMILES: train={len(smi_train)} val={len(smi_val)}")
532
+
533
+ # ------------------------------------------------------------------
534
+ # 5. Load ESM2 once - shared across all branches
535
+ # ------------------------------------------------------------------
536
+ print("\nLoading ESM2 (shared target encoder) ...")
537
+ esm_tok, esm_model = load_esm()
538
+
539
+ # ------------------------------------------------------------------
540
+ # 6. WT branch
541
+ # ------------------------------------------------------------------
542
+ run_wt_branch(wt_train, wt_val, esm_tok, esm_model)
543
+
544
+ # ------------------------------------------------------------------
545
+ # 7. SMILES branches
546
+ # ------------------------------------------------------------------
547
+ if RUN_PEPTIDECLM:
548
+ run_smiles_binder_model(
549
+ "peptideclm", smi_train, smi_val,
550
+ esm_tok, esm_model,
551
+ load_fn=load_peptideclm,
552
+ embed_fn=embed_peptideclm,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  )
554
+
555
+ if RUN_CHEMBERTA:
556
+ run_smiles_binder_model(
557
+ "chemberta", smi_train, smi_val,
558
+ esm_tok, esm_model,
559
+ load_fn=load_chemberta,
560
+ embed_fn=embed_chemberta,
561
  )
562
+
563
+ print(f"\n All done. Datasets saved under: {OUT_ROOT}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
 
565
 
566
  if __name__ == "__main__":
567
+ main()
training_data_cleaned/embed_smiles.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pipeline:
3
+ 1. Read *_meta_with_split.csv (sequence, label, id, split)
4
+ 2. Convert wt sequences to SMILES via: fasta2smi -i peptides.fasta -o peptides.p2smi
5
+ 3. Parse .p2smi format: "{seq}-linear: {SMILES}"
6
+ 4. Embed SMILES with ChemBERTa to save pooled + unpooled DatasetDicts
7
+ 5. Embed SMILES with PeptideCLM to save pooled + unpooled DatasetDicts
8
+ """
9
+
10
+ import os
11
+ import subprocess
12
+ import tempfile
13
+ import sys
14
+ import numpy as np
15
+ import torch
16
+ import pandas as pd
17
+ from tqdm import tqdm
18
+ from datasets import Dataset, DatasetDict
19
+ from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
20
+
21
+ PROJECT_ROOT = "<>" # change here
22
+
23
+ # using permeability as example
24
+ META_CSV = (
25
+ f"{PROJECT_ROOT}/training_data_cleaned/"
26
+ "permeability_penetrance/permeability_meta_with_split.csv"
27
+ )
28
+ BASE_OUT = f"{PROJECT_ROOT}/alternative_embeddings"
29
+
30
+ # ChemBERTa
31
+ CHEMBERTA_MODEL = "DeepChem/ChemBERTa-77M-MLM"
32
+ CHEMBERTA_OUT = f"{BASE_OUT}/permeability_chemberta/perm_smiles_with_embeddings"
33
+
34
+ # PeptideCLM
35
+ sys.path.append(PROJECT_ROOT)
36
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
37
+
38
+ PEPTIDECLM_MODEL = "aaronfeller/PeptideCLM-23M-all"
39
+ PEPTIDECLM_TOKENIZER = f"{PROJECT_ROOT}/tokenizer/new_vocab.txt"
40
+ PEPTIDECLM_SPLITS = f"{PROJECT_ROOT}/tokenizer/new_splits.txt"
41
+ PEPTIDECLM_OUT = f"{BASE_OUT}/permeability_peptideclm/perm_smiles_with_embeddings"
42
+
43
+ # Column names in the CSV
44
+ SEQ_COL = "sequence"
45
+ LABEL_COL = "label"
46
+ SPLIT_COL = "split"
47
+ ID_COL = "id" # used as FASTA header; must be unique
48
+
49
+ # fasta2smi settings
50
+ FASTA2SMI_BIN = "fasta2smi" # install via github
51
+
52
+ # Embedding settings
53
+ MAX_LENGTH_CHEMBERTA = 512
54
+ MAX_LENGTH_PEPTIDECLM = 768
55
+ BATCH_SIZE = 128
56
+
57
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
+
59
+
60
+ # ===========================================================================================
61
+ # Step 1 — fasta2smi conversion, do not apply to properties that only have SMILES sequences
62
+ # ===========================================================================================
63
+ def sequences_to_smiles(sequences: list[str], ids: list[str]) -> dict[str, str]:
64
+ """
65
+ .p2smi format produced by fasta2smi:
66
+ MIIFAIAASHKK-linear: N[C@@H](CCSC)C(=O)...
67
+ KIAKLKAKIQ...-linear: N[C@@H](CCCCN)C(=O)...
68
+ """
69
+ with tempfile.TemporaryDirectory() as tmpdir:
70
+ fasta_path = os.path.join(tmpdir, "peptides.fasta")
71
+ p2smi_path = os.path.join(tmpdir, "peptides.p2smi")
72
+
73
+ with open(fasta_path, "w") as fh:
74
+ for sid, seq in zip(ids, sequences):
75
+ fh.write(f">{sid}\n{seq}\n")
76
+
77
+ cmd = [FASTA2SMI_BIN, "-i", fasta_path, "-o", p2smi_path]
78
+ print(f" Running: {' '.join(cmd)}")
79
+ result = subprocess.run(cmd, capture_output=True, text=True)
80
+ if result.returncode != 0:
81
+ raise RuntimeError(
82
+ f"fasta2smi failed (exit {result.returncode}):\n"
83
+ f" stdout: {result.stdout}\n stderr: {result.stderr}"
84
+ )
85
+
86
+ seq2smi = _parse_p2smi(p2smi_path)
87
+
88
+ n_ok = len(seq2smi)
89
+ n_fail = len(sequences) - n_ok
90
+ print(f" fasta2smi: {n_ok}/{len(sequences)} converted ({n_fail} failed/skipped)")
91
+ return seq2smi
92
+
93
+
94
+ def _parse_p2smi(path: str) -> dict[str, str]:
95
+ seq2smi: dict[str, str] = {}
96
+ with open(path) as fh:
97
+ for line in fh:
98
+ line = line.strip()
99
+ if not line or line.startswith("#"):
100
+ continue
101
+ # Split on "-linear: " — the separator fasta2smi uses
102
+ if "-linear: " not in line:
103
+ print(f" [WARN] Unexpected p2smi line, skipping: {line[:80]}")
104
+ continue
105
+ aa_seq, smi = line.split("-linear: ", maxsplit=1)
106
+ smi = smi.strip()
107
+ if smi and smi.lower() not in ("none", "null", "n/a"):
108
+ seq2smi[aa_seq] = smi
109
+ return seq2smi
110
+
111
+
112
+ # ============================================================
113
+ # Setups
114
+ # ============================================================
115
+ def _get_special_ids_tensor(tokenizer):
116
+ attrs = [
117
+ "pad_token_id", "cls_token_id", "sep_token_id",
118
+ "bos_token_id", "eos_token_id", "mask_token_id",
119
+ ]
120
+ ids = sorted({getattr(tokenizer, a, None) for a in attrs} - {None})
121
+ return torch.tensor(ids, device=device, dtype=torch.long) if ids else None
122
+
123
+
124
+ @torch.no_grad()
125
+ def _embed_batch(tokenizer, model, special_ids_t, sequences, max_length):
126
+ tok = tokenizer(
127
+ sequences, return_tensors="pt",
128
+ padding=True, max_length=max_length, truncation=True,
129
+ )
130
+ input_ids = tok["input_ids"].to(device)
131
+ attention_mask = tok["attention_mask"].to(device)
132
+
133
+ out = model(input_ids=input_ids, attention_mask=attention_mask)
134
+ last_hidden = out.last_hidden_state # (B, L, H)
135
+
136
+ valid = attention_mask.bool()
137
+ if special_ids_t is not None:
138
+ valid = valid & (~torch.isin(input_ids, special_ids_t))
139
+
140
+ valid_f = valid.unsqueeze(-1).float()
141
+ pooled = (
142
+ torch.sum(last_hidden * valid_f, dim=1)
143
+ / torch.clamp(valid_f.sum(dim=1), min=1e-9)
144
+ ).cpu().numpy() # (B, H) float32
145
+
146
+ token_embs, masks, lengths = [], [], []
147
+ for b in range(last_hidden.shape[0]):
148
+ emb = last_hidden[b, valid[b]].cpu().to(torch.float16).numpy()
149
+ token_embs.append(emb)
150
+ masks.append(np.ones(emb.shape[0], dtype=np.int8))
151
+ lengths.append(emb.shape[0])
152
+
153
+ return pooled, token_embs, masks, lengths
154
+
155
+
156
+ def _embed_all(tokenizer, model, special_ids_t, sequences, max_length):
157
+ pooled_all, token_all, mask_all, len_all = [], [], [], []
158
+ for i in tqdm(range(0, len(sequences), BATCH_SIZE), desc=" batches"):
159
+ p, t, m, l = _embed_batch(
160
+ tokenizer, model, special_ids_t,
161
+ sequences[i:i+BATCH_SIZE], max_length,
162
+ )
163
+ pooled_all.append(p)
164
+ token_all.extend(t)
165
+ mask_all.extend(m)
166
+ len_all.extend(l)
167
+ return np.vstack(pooled_all), token_all, mask_all, len_all
168
+
169
+
170
+ def _build_datasets(wt_seqs, smiles, labels, tokenizer, model, special_ids_t, max_length):
171
+ pooled, tok_embs, masks, lengths = _embed_all(
172
+ tokenizer, model, special_ids_t, smiles, max_length
173
+ )
174
+ pooled_ds = Dataset.from_dict({
175
+ "sequence": wt_seqs,
176
+ "smiles": smiles,
177
+ "label": labels,
178
+ "embedding": pooled,
179
+ })
180
+ full_ds = Dataset.from_dict({
181
+ "sequence": wt_seqs,
182
+ "smiles": smiles,
183
+ "label": labels,
184
+ "embedding": tok_embs,
185
+ "attention_mask": masks,
186
+ "length": lengths,
187
+ })
188
+ return pooled_ds, full_ds
189
+
190
+
191
+ def _save(splits: dict, out_path: str):
192
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
193
+ DatasetDict({k: v[0] for k, v in splits.items()}).save_to_disk(out_path)
194
+ DatasetDict({k: v[1] for k, v in splits.items()}).save_to_disk(out_path + "_unpooled")
195
+ print(f" Saved pooled to {out_path}")
196
+ print(f" Saved unpooled to {out_path}_unpooled")
197
+
198
+
199
+ # ============================================================
200
+ # ChemBERTa
201
+ # ============================================================
202
+ def run_chemberta(meta: pd.DataFrame):
203
+ print(f"\n{'='*60}")
204
+ print(" Encoder: ChemBERTa")
205
+ print(f"{'='*60}")
206
+
207
+ print(f" Loading {CHEMBERTA_MODEL} ...")
208
+ tokenizer = AutoTokenizer.from_pretrained(CHEMBERTA_MODEL)
209
+ model = AutoModel.from_pretrained(CHEMBERTA_MODEL).to(device).eval()
210
+ special_ids_t = _get_special_ids_tensor(tokenizer)
211
+
212
+ splits: dict[str, tuple] = {}
213
+ for split_name in ["train", "val"]:
214
+ df = meta[meta[SPLIT_COL] == split_name].reset_index(drop=True)
215
+ print(f"\n [{split_name}] {len(df)} rows")
216
+ if df.empty:
217
+ print(" [WARN] Empty split, skipping.")
218
+ continue
219
+ pooled_ds, full_ds = _build_datasets(
220
+ df[SEQ_COL].tolist(), df["smiles"].tolist(),
221
+ df[LABEL_COL].tolist(),
222
+ tokenizer, model, special_ids_t, MAX_LENGTH_CHEMBERTA,
223
+ )
224
+ splits[split_name] = (pooled_ds, full_ds)
225
+
226
+ _save(splits, CHEMBERTA_OUT)
227
+
228
+ # free GPU memory before loading next model
229
+ del model
230
+ torch.cuda.empty_cache()
231
+
232
+
233
+ # ============================================================
234
+ # PeptideCLM
235
+ # ============================================================
236
+ def run_peptideclm(meta: pd.DataFrame):
237
+ print(f"\n{'='*60}")
238
+ print(" Encoder: PeptideCLM")
239
+ print(f"{'='*60}")
240
+
241
+ print(f" Loading tokenizer from {PEPTIDECLM_TOKENIZER} ...")
242
+ tokenizer = SMILES_SPE_Tokenizer(PEPTIDECLM_TOKENIZER, PEPTIDECLM_SPLITS)
243
+
244
+ print(f" Loading {PEPTIDECLM_MODEL} ...")
245
+ full_model = AutoModelForMaskedLM.from_pretrained(PEPTIDECLM_MODEL)
246
+ model = full_model.roformer.to(device).eval()
247
+ special_ids_t = _get_special_ids_tensor(tokenizer)
248
+
249
+ splits: dict[str, tuple] = {}
250
+ for split_name in ["train", "val"]:
251
+ df = meta[meta[SPLIT_COL] == split_name].reset_index(drop=True)
252
+ print(f"\n [{split_name}] {len(df)} rows")
253
+ if df.empty:
254
+ print(" [WARN] Empty split, skipping.")
255
+ continue
256
+ pooled_ds, full_ds = _build_datasets(
257
+ df[SEQ_COL].tolist(), df["smiles"].tolist(),
258
+ df[LABEL_COL].tolist(),
259
+ tokenizer, model, special_ids_t, MAX_LENGTH_PEPTIDECLM,
260
+ )
261
+ splits[split_name] = (pooled_ds, full_ds)
262
+
263
+ _save(splits, PEPTIDECLM_OUT)
264
+
265
+ del model
266
+ torch.cuda.empty_cache()
267
+
268
+
269
+ # ============================================================
270
+ # Main
271
+ # ============================================================
272
+ def main():
273
+ print(f"\nDevice : {device}")
274
+ print(f"Meta : {META_CSV}")
275
+
276
+ # Load metadata
277
+ meta = pd.read_csv(META_CSV, sep=None, engine="python")
278
+ print(f"Loaded {len(meta)} rows. Columns: {meta.columns.tolist()}")
279
+ for col in [SEQ_COL, LABEL_COL, SPLIT_COL]:
280
+ if col not in meta.columns:
281
+ raise ValueError(f"Expected column '{col}' not found. Available: {meta.columns.tolist()}")
282
+
283
+ # Ensure numeric labels
284
+ meta[LABEL_COL] = pd.to_numeric(meta[LABEL_COL], errors="coerce")
285
+ meta = meta.dropna(subset=[SEQ_COL, LABEL_COL]).reset_index(drop=True)
286
+
287
+ # Build id list for FASTA headers
288
+ if ID_COL in meta.columns:
289
+ ids = meta[ID_COL].astype(str).tolist()
290
+ else:
291
+ ids = [f"seq_{i}" for i in range(len(meta))]
292
+
293
+ # Note that for properties start with SMILES sequences, fasta2smi is not needed
294
+ # Convert wt to SMILES (single fasta2smi call for the whole dataset)
295
+ print("\nConverting peptide sequences to SMILES ...")
296
+ seqs = meta[SEQ_COL].astype(str).tolist()
297
+ seq2smi = sequences_to_smiles(seqs, ids)
298
+
299
+ meta["smiles"] = meta[SEQ_COL].astype(str).map(seq2smi)
300
+ n_missing = meta["smiles"].isna().sum()
301
+ if n_missing:
302
+ print(f" [WARN] {n_missing} sequences had no SMILES — dropping.")
303
+ meta = meta.dropna(subset=["smiles"]).reset_index(drop=True)
304
+ print(f" Retained {len(meta)} rows with valid SMILES.")
305
+ # Save SMILES-enriched meta CSV
306
+ smiles_meta_path = os.path.join(BASE_OUT, "permeability_smiles_meta_with_split.csv")
307
+ os.makedirs(BASE_OUT, exist_ok=True)
308
+ meta.to_csv(smiles_meta_path, index=False)
309
+ print(f" Saved SMILES meta to {smiles_meta_path}")
310
+
311
+ # Run both encoders sequentially (share the same converted SMILES)
312
+ #run_chemberta(meta)
313
+ #run_peptideclm(meta)
314
+
315
+ print("\nAll done.")
316
+
317
+
318
+ if __name__ == "__main__":
319
+ main()
training_data_cleaned/permeability_penetrance/permeability_smiles_meta_with_split.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbece0b3b8345cae1ce6fe2e9a1a10ddd5320bae18c3a7a3f958b97b98979796
3
+ size 947525
training_data_cleaned/smiles_data_split.py CHANGED
@@ -15,6 +15,7 @@ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
15
 
16
  seed_everything(1986)
17
 
 
18
  df = pd.read_csv("caco2.csv")
19
 
20
  mols = []
@@ -87,151 +88,4 @@ df[df["split"] == "train"].to_csv("caco2_train.csv", index=False)
87
  df[df["split"] == "val"].to_csv("caco2_val.csv", index=False)
88
  df.to_csv("caco2_meta_with_split.csv", index=False)
89
 
90
- print(df["split"].value_counts())
91
-
92
- # ======================
93
- # Config
94
- # ======================
95
- MAX_LENGTH = 768
96
- BATCH_SIZE = 128
97
-
98
- TRAIN_CSV = "caco2_train.csv"
99
- VAL_CSV = "caco2_val.csv"
100
-
101
- SMILES_COL = "SMILES"
102
- LABEL_COL = "Caco2"
103
-
104
- OUT_PATH = "./Classifier_Weight/training_data_cleaned/permeability_caco2/caco2_smiles_with_embeddings"
105
-
106
- # GPU device
107
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
108
- print(f"Using device: {device}")
109
-
110
- # ======================
111
- # Load tokenizer + model
112
- # ======================
113
- print("Loading tokenizer and model...")
114
- tokenizer = SMILES_SPE_Tokenizer(
115
- "./Classifier_Weight/tokenizer/new_vocab.txt",
116
- "./Classifier_Weight/tokenizer/new_splits.txt",
117
- )
118
-
119
- embedding_model = AutoModelForMaskedLM.from_pretrained("aaronfeller/PeptideCLM-23M-all").roformer
120
- embedding_model.to(device)
121
- embedding_model.eval()
122
-
123
- HIDDEN_KEY = "last_hidden_state"
124
-
125
- def get_special_ids(tokenizer):
126
- cand = [
127
- getattr(tokenizer, "pad_token_id", None),
128
- getattr(tokenizer, "cls_token_id", None),
129
- getattr(tokenizer, "sep_token_id", None),
130
- getattr(tokenizer, "bos_token_id", None),
131
- getattr(tokenizer, "eos_token_id", None),
132
- getattr(tokenizer, "mask_token_id", None),
133
- ]
134
- special_ids = sorted({x for x in cand if x is not None})
135
- if len(special_ids) == 0:
136
- print("[WARN] No special token ids found on tokenizer; pooling will only exclude padding via attention_mask.")
137
- return special_ids
138
-
139
- SPECIAL_IDS = get_special_ids(tokenizer)
140
- SPECIAL_IDS_T = torch.tensor(SPECIAL_IDS, device=device, dtype=torch.long) if len(SPECIAL_IDS) else None
141
-
142
- @torch.no_grad()
143
- def embed_batch_return_both(batch_sequences, max_length, device):
144
- tok = tokenizer(
145
- batch_sequences,
146
- return_tensors="pt",
147
- padding=True,
148
- max_length=max_length,
149
- truncation=True,
150
- )
151
- input_ids = tok["input_ids"].to(device) # (B, L)
152
- attention_mask = tok["attention_mask"].to(device) # (B, L)
153
-
154
- outputs = embedding_model(input_ids=input_ids, attention_mask=attention_mask)
155
- last_hidden = outputs.last_hidden_state # (B, L, H)
156
-
157
- valid = attention_mask.bool()
158
- if SPECIAL_IDS_T is not None and SPECIAL_IDS_T.numel() > 0:
159
- valid = valid & (~torch.isin(input_ids, SPECIAL_IDS_T))
160
-
161
- # --- pooled embeddings (exclude specials) ---
162
- valid_f = valid.unsqueeze(-1).float() # (B, L, 1)
163
- summed = torch.sum(last_hidden * valid_f, dim=1) # (B, H)
164
- denom = torch.clamp(valid_f.sum(dim=1), min=1e-9) # (B, 1)
165
- pooled = (summed / denom).detach().cpu().numpy() # (B, H), float32
166
-
167
- # --- unpooled per-example token embeddings (exclude specials) ---
168
- token_emb_list = []
169
- mask_list = []
170
- lengths = []
171
- for b in range(last_hidden.shape[0]):
172
- emb = last_hidden[b, valid[b]] # (L_i, H)
173
- token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy()) # float16
174
- L_i = emb.shape[0]
175
- lengths.append(int(L_i))
176
- mask_list.append(np.ones((L_i,), dtype=np.int8))
177
-
178
- return pooled, token_emb_list, mask_list, lengths
179
-
180
- def generate_embeddings_batched_both(sequences, batch_size, max_length):
181
- pooled_all = []
182
- token_emb_all = []
183
- mask_all = []
184
- lengths_all = []
185
-
186
- for i in tqdm(range(0, len(sequences), batch_size), desc="Embedding batches"):
187
- batch = sequences[i:i + batch_size]
188
- pooled, token_list, m_list, lens = embed_batch_return_both(batch, max_length, device)
189
- pooled_all.append(pooled)
190
- token_emb_all.extend(token_list)
191
- mask_all.extend(m_list)
192
- lengths_all.extend(lens)
193
-
194
- pooled_all = np.vstack(pooled_all) # (N, H)
195
- return pooled_all, token_emb_all, mask_all, lengths_all
196
-
197
- from datasets import Dataset, DatasetDict
198
-
199
- def make_split_datasets(csv_path, split_name):
200
- df = pd.read_csv(csv_path)
201
- df = df.dropna(subset=[SMILES_COL, LABEL_COL]).reset_index(drop=True)
202
- df["sequence"] = df[SMILES_COL].astype(str)
203
-
204
- labels = pd.to_numeric(df[LABEL_COL], errors="coerce")
205
- df = df.loc[~labels.isna()].reset_index(drop=True)
206
- sequences = df["sequence"].tolist()
207
- labels = pd.to_numeric(df[LABEL_COL], errors="coerce").tolist()
208
-
209
- # (pooled_embs: (N,H), token_emb_list: list of (L_i,H), mask_list: list of (L_i,), lengths: list[int])
210
- pooled_embs, token_emb_list, mask_list, lengths = generate_embeddings_batched_both(
211
- sequences, batch_size=BATCH_SIZE, max_length=MAX_LENGTH
212
- )
213
-
214
- pooled_ds = Dataset.from_dict({
215
- "sequence": sequences,
216
- "label": labels,
217
- "embedding": pooled_embs, # (N,H)
218
- })
219
-
220
- full_ds = Dataset.from_dict({
221
- "sequence": sequences,
222
- "label": labels,
223
- "embedding": token_emb_list, # each (L_i,H) float16
224
- "attention_mask": mask_list, # each (L_i,) int8 ones
225
- "length": lengths,
226
- })
227
-
228
- return pooled_ds, full_ds
229
-
230
- train_pooled, train_full = make_split_datasets(TRAIN_CSV, "train")
231
- val_pooled, val_full = make_split_datasets(VAL_CSV, "val")
232
-
233
- ds_pooled = DatasetDict({"train": train_pooled, "val": val_pooled})
234
- ds_full = DatasetDict({"train": train_full, "val": val_full})
235
-
236
- ds_pooled.save_to_disk(OUT_PATH)
237
- ds_full.save_to_disk(OUT_PATH + "_unpooled")
 
15
 
16
  seed_everything(1986)
17
 
18
+ # Starting with a raw dataframe, using caco2 as example.
19
  df = pd.read_csv("caco2.csv")
20
 
21
  mols = []
 
88
  df[df["split"] == "val"].to_csv("caco2_val.csv", index=False)
89
  df.to_csv("caco2_meta_with_split.csv", index=False)
90
 
91
+ print(df["split"].value_counts())