chq1155 commited on
Commit
ee6da62
·
verified ·
1 Parent(s): e32d6dc

Upload TD3B code (inference, training, baselines)

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ scoring/functions/classifiers/permeability-xgboost.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,128 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TD3B: Transition-Directed Discrete Diffusion for Allosteric Binder Generation
2
+
3
+ TD3B is a sequence-based generative framework that designs peptide binders with specified agonist or antagonist behavior. It combines a Direction Oracle, a soft binding-affinity gate, and amortized fine-tuning of a pre-trained discrete diffusion model (MDLM).
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ conda env create -f env.yml
9
+ conda activate td3b
10
+ pip install -e .
11
+ ```
12
+
13
+ ## Data and Checkpoints
14
+
15
+ Download the pretrained checkpoints and data from [Google Drive (TBA)](placeholder_link).
16
+
17
+ Place the files as follows:
18
+
19
+ ```
20
+ TD3B/
21
+ ├── checkpoints/
22
+ │ ├── pretrained.ckpt # Pre-trained MDLM weights
23
+ │ ├── td3b.ckpt # Fine-tuned TD3B model
24
+ │ └── direction_oracle.pt # Direction Oracle weights
25
+ ├── data/
26
+ │ ├── train.csv # Training set (target-binder pairs)
27
+ │ └── test.csv # Test set
28
+ ├── scoring/functions/classifiers/
29
+ │ ├── binding-affinity.pt
30
+ │ ├── hemolysis-xgboost.json
31
+ │ ├── nonfouling-xgboost.json
32
+ │ ├── permeability-xgboost.json
33
+ │ └── solubility-xgboost.json
34
+ └── tokenizer/
35
+ ├── new_vocab.txt
36
+ └── new_splits.txt
37
+ ```
38
+
39
+ ## Code Structure
40
+
41
+ ```
42
+ TD3B/
43
+ ├── inference.py # Generate binders (main inference entry point)
44
+ ├── finetune_multi_target.py # Multi-target TD3B training
45
+ ├── finetune_utils.py # Training utilities
46
+ ├── launch_multi_target.sh # Training launcher script
47
+ ├── diffusion.py # MDLM backbone (TR2-D2)
48
+ ├── roformer.py # RoFormer wrapper
49
+ ├── noise_schedule.py # Noise schedules
50
+ ├── peptide_mcts.py # MCTS tree search
51
+ ├── td3b/
52
+ │ ├── direction_oracle.py # Direction Oracle (f_φ)
53
+ │ ├── td3b_scoring.py # Gated reward R = g_ψ · σ(d*·(f_φ−0.5)/τ)
54
+ │ ├── td3b_losses.py # L_WDCE + λ·L_ctr + β·L_KL
55
+ │ ├── td3b_mcts.py # TD3B-extended MCTS
56
+ │ ├── td3b_finetune.py # Training loop
57
+ │ └── data_utils.py # Data loading utilities
58
+ ├── scoring/ # Affinity predictor (g_ψ) and property classifiers
59
+ ├── baselines/ # CG, SMC, TDS, PepTune, Unguided baselines
60
+ ├── tokenizer/ # SMILES tokenizer (vocab + splits)
61
+ ├── configs/ # Model and training configs
62
+ └── utils/ # Misc utilities
63
+ ```
64
+
65
+ ## Inference
66
+
67
+ Generate agonist/antagonist binders for target proteins:
68
+
69
+ ```bash
70
+ python inference.py \
71
+ --ckpt_path checkpoints/td3b.ckpt \
72
+ --val_csv data/test.csv \
73
+ --save_path results/ \
74
+ --seed 42 \
75
+ --num_pool 32 \
76
+ --val_samples_per_target 8 \
77
+ --resample_alpha 0.1
78
+ ```
79
+
80
+ This generates 32 candidates per (target, direction), scores them with the Direction Oracle and affinity predictor, applies Algorithm 2 weighted resampling, and saves only valid peptide samples.
81
+
82
+ Output: `results/td3b_results_seed42.csv` with columns: target, sequence, direction, affinity, gated_reward, direction_oracle, direction_accuracy.
83
+
84
+ ## Training
85
+
86
+ ### Multi-target TD3B
87
+
88
+ 1. Edit `launch_multi_target.sh` — set paths to checkpoints, data, and oracle:
89
+
90
+ ```bash
91
+ BASE_PATH="/path/to/TD3B"
92
+ PRETRAINED_CHECKPOINT="${BASE_PATH}/checkpoints/pretrained.ckpt"
93
+ TRAIN_CSV="${BASE_PATH}/data/train.csv"
94
+ ORACLE_CKPT="${BASE_PATH}/checkpoints/direction_oracle.pt"
95
+ ```
96
+
97
+ 2. Launch training:
98
+
99
+ ```bash
100
+ bash launch_multi_target.sh
101
+ ```
102
+
103
+ Key hyperparameters (in `launch_multi_target.sh`):
104
+ - `CONTRASTIVE_WEIGHT=0.1` — λ for L_ctr
105
+ - `KL_BETA=0.1` — β for L_KL
106
+ - `SIGMOID_TEMPERATURE=0.1` — τ for gated reward
107
+ - `NUM_ITER=20` — MCTS iterations per round
108
+ - `NUM_CHILDREN=16` — Children per MCTS expansion
109
+
110
+ ### Baselines
111
+
112
+ Run baseline methods (CG, SMC, TDS, PepTune, Unguided):
113
+
114
+ ```bash
115
+ cd baselines/
116
+ bash run.sh --baseline cg --device cuda:0
117
+ bash run.sh --baseline smc --device cuda:0
118
+ bash run.sh --baseline tds --device cuda:0
119
+ ```
120
+
121
+ ## Citation
122
+
123
+ ```bibtex
124
+ @article{caotd3b,
125
+ title={TD3B: Transition-Directed Discrete Diffusion for Allosteric Binder Generation},
126
+ author={Cao, Hanqun and Pal, Aastha and Tang, Sophia and Zhang, Yinuo and Zhang, Jingjie and Heng, Pheng-Ann and Chatterjee, Pranam}
127
+ }
128
+ ```
baselines/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from baselines.baselines import (
2
+ RewardInputs,
3
+ RewardWrapper,
4
+ classifier_guidance,
5
+ peptune_mctg_sampling,
6
+ unguided_sampling,
7
+ sequential_monte_carlo,
8
+ twisted_diffusion_sampler,
9
+ )
10
+
11
+ __all__ = [
12
+ "RewardInputs",
13
+ "RewardWrapper",
14
+ "classifier_guidance",
15
+ "peptune_mctg_sampling",
16
+ "unguided_sampling",
17
+ "sequential_monte_carlo",
18
+ "twisted_diffusion_sampler",
19
+ ]
baselines/baselines.py ADDED
@@ -0,0 +1,746 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ from dataclasses import dataclass
4
+ from types import SimpleNamespace
5
+ from typing import Callable, Dict, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ DEFAULT_EPS = 1e-5
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def _sample_categorical(categorical_probs: torch.Tensor) -> torch.Tensor:
17
+ gumbel = 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()
18
+ return (categorical_probs / gumbel).argmax(dim=-1).to(dtype=torch.long)
19
+
20
+
21
+ def _normalize_probs(probs: torch.Tensor, dim: int = -1) -> torch.Tensor:
22
+ return probs / probs.sum(dim=dim, keepdim=True).clamp_min(1e-12)
23
+
24
+
25
+ def _safe_resample_weights(weights: torch.Tensor) -> torch.Tensor:
26
+ if weights.numel() == 0:
27
+ return weights
28
+ weights = torch.where(torch.isfinite(weights), weights, torch.zeros_like(weights))
29
+ total = weights.sum()
30
+ if not torch.isfinite(total) or total <= 0:
31
+ return torch.full_like(weights, 1.0 / weights.numel())
32
+ return weights / total
33
+
34
+
35
+ def _sequence_logprob(
36
+ probs: torch.Tensor,
37
+ x_next: torch.Tensor,
38
+ x_current: torch.Tensor,
39
+ mask_idx: int,
40
+ ) -> torch.Tensor:
41
+ gather = probs.gather(-1, x_next.unsqueeze(-1)).squeeze(-1).clamp_min(1e-12)
42
+ mask = (x_current == mask_idx).to(gather.dtype)
43
+ return (gather.log() * mask).sum(dim=-1)
44
+
45
+
46
+ def _transition_probs_from_logits(
47
+ log_probs: torch.Tensor,
48
+ t: torch.Tensor,
49
+ dt: torch.Tensor,
50
+ mask_idx: int,
51
+ ) -> torch.Tensor:
52
+ change_prob_t = t[:, None, None]
53
+ change_prob_s = (t - dt)[:, None, None]
54
+ q_xs = log_probs.exp() * (change_prob_t - change_prob_s)
55
+ q_xs[:, :, mask_idx] = change_prob_s[:, :, 0]
56
+ return q_xs
57
+
58
+
59
+ def _sample_from_q(
60
+ q_probs: torch.Tensor,
61
+ x_current: torch.Tensor,
62
+ mask_idx: int,
63
+ ) -> torch.Tensor:
64
+ x_changed = _sample_categorical(q_probs)
65
+ copy_flag = (x_current != mask_idx)
66
+ return torch.where(copy_flag, x_current, x_changed)
67
+
68
+
69
+ def _protein_tokens_to_device(tokens: torch.Tensor, device: torch.device) -> torch.Tensor:
70
+ if tokens.device != device:
71
+ return tokens.to(device)
72
+ return tokens
73
+
74
+
75
+ def _tokens_to_one_hot(tokens: torch.Tensor, vocab_size: int) -> torch.Tensor:
76
+ return F.one_hot(tokens, num_classes=vocab_size).float()
77
+
78
+
79
+ def _decode_sequences(tokenizer, token_ids: torch.Tensor) -> list:
80
+ return tokenizer.batch_decode(token_ids)
81
+
82
+
83
+ def _affinity_from_scoring(
84
+ scoring_fn: Callable,
85
+ sequences: list,
86
+ device: torch.device,
87
+ protein_seq: Optional[str] = None,
88
+ ) -> torch.Tensor:
89
+ if protein_seq is not None:
90
+ try:
91
+ scores = scoring_fn(sequences, protein_seq)
92
+ except TypeError:
93
+ try:
94
+ scores = scoring_fn(sequences, prot_seq=protein_seq)
95
+ except TypeError:
96
+ scores = scoring_fn(sequences)
97
+ else:
98
+ scores = scoring_fn(sequences)
99
+ if isinstance(scores, tuple):
100
+ scores = scores[0]
101
+ scores = np.asarray(scores)
102
+ if scores.ndim == 1:
103
+ affinity = scores
104
+ else:
105
+ affinity = scores[:, 0]
106
+ return torch.as_tensor(affinity, device=device, dtype=torch.float32)
107
+
108
+
109
+ def _roformer_hidden_from_inputs(
110
+ base_model,
111
+ input_ids: Optional[torch.Tensor] = None,
112
+ inputs_embeds: Optional[torch.Tensor] = None,
113
+ attn_mask: Optional[torch.Tensor] = None,
114
+ ) -> torch.Tensor:
115
+ outputs = base_model.backbone.model(
116
+ input_ids=input_ids,
117
+ inputs_embeds=inputs_embeds,
118
+ attention_mask=attn_mask,
119
+ output_hidden_states=True,
120
+ return_dict=True,
121
+ )
122
+ return outputs.hidden_states[-1]
123
+
124
+
125
+ def _logits_from_inputs(
126
+ base_model,
127
+ input_ids: Optional[torch.Tensor] = None,
128
+ inputs_embeds: Optional[torch.Tensor] = None,
129
+ attn_mask: Optional[torch.Tensor] = None,
130
+ ) -> torch.Tensor:
131
+ outputs = base_model.backbone.model(
132
+ input_ids=input_ids,
133
+ inputs_embeds=inputs_embeds,
134
+ attention_mask=attn_mask,
135
+ output_hidden_states=False,
136
+ return_dict=True,
137
+ )
138
+ return outputs.logits
139
+
140
+
141
+ @dataclass
142
+ class RewardInputs:
143
+ protein_tokens: torch.Tensor
144
+ d_star: float
145
+ protein_seq: str
146
+
147
+
148
+ class RewardWrapper:
149
+ def __init__(
150
+ self,
151
+ scoring_fn: Callable,
152
+ direction_oracle: torch.nn.Module,
153
+ base_model,
154
+ tokenizer,
155
+ reward_inputs: RewardInputs,
156
+ device: torch.device,
157
+ fast_direction: bool = False,
158
+ reward_alpha: float = 0.1,
159
+ ):
160
+ self.scoring_fn = scoring_fn
161
+ self.direction_oracle = direction_oracle
162
+ self.base_model = base_model
163
+ self.tokenizer = tokenizer
164
+ self.reward_inputs = reward_inputs
165
+ self.device = device
166
+ self.fast_direction = fast_direction
167
+ self.reward_alpha = reward_alpha
168
+ self._supports_hidden_direction = all(
169
+ hasattr(direction_oracle, attr)
170
+ for attr in ("protein_embedder", "fusion", "classifier")
171
+ )
172
+ self._supports_predict = hasattr(direction_oracle, "predict_with_confidence")
173
+ if self.fast_direction and not self._supports_hidden_direction:
174
+ logger.warning("fast_direction requested but oracle lacks hidden-direction modules; disabling fast_direction.")
175
+ self.fast_direction = False
176
+ self._protein_emb_cache = None
177
+ if self.reward_inputs.protein_seq is None:
178
+ raise ValueError("RewardInputs.protein_seq is required for conditioned sampling.")
179
+
180
+ def _protein_emb(self, batch_size: int) -> torch.Tensor:
181
+ if not self._supports_hidden_direction:
182
+ raise RuntimeError("direction_oracle does not support hidden-direction inference.")
183
+ if self._protein_emb_cache is None:
184
+ prot_tokens = _protein_tokens_to_device(self.reward_inputs.protein_tokens, self.device)
185
+ prot_emb = self.direction_oracle.protein_embedder(prot_tokens)
186
+ self._protein_emb_cache = prot_emb
187
+ return self._protein_emb_cache.expand(batch_size, -1)
188
+
189
+ def _direction_from_hidden(
190
+ self,
191
+ hidden: torch.Tensor,
192
+ attn_mask: torch.Tensor,
193
+ ) -> torch.Tensor:
194
+ if not self._supports_hidden_direction:
195
+ raise RuntimeError("direction_oracle does not support hidden-direction inference.")
196
+ mask = attn_mask.to(hidden.dtype).unsqueeze(-1)
197
+ pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0)
198
+ protein_emb = self._protein_emb(pooled.size(0))
199
+ fused = self.direction_oracle.fusion(pooled, protein_emb)
200
+ return self.direction_oracle.classifier(fused).squeeze(-1)
201
+
202
+ def _direction_from_probs(
203
+ self,
204
+ y_probs: torch.Tensor,
205
+ attn_mask: torch.Tensor,
206
+ ) -> torch.Tensor:
207
+ if hasattr(self.direction_oracle, "predict_from_probs"):
208
+ prot_tokens = _protein_tokens_to_device(self.reward_inputs.protein_tokens, self.device)
209
+ return self.direction_oracle.predict_from_probs(y_probs, prot_tokens, attn_mask)
210
+ if not self._supports_hidden_direction:
211
+ token_ids = y_probs.argmax(dim=-1)
212
+ return self._direction_from_tokens(token_ids)
213
+ if self.fast_direction:
214
+ emb_weight = self.base_model.backbone.model.roformer.embeddings.word_embeddings.weight
215
+ inputs_embeds = y_probs @ emb_weight
216
+ hidden = inputs_embeds
217
+ else:
218
+ emb_weight = self.base_model.backbone.model.roformer.embeddings.word_embeddings.weight
219
+ inputs_embeds = y_probs @ emb_weight
220
+ hidden = _roformer_hidden_from_inputs(
221
+ self.base_model,
222
+ inputs_embeds=inputs_embeds,
223
+ attn_mask=attn_mask,
224
+ )
225
+ return self._direction_from_hidden(hidden, attn_mask)
226
+
227
+ def _direction_from_tokens(self, token_ids: torch.Tensor) -> torch.Tensor:
228
+ prot_tokens = _protein_tokens_to_device(self.reward_inputs.protein_tokens, self.device)
229
+ if prot_tokens.dim() == 2 and prot_tokens.size(0) == 1:
230
+ prot_tokens = prot_tokens.expand(token_ids.size(0), -1)
231
+ if self._supports_predict:
232
+ direction, _ = self.direction_oracle.predict_with_confidence(token_ids, prot_tokens)
233
+ return direction
234
+ return self.direction_oracle(token_ids, prot_tokens)
235
+
236
+ def _gated_reward(self, affinity: torch.Tensor, direction: torch.Tensor) -> torch.Tensor:
237
+ d_star = torch.as_tensor(self.reward_inputs.d_star, device=self.device, dtype=direction.dtype)
238
+ directional_score = (direction - 0.5) * d_star
239
+ gate = torch.sigmoid(directional_score / self.reward_alpha)
240
+ return affinity * gate
241
+
242
+ def evaluate_tokens(self, token_ids: torch.Tensor, attn_mask: torch.Tensor) -> Dict[str, torch.Tensor]:
243
+ sequences = _decode_sequences(self.tokenizer, token_ids)
244
+ affinity = _affinity_from_scoring(
245
+ self.scoring_fn,
246
+ sequences,
247
+ self.device,
248
+ protein_seq=self.reward_inputs.protein_seq,
249
+ )
250
+ with torch.no_grad():
251
+ direction = self._direction_from_tokens(token_ids)
252
+ gated_reward = self._gated_reward(affinity, direction)
253
+ return {
254
+ "sequences": sequences,
255
+ "affinity": affinity,
256
+ "direction": direction,
257
+ "gated_reward": gated_reward,
258
+ }
259
+
260
+ def reward_from_tokens(
261
+ self,
262
+ token_ids: torch.Tensor,
263
+ attn_mask: torch.Tensor,
264
+ ) -> torch.Tensor:
265
+ sequences = _decode_sequences(self.tokenizer, token_ids)
266
+ affinity = _affinity_from_scoring(
267
+ self.scoring_fn,
268
+ sequences,
269
+ self.device,
270
+ protein_seq=self.reward_inputs.protein_seq,
271
+ )
272
+ with torch.no_grad():
273
+ direction = self._direction_from_tokens(token_ids)
274
+ return self._gated_reward(affinity, direction)
275
+
276
+ def reward_from_probs(
277
+ self,
278
+ y_probs: torch.Tensor,
279
+ token_ids_for_affinity: torch.Tensor,
280
+ attn_mask: torch.Tensor,
281
+ ) -> torch.Tensor:
282
+ affinity = None
283
+ if hasattr(self.scoring_fn, "forward_from_probs"):
284
+ try:
285
+ affinity = self.scoring_fn.forward_from_probs(
286
+ y_probs,
287
+ attn_mask,
288
+ prot_seq=self.reward_inputs.protein_seq,
289
+ )
290
+ except Exception as exc:
291
+ logger.warning("Differentiable affinity failed; falling back to argmax. Error: %s", exc)
292
+ affinity = None
293
+ if affinity is None:
294
+ sequences = _decode_sequences(self.tokenizer, token_ids_for_affinity)
295
+ affinity = _affinity_from_scoring(
296
+ self.scoring_fn,
297
+ sequences,
298
+ self.device,
299
+ protein_seq=self.reward_inputs.protein_seq,
300
+ )
301
+ direction = self._direction_from_probs(y_probs, attn_mask)
302
+ return self._gated_reward(affinity, direction)
303
+
304
+
305
+ class PepTuneSampler:
306
+ def __init__(
307
+ self,
308
+ base_model,
309
+ reward_fn: RewardWrapper,
310
+ seq_length: int,
311
+ num_steps: int,
312
+ mcts_iterations: int,
313
+ num_children: int,
314
+ sample_prob_weight: float,
315
+ invalid_penalty: float,
316
+ pareto_max_size: Optional[int],
317
+ eps: float,
318
+ ):
319
+ from peptide_mcts import Node, updateParetoFront
320
+ from utils.app import PeptideAnalyzer
321
+
322
+ self.base_model = base_model
323
+ self.reward_fn = reward_fn
324
+ self.seq_length = seq_length
325
+ self.num_steps = num_steps
326
+ self.mcts_iterations = mcts_iterations
327
+ self.num_children = num_children
328
+ self.sample_prob_weight = sample_prob_weight
329
+ self.invalid_penalty = invalid_penalty
330
+ self.pareto_max_size = pareto_max_size
331
+ self.eps = eps
332
+
333
+ self.device = base_model.device
334
+ self.mask_idx = base_model.mask_index
335
+ self.tokenizer = base_model.tokenizer
336
+ self.analyzer = PeptideAnalyzer()
337
+ self.Node = Node
338
+ self.updateParetoFront = updateParetoFront
339
+
340
+ self.timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
341
+ self.dt = torch.as_tensor((1 - eps) / num_steps, device=self.device)
342
+ self.args = SimpleNamespace(
343
+ num_obj=1,
344
+ total_num_steps=num_steps,
345
+ seq_length=seq_length,
346
+ num_children=num_children,
347
+ )
348
+
349
+ def _init_root(self):
350
+ masked_seq = torch.full((self.seq_length,), self.mask_idx, device=self.device, dtype=torch.long)
351
+ attn_mask = torch.ones_like(masked_seq, device=self.device)
352
+ tokens = {"seqs": masked_seq, "attention_mask": attn_mask}
353
+ return self.Node(
354
+ args=self.args,
355
+ tokens=tokens,
356
+ log_rnd=torch.zeros((), device=self.device),
357
+ log_policy_step=torch.zeros((), device=self.device),
358
+ log_pretrained_step=torch.zeros((), device=self.device),
359
+ totalReward=np.zeros(self.args.num_obj),
360
+ timestep=0,
361
+ )
362
+
363
+ def _select(self, root):
364
+ node = root
365
+ while True:
366
+ node, status = node.selectNode()
367
+ if status != 3:
368
+ return node, status
369
+
370
+ def _update_pareto(self, pareto_front, pareto_tokens, seq, token_ids, score_vector):
371
+ pareto_front = self.updateParetoFront(
372
+ pareto_front,
373
+ seq,
374
+ score_vector,
375
+ totalSize=self.pareto_max_size,
376
+ )
377
+ pareto_tokens = {k: pareto_tokens[k] for k in pareto_front if k in pareto_tokens}
378
+ if seq in pareto_front:
379
+ pareto_tokens[seq] = token_ids.detach().clone()
380
+ return pareto_front, pareto_tokens
381
+
382
+ def _expand(self, parent, pareto_front, pareto_tokens):
383
+ parent_tokens = parent.tokens["seqs"].to(self.device)
384
+ attn_mask = parent.tokens["attention_mask"].to(self.device)
385
+ t = self.timesteps[parent.timestep] * torch.ones(1, 1, device=self.device)
386
+
387
+ with torch.no_grad():
388
+ _, x_children, log_policy_step, log_pretrained_step = self.base_model.batch_mcts_reverse_step(
389
+ token_array=parent_tokens,
390
+ t=t,
391
+ dt=self.dt,
392
+ batch_size=self.num_children,
393
+ pretrained=self.base_model,
394
+ )
395
+
396
+ child_log_rnd = parent.log_rnd + (log_pretrained_step - log_policy_step)
397
+ log_policy_step = log_policy_step * self.sample_prob_weight
398
+
399
+ x_rollout = x_children
400
+ t_step = self.timesteps[parent.timestep] * torch.ones(self.num_children, 1, device=self.device)
401
+ for i in range(1, self.num_steps - parent.timestep):
402
+ t_step = self.timesteps[parent.timestep + i] * torch.ones(self.num_children, 1, device=self.device)
403
+ with torch.no_grad():
404
+ _, x_next, _, _ = self.base_model.mcts_reverse_step(
405
+ x_rollout,
406
+ t=t_step,
407
+ dt=self.dt,
408
+ pretrained=self.base_model,
409
+ )
410
+ x_rollout = x_next
411
+
412
+ if (x_rollout == self.mask_idx).any().item():
413
+ with torch.no_grad():
414
+ _, x_next, _, _ = self.base_model.mcts_noise_removal(
415
+ x_rollout,
416
+ t=t_step,
417
+ dt=self.dt,
418
+ pretrained=self.base_model,
419
+ )
420
+ x_rollout = x_next
421
+
422
+ sequences = self.tokenizer.batch_decode(x_rollout)
423
+ valid_mask = [self.analyzer.is_peptide(seq) for seq in sequences]
424
+
425
+ reward_values = np.full(self.num_children, -float(self.invalid_penalty), dtype=np.float32)
426
+ if any(valid_mask):
427
+ valid_tokens = x_rollout[valid_mask]
428
+ valid_sequences = [seq for seq, keep in zip(sequences, valid_mask) if keep]
429
+ affinity = _affinity_from_scoring(
430
+ self.reward_fn.scoring_fn,
431
+ valid_sequences,
432
+ self.device,
433
+ protein_seq=self.reward_fn.reward_inputs.protein_seq,
434
+ )
435
+ with torch.no_grad():
436
+ direction = self.reward_fn._direction_from_tokens(valid_tokens)
437
+ gated_reward = self.reward_fn._gated_reward(affinity, direction)
438
+ d_star = self.reward_fn.reward_inputs.d_star
439
+ dir_score = (direction - 0.5) * d_star
440
+
441
+ for idx, seq in enumerate(valid_sequences):
442
+ score_vector = np.array(
443
+ [float(affinity[idx].item()), float(dir_score[idx].item())],
444
+ dtype=np.float32,
445
+ )
446
+ pareto_front, pareto_tokens = self._update_pareto(
447
+ pareto_front,
448
+ pareto_tokens,
449
+ seq,
450
+ valid_tokens[idx],
451
+ score_vector,
452
+ )
453
+
454
+ reward_values[np.array(valid_mask)] = gated_reward.detach().cpu().numpy()
455
+
456
+ reward_vectors = []
457
+ for i in range(self.num_children):
458
+ child_tokens = {"seqs": x_children[i].to(dtype=torch.long), "attention_mask": attn_mask}
459
+ reward_vec = np.array([float(reward_values[i])], dtype=np.float32)
460
+ parent.addChildNode(
461
+ tokens=child_tokens,
462
+ log_rnd=child_log_rnd[i],
463
+ log_policy_step=log_policy_step[i],
464
+ log_pretrained_step=log_pretrained_step[i],
465
+ totalReward=reward_vec,
466
+ )
467
+ reward_vectors.append(reward_vec)
468
+
469
+ avg_reward = np.mean(np.stack(reward_vectors, axis=0), axis=0)
470
+ node = parent
471
+ while node:
472
+ node.updateNode(avg_reward)
473
+ node = node.parentNode
474
+
475
+ return pareto_front, pareto_tokens
476
+
477
+ def _select_from_pareto(self, pareto_front, pareto_tokens, batch_size):
478
+ if not pareto_front:
479
+ return self.base_model.sample_prior(batch_size, self.seq_length).to(self.device)
480
+
481
+ seqs = list(pareto_front.keys())
482
+ scores = np.stack([pareto_front[seq] for seq in seqs], axis=0)
483
+ affinity = scores[:, 0]
484
+ dir_score = scores[:, 1]
485
+ gate = 1.0 / (1.0 + np.exp(-dir_score / max(self.reward_fn.reward_alpha, 1e-6)))
486
+ gated = affinity * gate
487
+ order = np.argsort(-gated)
488
+
489
+ if len(order) >= batch_size:
490
+ selected = [seqs[i] for i in order[:batch_size]]
491
+ else:
492
+ repeats = np.random.choice(order, size=batch_size, replace=True)
493
+ selected = [seqs[i] for i in repeats]
494
+
495
+ tokens = [pareto_tokens[seq] for seq in selected]
496
+ return torch.stack(tokens, dim=0).to(self.device)
497
+
498
+ def sample(self, batch_size):
499
+ self.base_model.eval()
500
+ root = self._init_root()
501
+ pareto_front = {}
502
+ pareto_tokens = {}
503
+
504
+ for _ in range(self.mcts_iterations):
505
+ leaf, status = self._select(root)
506
+ if status == 1:
507
+ continue
508
+ pareto_front, pareto_tokens = self._expand(leaf, pareto_front, pareto_tokens)
509
+
510
+ return self._select_from_pareto(pareto_front, pareto_tokens, batch_size)
511
+
512
+
513
+ def _logits_and_probs_from_tokens(
514
+ base_model,
515
+ token_ids: torch.Tensor,
516
+ attn_mask: torch.Tensor,
517
+ ) -> torch.Tensor:
518
+ logits = _logits_from_inputs(base_model, input_ids=token_ids, attn_mask=attn_mask)
519
+ log_probs = base_model.subs_parameterization(logits, token_ids)
520
+ return log_probs
521
+
522
+
523
+ def _logits_and_probs_from_one_hot(
524
+ base_model,
525
+ y_one_hot: torch.Tensor,
526
+ token_ids: torch.Tensor,
527
+ attn_mask: torch.Tensor,
528
+ ) -> torch.Tensor:
529
+ emb_weight = base_model.backbone.model.roformer.embeddings.word_embeddings.weight
530
+ inputs_embeds = y_one_hot @ emb_weight
531
+ logits = _logits_from_inputs(base_model, inputs_embeds=inputs_embeds, attn_mask=attn_mask)
532
+ log_probs = base_model.subs_parameterization(logits, token_ids)
533
+ return log_probs
534
+
535
+
536
+ def classifier_guidance(
537
+ base_model,
538
+ reward_fn: RewardWrapper,
539
+ batch_size: int,
540
+ seq_length: int,
541
+ num_steps: int,
542
+ guidance_scale: float,
543
+ eps: float = DEFAULT_EPS,
544
+ guidance_steps: Optional[int] = None,
545
+ ) -> Dict[str, torch.Tensor]:
546
+ device = base_model.device
547
+ mask_idx = base_model.mask_index
548
+ vocab_size = base_model.vocab_size
549
+ x = base_model.sample_prior(batch_size, seq_length).to(device)
550
+ attn_mask = torch.ones_like(x, device=device)
551
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=device)
552
+ dt = torch.as_tensor((1 - eps) / num_steps, device=device)
553
+
554
+ guidance_enabled = True
555
+ for step in range(num_steps):
556
+ t = timesteps[step].repeat(batch_size)
557
+ use_guidance = guidance_enabled and (guidance_steps is None or step >= num_steps - guidance_steps)
558
+ if not use_guidance:
559
+ log_probs = _logits_and_probs_from_tokens(base_model, x, attn_mask)
560
+ q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx)
561
+ x = _sample_from_q(q_base, x, mask_idx)
562
+ continue
563
+
564
+ y_one_hot = _tokens_to_one_hot(x, vocab_size).to(device)
565
+ y_one_hot.requires_grad_(True)
566
+ token_ids = x.detach()
567
+ log_probs = _logits_and_probs_from_one_hot(base_model, y_one_hot, token_ids, attn_mask)
568
+ y_probs = log_probs.exp()
569
+ token_ids_for_affinity = y_probs.argmax(dim=-1).detach()
570
+ reward = reward_fn.reward_from_probs(y_probs, token_ids_for_affinity, attn_mask)
571
+ if not reward.requires_grad:
572
+ if guidance_enabled:
573
+ logger.warning(
574
+ "Reward does not require grad; disabling gradient guidance for classifier_guidance."
575
+ )
576
+ guidance_enabled = False
577
+ q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx)
578
+ x = _sample_from_q(q_base, x, mask_idx)
579
+ continue
580
+ reward.sum().backward()
581
+ grad = y_one_hot.grad
582
+ q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx)
583
+ guidance = guidance_scale * (grad - grad[:, :, mask_idx].unsqueeze(-1))
584
+ guidance = guidance.clamp(min=-50.0, max=50.0)
585
+ q_guided = q_base * torch.exp(guidance)
586
+ q_guided = _normalize_probs(q_guided)
587
+ x = _sample_from_q(q_guided, x, mask_idx)
588
+
589
+ return {"tokens": x}
590
+
591
+
592
+ def unguided_sampling(
593
+ base_model,
594
+ batch_size: int,
595
+ seq_length: int,
596
+ num_steps: int,
597
+ eps: float = DEFAULT_EPS,
598
+ ) -> Dict[str, torch.Tensor]:
599
+ device = base_model.device
600
+ mask_idx = base_model.mask_index
601
+ x = base_model.sample_prior(batch_size, seq_length).to(device)
602
+ attn_mask = torch.ones_like(x, device=device)
603
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=device)
604
+ dt = torch.as_tensor((1 - eps) / num_steps, device=device)
605
+
606
+ for step in range(num_steps):
607
+ t = timesteps[step].repeat(batch_size)
608
+ log_probs = _logits_and_probs_from_tokens(base_model, x, attn_mask)
609
+ q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx)
610
+ x = _sample_from_q(q_base, x, mask_idx)
611
+
612
+ return {"tokens": x}
613
+
614
+
615
+ def sequential_monte_carlo(
616
+ base_model,
617
+ reward_fn: RewardWrapper,
618
+ batch_size: int,
619
+ seq_length: int,
620
+ num_steps: int,
621
+ alpha: float,
622
+ eps: float = DEFAULT_EPS,
623
+ ) -> Dict[str, torch.Tensor]:
624
+ device = base_model.device
625
+ mask_idx = base_model.mask_index
626
+ x = base_model.sample_prior(batch_size, seq_length).to(device)
627
+ attn_mask = torch.ones_like(x, device=device)
628
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=device)
629
+ dt = torch.as_tensor((1 - eps) / num_steps, device=device)
630
+
631
+ with torch.no_grad():
632
+ r_current = reward_fn.reward_from_tokens(x, attn_mask).detach()
633
+ for step in range(num_steps):
634
+ t = timesteps[step].repeat(batch_size)
635
+ log_probs = _logits_and_probs_from_tokens(base_model, x, attn_mask)
636
+ q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx)
637
+ x_next = _sample_from_q(q_base, x, mask_idx)
638
+
639
+ with torch.no_grad():
640
+ r_next = reward_fn.reward_from_tokens(x_next, attn_mask).detach()
641
+ weights = torch.exp((r_next - r_current) / alpha).clamp_max(1e6)
642
+ weights = _safe_resample_weights(weights)
643
+ indices = torch.multinomial(weights, num_samples=batch_size, replacement=True)
644
+ x = x_next[indices]
645
+ r_current = r_next[indices]
646
+
647
+ return {"tokens": x}
648
+
649
+
650
+ def twisted_diffusion_sampler(
651
+ base_model,
652
+ reward_fn: RewardWrapper,
653
+ batch_size: int,
654
+ seq_length: int,
655
+ num_steps: int,
656
+ guidance_scale: float,
657
+ alpha: float,
658
+ eps: float = DEFAULT_EPS,
659
+ guidance_steps: Optional[int] = None,
660
+ ) -> Dict[str, torch.Tensor]:
661
+ device = base_model.device
662
+ mask_idx = base_model.mask_index
663
+ vocab_size = base_model.vocab_size
664
+ x = base_model.sample_prior(batch_size, seq_length).to(device)
665
+ attn_mask = torch.ones_like(x, device=device)
666
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=device)
667
+ dt = torch.as_tensor((1 - eps) / num_steps, device=device)
668
+
669
+ with torch.no_grad():
670
+ r_current = reward_fn.reward_from_tokens(x, attn_mask).detach()
671
+ guidance_enabled = True
672
+ for step in range(num_steps):
673
+ t = timesteps[step].repeat(batch_size)
674
+ use_guidance = guidance_enabled and (guidance_steps is None or step >= num_steps - guidance_steps)
675
+
676
+ if use_guidance:
677
+ y_one_hot = _tokens_to_one_hot(x, vocab_size).to(device)
678
+ y_one_hot.requires_grad_(True)
679
+ token_ids = x.detach()
680
+ log_probs = _logits_and_probs_from_one_hot(base_model, y_one_hot, token_ids, attn_mask)
681
+ y_probs = log_probs.exp()
682
+ token_ids_for_affinity = y_probs.argmax(dim=-1).detach()
683
+ reward = reward_fn.reward_from_probs(y_probs, token_ids_for_affinity, attn_mask)
684
+ q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx)
685
+ if not reward.requires_grad:
686
+ if guidance_enabled:
687
+ logger.warning(
688
+ "Reward does not require grad; disabling gradient guidance for twisted_diffusion_sampler."
689
+ )
690
+ guidance_enabled = False
691
+ q_guided = q_base
692
+ else:
693
+ reward.sum().backward()
694
+ grad = y_one_hot.grad
695
+ guidance = guidance_scale * (grad - grad[:, :, mask_idx].unsqueeze(-1))
696
+ guidance = guidance.clamp(min=-50.0, max=50.0)
697
+ q_guided = q_base * torch.exp(guidance)
698
+ q_guided = _normalize_probs(q_guided)
699
+ else:
700
+ log_probs = _logits_and_probs_from_tokens(base_model, x, attn_mask)
701
+ q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx)
702
+ q_guided = q_base
703
+
704
+ x_next = _sample_from_q(q_guided, x, mask_idx)
705
+ with torch.no_grad():
706
+ r_next = reward_fn.reward_from_tokens(x_next, attn_mask).detach()
707
+
708
+ logp_guided = _sequence_logprob(q_guided, x_next, x, mask_idx)
709
+ logp_base = _sequence_logprob(q_base, x_next, x, mask_idx)
710
+ weights = torch.exp((r_next - r_current) / alpha + (logp_base - logp_guided)).clamp_max(1e6)
711
+ weights = _safe_resample_weights(weights)
712
+ indices = torch.multinomial(weights, num_samples=batch_size, replacement=True)
713
+ x = x_next[indices]
714
+ r_current = r_next[indices]
715
+
716
+ return {"tokens": x}
717
+
718
+
719
+ def peptune_mctg_sampling(
720
+ base_model,
721
+ reward_fn: RewardWrapper,
722
+ batch_size: int,
723
+ seq_length: int,
724
+ num_steps: int,
725
+ mcts_iterations: int,
726
+ num_children: int,
727
+ alpha: float,
728
+ sample_prob_weight: float,
729
+ invalid_penalty: float = 1.0,
730
+ pareto_max_size: Optional[int] = None,
731
+ eps: float = DEFAULT_EPS,
732
+ ) -> Dict[str, torch.Tensor]:
733
+ sampler = PepTuneSampler(
734
+ base_model=base_model,
735
+ reward_fn=reward_fn,
736
+ seq_length=seq_length,
737
+ num_steps=num_steps,
738
+ mcts_iterations=mcts_iterations,
739
+ num_children=num_children,
740
+ sample_prob_weight=sample_prob_weight,
741
+ invalid_penalty=invalid_penalty,
742
+ pareto_max_size=pareto_max_size,
743
+ eps=eps,
744
+ )
745
+ tokens = sampler.sample(batch_size=batch_size)
746
+ return {"tokens": tokens}
baselines/run.sh ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ ROOT_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)"
6
+
7
+ CSV_PATH="${1:-To Be Added}"
8
+ BASELINE="${2:-unguided}"
9
+ DEVICE="${3:-cuda:4}"
10
+ OUTPUT_DIR="${4:-${SCRIPT_DIR}/outputs}"
11
+ NGPUS="${5:-1}"
12
+ MASTER_PORT="${6:-29500}"
13
+
14
+ if [ "$NGPUS" -gt 1 ]; then
15
+ echo "Running multi-GPU inference with $NGPUS GPUs (master port: $MASTER_PORT)"
16
+ LAUNCH_DEVICE="cuda"
17
+ python -m torch.distributed.run \
18
+ --nproc_per_node="$NGPUS" \
19
+ --master_port="$MASTER_PORT" \
20
+ "${SCRIPT_DIR}/sampling_setup.py" \
21
+ --ckpt_path "${ROOT_DIR}/pretrained/peptune-pretrained.ckpt" \
22
+ --device "${LAUNCH_DEVICE}" \
23
+ --baseline "${BASELINE}" \
24
+ --targets_csv "${CSV_PATH}" \
25
+ --batch_size 8 \
26
+ --num_steps 128 \
27
+ --num_batches 1 \
28
+ --output_dir "${OUTPUT_DIR}"
29
+
30
+ export OUTPUT_DIR BASELINE
31
+ python - <<'PY'
32
+ import glob
33
+ import os
34
+ import pandas as pd
35
+
36
+ out_dir = os.environ["OUTPUT_DIR"]
37
+ baseline = os.environ["BASELINE"]
38
+
39
+ def merge(pattern, output_name):
40
+ files = sorted(glob.glob(os.path.join(out_dir, pattern)))
41
+ if not files:
42
+ return
43
+ dfs = []
44
+ for path in files:
45
+ try:
46
+ dfs.append(pd.read_csv(path))
47
+ except Exception as exc:
48
+ print(f"[merge] skip {path}: {exc}")
49
+ if not dfs:
50
+ return
51
+ merged = pd.concat(dfs, ignore_index=True)
52
+ merged.to_csv(os.path.join(out_dir, output_name), index=False)
53
+ print(f"[merge] wrote {output_name} from {len(files)} shards")
54
+
55
+ merge(f"{baseline}_samples_rank*.csv", f"{baseline}_samples.csv")
56
+ merge("batch_times_rank*.csv", "batch_times.csv")
57
+ merge(f"{baseline}_metrics_rank*.csv", f"{baseline}_metrics.csv")
58
+ PY
59
+ exit 0
60
+ fi
61
+
62
+ python "${SCRIPT_DIR}/sampling_setup.py" \
63
+ --ckpt_path "${ROOT_DIR}/pretrained/peptune-pretrained.ckpt" \
64
+ --device "${DEVICE}" \
65
+ --baseline "${BASELINE}" \
66
+ --targets_csv "${CSV_PATH}" \
67
+ --batch_size 8 \
68
+ --num_steps 128 \
69
+ --num_batches 1 \
70
+ --output_dir "${OUTPUT_DIR}"
71
+
72
+ # ./run.sh To Be Added peptune cuda:0 To Be Added
73
+ # ./run.sh To Be Added peptune cuda To Be Added 4 29501
74
+ # ./run.sh To Be Added tds cuda:1 To Be Added
75
+ # ./run.sh To Be Added smc cuda:2 To Be Added
76
+ # ./run.sh To Be Added cg cuda:3 To Be Added
77
+ # ./run.sh To Be Added unguided cuda:4 To Be Added
baselines/run_mcts_tr2d2.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import os
4
+ import sys
5
+ from types import SimpleNamespace
6
+ from typing import Any, Dict, List, Tuple
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ import torch.distributed as dist
12
+
13
+ ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
14
+ if ROOT_DIR not in sys.path:
15
+ sys.path.insert(0, ROOT_DIR)
16
+
17
+ from diffusion import Diffusion
18
+ from configs.finetune_config import (
19
+ DiffusionConfig,
20
+ RoFormerConfig,
21
+ NoiseConfig,
22
+ TrainingConfig,
23
+ SamplingConfig,
24
+ EvalConfig,
25
+ OptimConfig,
26
+ MCTSConfig,
27
+ )
28
+ from finetune_utils import load_tokenizer
29
+ from finetune_distributed_utils import setup_distributed, cleanup_distributed, is_main_process
30
+ from scoring.functions.binding import MultiTargetBindingAffinity, TargetSpecificBindingAffinity
31
+ from td3b.direction_oracle import DirectionalOracle
32
+ from finetune_multi_target_tr2d2_ddp import TR2D2GatedReward, TargetDataset, create_tr2d2_mcts
33
+ from utils.app import PeptideAnalyzer
34
+
35
+
36
+ def _load_checkpoint(ckpt_path: str, device: torch.device) -> Dict[str, Any]:
37
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
38
+ if not isinstance(ckpt, dict):
39
+ raise ValueError(f"Unsupported checkpoint format: {type(ckpt)}")
40
+ return ckpt
41
+
42
+
43
+ def _extract_state_and_config(ckpt: Dict[str, Any]) -> Dict[str, Any]:
44
+ state_dict = ckpt.get("model_state_dict") or ckpt.get("state_dict") or ckpt
45
+ config = ckpt.get("config") or {}
46
+ return {"state_dict": state_dict, "config": config}
47
+
48
+
49
+ def _build_args(cfg: Dict[str, Any], cli: argparse.Namespace) -> argparse.Namespace:
50
+ defaults = {
51
+ "base_path": "To Be Added",
52
+ "seq_length": 200,
53
+ "sampling_eps": 1e-3,
54
+ "total_num_steps": 128,
55
+ "alpha": 0.1,
56
+ "hidden_dim": 768,
57
+ "num_layers": 8,
58
+ "num_heads": 8,
59
+ "min_affinity_threshold": 0.0,
60
+ "sigmoid_temperature": 0.1,
61
+ "val_samples_per_target": 8,
62
+ "direction_oracle_esm_name": "facebook/esm2_t33_650M_UR50D",
63
+ "direction_oracle_esm_cache_dir": None,
64
+ "direction_oracle_esm_local_files_only": False,
65
+ "direction_oracle_max_ligand_length": 768,
66
+ "direction_oracle_max_protein_length": 1024,
67
+ "direction_oracle_d_model": 256,
68
+ "direction_oracle_n_heads": 4,
69
+ "direction_oracle_n_self_attn_layers": 1,
70
+ "direction_oracle_n_bmca_layers": 2,
71
+ "direction_oracle_dropout": 0.3,
72
+ "num_iter": 20,
73
+ "num_children": 24,
74
+ "buffer_size": 32,
75
+ "exploration": 1.0,
76
+ }
77
+
78
+ merged = dict(defaults)
79
+ merged.update(cfg or {})
80
+
81
+ if cli.base_path is not None:
82
+ merged["base_path"] = cli.base_path
83
+ if cli.val_csv is not None:
84
+ merged["val_csv"] = cli.val_csv
85
+ if cli.save_path is not None:
86
+ merged["save_path"] = cli.save_path
87
+ if cli.device is not None:
88
+ merged["device"] = cli.device
89
+ if cli.val_samples_per_target is not None:
90
+ merged["val_samples_per_target"] = cli.val_samples_per_target
91
+ if cli.seq_length is not None:
92
+ merged["seq_length"] = cli.seq_length
93
+ if cli.total_num_steps is not None:
94
+ merged["total_num_steps"] = cli.total_num_steps
95
+ if cli.sampling_eps is not None:
96
+ merged["sampling_eps"] = cli.sampling_eps
97
+ if cli.alpha is not None:
98
+ merged["alpha"] = cli.alpha
99
+ if cli.num_iter is not None:
100
+ merged["num_iter"] = cli.num_iter
101
+ if cli.num_children is not None:
102
+ merged["num_children"] = cli.num_children
103
+ if cli.buffer_size is not None:
104
+ merged["buffer_size"] = cli.buffer_size
105
+ if cli.exploration is not None:
106
+ merged["exploration"] = cli.exploration
107
+ if cli.max_sequence_length is not None:
108
+ merged["max_sequence_length"] = cli.max_sequence_length
109
+
110
+ args = SimpleNamespace(**merged)
111
+
112
+ base_tr2d2_path = os.path.join(args.base_path, "tr2d2-pep")
113
+ if not getattr(args, "direction_oracle_ckpt", None):
114
+ args.direction_oracle_ckpt = os.path.join(base_tr2d2_path, "direction_oracle.pt")
115
+ if not getattr(args, "direction_oracle_tr2d2_checkpoint", None):
116
+ args.direction_oracle_tr2d2_checkpoint = os.path.join(
117
+ base_tr2d2_path, "pretrained", "peptune-pretrained.ckpt"
118
+ )
119
+ if not getattr(args, "direction_oracle_tokenizer_vocab", None):
120
+ args.direction_oracle_tokenizer_vocab = os.path.join(
121
+ base_tr2d2_path, "tokenizer", "new_vocab.txt"
122
+ )
123
+ if not getattr(args, "direction_oracle_tokenizer_splits", None):
124
+ args.direction_oracle_tokenizer_splits = os.path.join(
125
+ base_tr2d2_path, "tokenizer", "new_splits.txt"
126
+ )
127
+
128
+ if not getattr(args, "save_path", None):
129
+ args.save_path = os.path.join(base_tr2d2_path, "baselines", "outputs_mcts_tr2d2")
130
+ os.makedirs(args.save_path, exist_ok=True)
131
+ return args
132
+
133
+
134
+ def _build_model(args: argparse.Namespace, state_dict: Dict[str, Any], device: torch.device) -> Diffusion:
135
+ config = DiffusionConfig(
136
+ roformer=RoFormerConfig(
137
+ hidden_size=args.hidden_dim,
138
+ n_layers=args.num_layers,
139
+ n_heads=args.num_heads,
140
+ ),
141
+ noise=NoiseConfig(),
142
+ training=TrainingConfig(sampling_eps=args.sampling_eps),
143
+ sampling=SamplingConfig(
144
+ steps=args.total_num_steps,
145
+ sampling_eps=args.sampling_eps,
146
+ ),
147
+ eval_cfg=EvalConfig(),
148
+ optim=OptimConfig(lr=getattr(args, "learning_rate", 3e-4)),
149
+ mcts=MCTSConfig(),
150
+ )
151
+
152
+ tokenizer = load_tokenizer(args.base_path)
153
+ model = Diffusion(
154
+ config=config,
155
+ tokenizer=tokenizer,
156
+ device=device,
157
+ ).to(device)
158
+ load_result = model.load_state_dict(state_dict, strict=False)
159
+ if load_result.missing_keys:
160
+ print(f"[load] Missing keys: {len(load_result.missing_keys)}")
161
+ if load_result.unexpected_keys:
162
+ print(f"[load] Unexpected keys: {len(load_result.unexpected_keys)}")
163
+ model.eval()
164
+ return model
165
+
166
+
167
+ def _build_oracle(args: argparse.Namespace, device: torch.device) -> DirectionalOracle:
168
+ oracle = DirectionalOracle(
169
+ model_ckpt=args.direction_oracle_ckpt,
170
+ tr2d2_checkpoint=args.direction_oracle_tr2d2_checkpoint,
171
+ tokenizer_vocab=args.direction_oracle_tokenizer_vocab,
172
+ tokenizer_splits=args.direction_oracle_tokenizer_splits,
173
+ esm_name=args.direction_oracle_esm_name,
174
+ d_model=args.direction_oracle_d_model,
175
+ n_heads=args.direction_oracle_n_heads,
176
+ n_self_attn_layers=args.direction_oracle_n_self_attn_layers,
177
+ n_bmca_layers=args.direction_oracle_n_bmca_layers,
178
+ dropout=args.direction_oracle_dropout,
179
+ max_ligand_length=args.direction_oracle_max_ligand_length,
180
+ max_protein_length=args.direction_oracle_max_protein_length,
181
+ device=device,
182
+ esm_cache_dir=args.direction_oracle_esm_cache_dir,
183
+ esm_local_files_only=args.direction_oracle_esm_local_files_only,
184
+ )
185
+ oracle.eval()
186
+ return oracle
187
+
188
+
189
+ def _compute_direction_accuracy(directions: np.ndarray, d_star: float) -> np.ndarray:
190
+ if directions.size == 0:
191
+ return directions
192
+ acc = np.full(directions.shape, np.nan, dtype=np.float32)
193
+ valid = np.isfinite(directions)
194
+ if not valid.any():
195
+ return acc
196
+ if d_star > 0:
197
+ acc[valid] = (directions[valid] >= 0.5).astype(np.float32)
198
+ else:
199
+ acc[valid] = (directions[valid] < 0.5).astype(np.float32)
200
+ return acc
201
+
202
+
203
+ def _nanmean(values: np.ndarray) -> float:
204
+ if values.size == 0:
205
+ return 0.0
206
+ finite = values[np.isfinite(values)]
207
+ return float(np.mean(finite)) if finite.size else 0.0
208
+
209
+
210
+ def _nanstd(values: np.ndarray) -> float:
211
+ if values.size == 0:
212
+ return 0.0
213
+ finite = values[np.isfinite(values)]
214
+ return float(np.std(finite)) if finite.size else 0.0
215
+
216
+
217
+ def main() -> None:
218
+ parser = argparse.ArgumentParser(description="MCTS-based TR2-D2 evaluation.")
219
+ parser.add_argument("--ckpt_path", required=True, help="Path to finetuned checkpoint (.ckpt)")
220
+ parser.add_argument("--val_csv", required=True, help="Validation CSV path")
221
+ parser.add_argument("--device", default="cuda", help="Device string (e.g., cuda:0 or cpu)")
222
+ parser.add_argument("--base_path", default=None, help="Base path for TR2-D2")
223
+ parser.add_argument("--save_path", default=None, help="Output directory for evaluation CSV")
224
+ parser.add_argument("--epoch", type=int, default=0, help="Epoch number to label outputs")
225
+ parser.add_argument("--val_samples_per_target", type=int, default=None, help="Samples per target (unused by MCTS)")
226
+ parser.add_argument("--seq_length", type=int, default=None, help="Fallback sequence length")
227
+ parser.add_argument("--total_num_steps", type=int, default=None, help="Diffusion steps")
228
+ parser.add_argument("--sampling_eps", type=float, default=None, help="Sampling epsilon")
229
+ parser.add_argument("--alpha", type=float, default=None, help="MCTS alpha temperature")
230
+ parser.add_argument("--num_iter", type=int, default=None, help="MCTS iterations")
231
+ parser.add_argument("--num_children", type=int, default=None, help="MCTS children per expand")
232
+ parser.add_argument("--buffer_size", type=int, default=None, help="MCTS buffer size")
233
+ parser.add_argument("--exploration", type=float, default=None, help="MCTS exploration constant")
234
+ parser.add_argument("--max_sequence_length", type=int, default=1035)
235
+ parser.add_argument("--max_attempts", type=int, default=3, help="Max MCTS attempts to reach target count")
236
+ parser.add_argument("--seed", type=int, default=None, help="Random seed")
237
+ cli_args = parser.parse_args()
238
+
239
+ rank = int(os.environ.get("LOCAL_RANK", 0))
240
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
241
+
242
+ if world_size > 1:
243
+ setup_distributed(rank, world_size)
244
+ device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
245
+ else:
246
+ device = torch.device(cli_args.device)
247
+
248
+ if cli_args.seed is not None:
249
+ torch.manual_seed(cli_args.seed + rank)
250
+ np.random.seed(cli_args.seed + rank)
251
+
252
+ ckpt = _load_checkpoint(cli_args.ckpt_path, device)
253
+ payload = _extract_state_and_config(ckpt)
254
+ args = _build_args(payload["config"], cli_args)
255
+
256
+ tokenizer = load_tokenizer(args.base_path)
257
+ val_dataset = TargetDataset(args.val_csv, tokenizer=tokenizer)
258
+
259
+ policy_model = _build_model(args, payload["state_dict"], device)
260
+
261
+ multi_target_affinity = MultiTargetBindingAffinity(
262
+ tokenizer=tokenizer,
263
+ base_path=args.base_path,
264
+ device=device,
265
+ emb_model=policy_model.backbone,
266
+ )
267
+
268
+ directional_oracle = _build_oracle(args, device)
269
+ analyzer = PeptideAnalyzer()
270
+
271
+ val_targets = val_dataset.get_all_targets()
272
+ if world_size > 1:
273
+ my_targets = val_targets[rank::world_size]
274
+ else:
275
+ my_targets = val_targets
276
+
277
+ records: List[Dict[str, Any]] = []
278
+ protein_token_cache: Dict[str, torch.Tensor] = {}
279
+
280
+ with torch.no_grad():
281
+ for target_seq in my_targets:
282
+ target_tokens = protein_token_cache.get(target_seq)
283
+ if target_tokens is None:
284
+ target_tokens = directional_oracle.encode_protein(target_seq)
285
+ protein_token_cache[target_seq] = target_tokens
286
+
287
+ for direction_name, d_star in [("agonist", 1.0), ("antagonist", -1.0)]:
288
+ target_length = val_dataset.get_sequence_length(target_seq, direction_name)
289
+ if target_length > args.max_sequence_length:
290
+ target_length = args.max_sequence_length
291
+
292
+ original_seq_length = args.seq_length
293
+ args.seq_length = int(target_length)
294
+
295
+ target_affinity = TargetSpecificBindingAffinity(multi_target_affinity, target_seq)
296
+ reward_model = TR2D2GatedReward(
297
+ affinity_predictor=target_affinity,
298
+ directional_oracle=directional_oracle,
299
+ target_direction=d_star,
300
+ target_protein_tokens=target_tokens,
301
+ tokenizer=tokenizer,
302
+ device=device,
303
+ min_affinity_threshold=args.min_affinity_threshold,
304
+ temperature=args.sigmoid_temperature,
305
+ )
306
+
307
+ mcts = create_tr2d2_mcts(
308
+ args=args,
309
+ policy_model=policy_model,
310
+ reward_function=reward_model,
311
+ buffer_size=args.buffer_size,
312
+ )
313
+
314
+ target_count = int(args.val_samples_per_target)
315
+ collected_sequences: List[str] = []
316
+ attempt_valid_fractions: List[float] = []
317
+
318
+ for attempt in range(max(cli_args.max_attempts, 1)):
319
+ try:
320
+ _, _, _, _, sequences = mcts.forward(resetTree=True)
321
+ except Exception as exc:
322
+ print(f"[mcts] failed for target={target_seq[:12]} dir={direction_name}: {exc}")
323
+ sequences = []
324
+
325
+ attempt_valid = float(np.mean(mcts.valid_fraction_log)) if getattr(mcts, "valid_fraction_log", None) else 0.0
326
+ attempt_valid_fractions.append(attempt_valid)
327
+
328
+ if sequences:
329
+ collected_sequences.extend(sequences)
330
+
331
+ if len(collected_sequences) >= target_count:
332
+ break
333
+
334
+ args.seq_length = original_seq_length
335
+
336
+ valid_fraction = _nanmean(np.asarray(attempt_valid_fractions, dtype=np.float32))
337
+
338
+ if not collected_sequences:
339
+ records.append(
340
+ {
341
+ "target": target_seq[:20],
342
+ "sequence": "",
343
+ "target_direction": d_star,
344
+ "is_valid": False,
345
+ "valid_fraction": valid_fraction,
346
+ "affinity": np.nan,
347
+ "gated_reward": np.nan,
348
+ "direction_oracle": np.nan,
349
+ "consistency_reward": np.nan,
350
+ "direction_accuracy": np.nan,
351
+ "success_rate": np.nan,
352
+ }
353
+ )
354
+ continue
355
+
356
+ if len(collected_sequences) > target_count:
357
+ collected_sequences = collected_sequences[:target_count]
358
+
359
+ gated_rewards, affinities, confidences, directions = reward_model.reward_fn.compute_gated_reward(collected_sequences)
360
+ direction_accuracy = _compute_direction_accuracy(directions, d_star)
361
+ consistency = d_star * (directions - 0.5)
362
+ success_rate = direction_accuracy * valid_fraction
363
+
364
+ valid_mask = np.array([analyzer.is_peptide(seq) for seq in collected_sequences], dtype=bool)
365
+
366
+ for idx, seq in enumerate(collected_sequences):
367
+ records.append(
368
+ {
369
+ "target": target_seq[:20],
370
+ "sequence": seq,
371
+ "target_direction": d_star,
372
+ "is_valid": bool(valid_mask[idx]) if valid_mask.size else False,
373
+ "valid_fraction": valid_fraction,
374
+ "affinity": float(affinities[idx]) if len(affinities) else np.nan,
375
+ "gated_reward": float(gated_rewards[idx]) if len(gated_rewards) else np.nan,
376
+ "direction_oracle": float(directions[idx]) if len(directions) else np.nan,
377
+ "consistency_reward": float(consistency[idx]) if len(consistency) else np.nan,
378
+ "direction_accuracy": float(direction_accuracy[idx]) if len(direction_accuracy) else np.nan,
379
+ "success_rate": float(success_rate[idx]) if len(success_rate) else np.nan,
380
+ }
381
+ )
382
+
383
+ if world_size > 1:
384
+ gathered: List[List[Dict[str, Any]]] = [None for _ in range(world_size)]
385
+ dist.all_gather_object(gathered, records)
386
+ if is_main_process():
387
+ records = [item for sub in gathered for item in sub]
388
+ else:
389
+ cleanup_distributed()
390
+ return
391
+
392
+ if is_main_process():
393
+ df = pd.DataFrame(records)
394
+ output_path = os.path.join(args.save_path, f"mcts_validation_epoch_{cli_args.epoch}.csv")
395
+ df.to_csv(output_path, index=False)
396
+ print(f"MCTS validation sequences saved to {output_path}")
397
+
398
+ affinities = df["affinity"].to_numpy(dtype=np.float32)
399
+ gated_rewards = df["gated_reward"].to_numpy(dtype=np.float32)
400
+ directions = df["direction_oracle"].to_numpy(dtype=np.float32)
401
+ target_directions = df["target_direction"].to_numpy(dtype=np.float32)
402
+ direction_correct = df["direction_accuracy"].to_numpy(dtype=np.float32)
403
+ valid_fractions = df["valid_fraction"].to_numpy(dtype=np.float32)
404
+
405
+ pos_mask = target_directions == 1.0
406
+ neg_mask = target_directions == -1.0
407
+
408
+ print("MCTS validation summary")
409
+ print(f" Affinity (d*=1): {_nanmean(affinities[pos_mask]):.4f} ± {_nanstd(affinities[pos_mask]):.4f}")
410
+ print(f" Affinity (d*=-1): {_nanmean(affinities[neg_mask]):.4f} ± {_nanstd(affinities[neg_mask]):.4f}")
411
+ print(f" Direction Accuracy (d*=1): {_nanmean(direction_correct[pos_mask]):.4f} ± {_nanstd(direction_correct[pos_mask]):.4f}")
412
+ print(f" Direction Accuracy (d*=-1): {_nanmean(direction_correct[neg_mask]):.4f} ± {_nanstd(direction_correct[neg_mask]):.4f}")
413
+ print(f" Gated Reward (overall): {_nanmean(gated_rewards):.4f} ± {_nanstd(gated_rewards):.4f}")
414
+ print(f" Valid Fraction: {_nanmean(valid_fractions):.4f} ± {_nanstd(valid_fractions):.4f}")
415
+
416
+ if world_size > 1:
417
+ cleanup_distributed()
418
+
419
+
420
+ if __name__ == "__main__":
421
+ main()
baselines/run_validation_td3b.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import os
4
+ import sys
5
+ from types import SimpleNamespace
6
+ from typing import Any, Dict, List, Tuple
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ import torch.distributed as dist
12
+
13
+ ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
14
+ if ROOT_DIR not in sys.path:
15
+ sys.path.insert(0, ROOT_DIR)
16
+
17
+ from diffusion import Diffusion
18
+ from configs.finetune_config import (
19
+ DiffusionConfig,
20
+ RoFormerConfig,
21
+ NoiseConfig,
22
+ TrainingConfig,
23
+ SamplingConfig,
24
+ EvalConfig,
25
+ OptimConfig,
26
+ MCTSConfig,
27
+ )
28
+ from finetune_utils import load_tokenizer, create_reward_function
29
+ from finetune_multi_target import TargetDataset
30
+ from distributed_utils import setup_distributed, cleanup_distributed, is_main_process
31
+ from scoring.functions.binding import MultiTargetBindingAffinity, TargetSpecificBindingAffinity
32
+ from td3b.direction_oracle import DirectionalOracle
33
+ from utils.app import PeptideAnalyzer
34
+
35
+
36
+ def _load_checkpoint(ckpt_path: str, device: torch.device) -> Dict[str, Any]:
37
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
38
+ if not isinstance(ckpt, dict):
39
+ raise ValueError(f"Unsupported checkpoint format: {type(ckpt)}")
40
+ return ckpt
41
+
42
+
43
+ def _extract_state_and_config(ckpt: Dict[str, Any]) -> Dict[str, Any]:
44
+ state_dict = ckpt.get("model_state_dict") or ckpt.get("state_dict") or ckpt
45
+ config = ckpt.get("config") or {}
46
+ return {"state_dict": state_dict, "config": config}
47
+
48
+
49
+ def _build_args(cfg: Dict[str, Any], cli: argparse.Namespace) -> argparse.Namespace:
50
+ defaults = {
51
+ "base_path": "To Be Added",
52
+ "seq_length": 200,
53
+ "sampling_eps": 1e-3,
54
+ "total_num_steps": 128,
55
+ "alpha": 0.1,
56
+ "hidden_dim": 768,
57
+ "num_layers": 8,
58
+ "num_heads": 8,
59
+ "min_affinity_threshold": 0.0,
60
+ "sigmoid_temperature": 0.1,
61
+ "val_samples_per_target": 8,
62
+ "direction_oracle_esm_name": "facebook/esm2_t33_650M_UR50D",
63
+ "direction_oracle_esm_cache_dir": None,
64
+ "direction_oracle_esm_local_files_only": False,
65
+ "direction_oracle_max_ligand_length": 768,
66
+ "direction_oracle_max_protein_length": 1024,
67
+ "direction_oracle_d_model": 256,
68
+ "direction_oracle_n_heads": 4,
69
+ "direction_oracle_n_self_attn_layers": 1,
70
+ "direction_oracle_n_bmca_layers": 2,
71
+ "direction_oracle_dropout": 0.3,
72
+ }
73
+
74
+ merged = dict(defaults)
75
+ merged.update(cfg or {})
76
+
77
+ if cli.base_path is not None:
78
+ merged["base_path"] = cli.base_path
79
+ if cli.val_csv is not None:
80
+ merged["val_csv"] = cli.val_csv
81
+ if cli.save_path is not None:
82
+ merged["save_path"] = cli.save_path
83
+ if cli.device is not None:
84
+ merged["device"] = cli.device
85
+ if cli.val_samples_per_target is not None:
86
+ merged["val_samples_per_target"] = cli.val_samples_per_target
87
+ if getattr(cli, "num_pool", None) is not None:
88
+ merged["num_pool"] = cli.num_pool
89
+ if cli.seq_length is not None:
90
+ merged["seq_length"] = cli.seq_length
91
+ if cli.total_num_steps is not None:
92
+ merged["total_num_steps"] = cli.total_num_steps
93
+ if cli.sampling_eps is not None:
94
+ merged["sampling_eps"] = cli.sampling_eps
95
+ if cli.seed is not None:
96
+ merged["seed"] = cli.seed
97
+
98
+ args = SimpleNamespace(**merged)
99
+
100
+ base_tr2d2_path = os.path.join(args.base_path, "tr2d2-pep")
101
+ if not getattr(args, "direction_oracle_ckpt", None):
102
+ args.direction_oracle_ckpt = os.path.join(base_tr2d2_path, "direction_oracle.pt")
103
+ if not getattr(args, "direction_oracle_tr2d2_checkpoint", None):
104
+ args.direction_oracle_tr2d2_checkpoint = os.path.join(
105
+ base_tr2d2_path, "pretrained", "peptune-pretrained.ckpt"
106
+ )
107
+ if not getattr(args, "direction_oracle_tokenizer_vocab", None):
108
+ args.direction_oracle_tokenizer_vocab = os.path.join(
109
+ base_tr2d2_path, "tokenizer", "new_vocab.txt"
110
+ )
111
+ if not getattr(args, "direction_oracle_tokenizer_splits", None):
112
+ args.direction_oracle_tokenizer_splits = os.path.join(
113
+ base_tr2d2_path, "tokenizer", "new_splits.txt"
114
+ )
115
+
116
+ if not getattr(args, "save_path", None):
117
+ args.save_path = os.path.join(base_tr2d2_path, "results", "validation_runs")
118
+
119
+ os.makedirs(args.save_path, exist_ok=True)
120
+ return args
121
+
122
+
123
+ def _build_model(args: argparse.Namespace, state_dict: Dict[str, Any], device: torch.device) -> Diffusion:
124
+ config = DiffusionConfig(
125
+ roformer=RoFormerConfig(
126
+ hidden_size=args.hidden_dim,
127
+ n_layers=args.num_layers,
128
+ n_heads=args.num_heads,
129
+ ),
130
+ noise=NoiseConfig(),
131
+ training=TrainingConfig(sampling_eps=args.sampling_eps),
132
+ sampling=SamplingConfig(
133
+ steps=args.total_num_steps,
134
+ sampling_eps=args.sampling_eps,
135
+ ),
136
+ eval_cfg=EvalConfig(),
137
+ optim=OptimConfig(lr=getattr(args, "learning_rate", 3e-4)),
138
+ mcts=MCTSConfig(),
139
+ )
140
+
141
+ tokenizer = load_tokenizer(args.base_path)
142
+ model = Diffusion(
143
+ config=config,
144
+ tokenizer=tokenizer,
145
+ device=device,
146
+ ).to(device)
147
+ load_result = model.load_state_dict(state_dict, strict=False)
148
+ if load_result.missing_keys:
149
+ print(f"[load] Missing keys: {len(load_result.missing_keys)}")
150
+ if load_result.unexpected_keys:
151
+ print(f"[load] Unexpected keys: {len(load_result.unexpected_keys)}")
152
+ model.eval()
153
+ return model
154
+
155
+
156
+ def _build_oracle(args: argparse.Namespace, device: torch.device) -> DirectionalOracle:
157
+ oracle = DirectionalOracle(
158
+ model_ckpt=args.direction_oracle_ckpt,
159
+ tr2d2_checkpoint=args.direction_oracle_tr2d2_checkpoint,
160
+ tokenizer_vocab=args.direction_oracle_tokenizer_vocab,
161
+ tokenizer_splits=args.direction_oracle_tokenizer_splits,
162
+ esm_name=args.direction_oracle_esm_name,
163
+ d_model=args.direction_oracle_d_model,
164
+ n_heads=args.direction_oracle_n_heads,
165
+ n_self_attn_layers=args.direction_oracle_n_self_attn_layers,
166
+ n_bmca_layers=args.direction_oracle_n_bmca_layers,
167
+ dropout=args.direction_oracle_dropout,
168
+ max_ligand_length=args.direction_oracle_max_ligand_length,
169
+ max_protein_length=args.direction_oracle_max_protein_length,
170
+ device=device,
171
+ esm_cache_dir=args.direction_oracle_esm_cache_dir,
172
+ esm_local_files_only=args.direction_oracle_esm_local_files_only,
173
+ )
174
+ oracle.eval()
175
+ return oracle
176
+
177
+
178
+ def _sample_sequences(
179
+ model: Diffusion,
180
+ batch_size: int,
181
+ seq_length: int,
182
+ total_num_steps: int,
183
+ sampling_eps: float,
184
+ ) -> torch.Tensor:
185
+ model.backbone.eval()
186
+ model.noise.eval()
187
+
188
+ x_rollout = model.sample_prior(batch_size, seq_length).to(model.device, dtype=torch.long)
189
+
190
+ timesteps = torch.linspace(1, sampling_eps, total_num_steps + 1, device=model.device)
191
+ dt = torch.tensor((1 - sampling_eps) / total_num_steps, device=model.device)
192
+
193
+ for i in range(total_num_steps):
194
+ t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=model.device)
195
+ _, x_next = model.single_reverse_step(x_rollout, t=t, dt=dt)
196
+ x_rollout = x_next.to(model.device)
197
+
198
+ if (x_rollout == model.mask_index).any().item():
199
+ _, x_next = model.single_noise_removal(x_rollout, t=t, dt=dt)
200
+ x_rollout = x_next.to(model.device)
201
+
202
+ return x_rollout
203
+
204
+
205
+ def _score_sequences(reward_model, sequences: List[str]) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
206
+ if not sequences:
207
+ empty = np.array([], dtype=np.float32)
208
+ return empty, empty, empty, empty
209
+
210
+ try:
211
+ result = reward_model(sequences)
212
+ if isinstance(result, tuple):
213
+ total_rewards, info = result
214
+ affinity = np.asarray(info.get("affinities", total_rewards), dtype=np.float32)
215
+ confidence = np.asarray(info.get("confidences", np.ones_like(affinity)), dtype=np.float32)
216
+ directions = np.asarray(info.get("directions", np.zeros_like(affinity)), dtype=np.float32)
217
+ else:
218
+ total_rewards = np.asarray(result, dtype=np.float32)
219
+ if total_rewards.ndim > 1:
220
+ affinity = total_rewards[:, 0]
221
+ else:
222
+ affinity = total_rewards
223
+ confidence = np.ones_like(affinity, dtype=np.float32)
224
+ directions = np.zeros_like(affinity, dtype=np.float32)
225
+ return np.asarray(total_rewards, dtype=np.float32), affinity, directions, confidence
226
+ except Exception:
227
+ total_rewards = np.full(len(sequences), np.nan, dtype=np.float32)
228
+ affinity = np.full(len(sequences), np.nan, dtype=np.float32)
229
+ directions = np.full(len(sequences), np.nan, dtype=np.float32)
230
+ confidence = np.full(len(sequences), np.nan, dtype=np.float32)
231
+ for idx, seq in enumerate(sequences):
232
+ try:
233
+ result = reward_model([seq])
234
+ if isinstance(result, tuple):
235
+ rewards, info = result
236
+ total_rewards[idx] = float(np.asarray(rewards)[0])
237
+ affinity[idx] = float(np.asarray(info.get("affinities", rewards))[0])
238
+ confidence[idx] = float(np.asarray(info.get("confidences", [np.nan]))[0])
239
+ directions[idx] = float(np.asarray(info.get("directions", [np.nan]))[0])
240
+ else:
241
+ reward = np.asarray(result)
242
+ total_rewards[idx] = float(reward[0]) if reward.size else np.nan
243
+ affinity[idx] = total_rewards[idx]
244
+ except Exception:
245
+ continue
246
+ return total_rewards, affinity, directions, confidence
247
+
248
+
249
+ def _compute_direction_accuracy(directions: np.ndarray, d_star: float) -> np.ndarray:
250
+ if directions.size == 0:
251
+ return directions
252
+ acc = np.full(directions.shape, np.nan, dtype=np.float32)
253
+ valid = np.isfinite(directions)
254
+ if not valid.any():
255
+ return acc
256
+ if d_star > 0:
257
+ acc[valid] = (directions[valid] >= 0.5).astype(np.float32)
258
+ else:
259
+ acc[valid] = (directions[valid] < 0.5).astype(np.float32)
260
+ return acc
261
+
262
+
263
+ def _nanmean(values: np.ndarray) -> float:
264
+ return float(np.nanmean(values)) if values.size else float("nan")
265
+
266
+
267
+ def _nanstd(values: np.ndarray) -> float:
268
+ return float(np.nanstd(values)) if values.size else float("nan")
269
+
270
+
271
+ def main() -> None:
272
+ parser = argparse.ArgumentParser(description="Run TD3B validation from a saved checkpoint.")
273
+ parser.add_argument("--ckpt_path", required=True, help="Path to saved checkpoint (.ckpt)")
274
+ parser.add_argument("--val_csv", required=True, help="Validation CSV path")
275
+ parser.add_argument("--device", default="cuda", help="Device string (e.g., cuda:0 or cpu)")
276
+ parser.add_argument("--base_path", default=None, help="Base path for TR2-D2")
277
+ parser.add_argument("--save_path", default=None, help="Output directory for validation CSV")
278
+ parser.add_argument("--epoch", type=int, default=0, help="Epoch number to label outputs")
279
+ parser.add_argument("--val_samples_per_target", type=int, default=None, help="Samples per target")
280
+ parser.add_argument("--num_pool", type=int, default=None,
281
+ help="Number of candidate sequences to sample before resampling")
282
+ parser.add_argument("--seq_length", type=int, default=None, help="Fallback sequence length")
283
+ parser.add_argument("--total_num_steps", type=int, default=None, help="Diffusion steps")
284
+ parser.add_argument("--sampling_eps", type=float, default=None, help="Sampling epsilon")
285
+ parser.add_argument("--seed", type=int, default=None, help="Base random seed")
286
+ parser.add_argument("--no_resample", action="store_true", help="Disable reward-weighted resampling")
287
+ parser.add_argument("--resample_without_replacement", action="store_true",
288
+ help="Resample without replacement when possible")
289
+ parser.add_argument("--resample_alpha", type=float, default=None,
290
+ help="Override alpha for resampling weights")
291
+ cli_args = parser.parse_args()
292
+
293
+ rank = int(os.environ.get("LOCAL_RANK", 0))
294
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
295
+
296
+ if world_size > 1:
297
+ setup_distributed(rank, world_size)
298
+ device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
299
+ else:
300
+ device = torch.device(cli_args.device)
301
+
302
+ if cli_args.seed is not None:
303
+ torch.manual_seed(cli_args.seed + rank)
304
+ np.random.seed(cli_args.seed + rank)
305
+
306
+ ckpt = _load_checkpoint(cli_args.ckpt_path, device)
307
+ payload = _extract_state_and_config(ckpt)
308
+ args = _build_args(payload["config"], cli_args)
309
+
310
+ tokenizer = load_tokenizer(args.base_path)
311
+ val_dataset = TargetDataset(args.val_csv, tokenizer=tokenizer)
312
+
313
+ policy_model = _build_model(args, payload["state_dict"], device)
314
+
315
+ multi_target_affinity = MultiTargetBindingAffinity(
316
+ tokenizer=tokenizer,
317
+ base_path=args.base_path,
318
+ device=device,
319
+ emb_model=policy_model.backbone,
320
+ )
321
+
322
+ directional_oracle = _build_oracle(args, device)
323
+ analyzer = PeptideAnalyzer()
324
+ protein_token_cache: Dict[str, torch.Tensor] = {}
325
+
326
+ resample_enabled = not cli_args.no_resample
327
+ resample_with_replacement = not cli_args.resample_without_replacement
328
+ resample_alpha = cli_args.resample_alpha if cli_args.resample_alpha is not None else args.alpha
329
+
330
+ all_targets = val_dataset.get_all_targets()
331
+ if world_size > 1:
332
+ my_targets = all_targets[rank::world_size]
333
+ else:
334
+ my_targets = all_targets
335
+
336
+ records: List[Dict[str, Any]] = []
337
+ resampled_records: List[Dict[str, Any]] = []
338
+ resampled_affinity_pos: List[float] = []
339
+ resampled_affinity_neg: List[float] = []
340
+ resampled_acc_pos: List[float] = []
341
+ resampled_acc_neg: List[float] = []
342
+ resampled_gated_rewards: List[float] = []
343
+
344
+ with torch.no_grad():
345
+ for target_seq in my_targets:
346
+ target_protein_tokens = protein_token_cache.get(target_seq)
347
+ if target_protein_tokens is None:
348
+ target_protein_tokens = directional_oracle.encode_protein(target_seq)
349
+ protein_token_cache[target_seq] = target_protein_tokens
350
+
351
+ for direction_name, d_star in [("agonist", 1.0), ("antagonist", -1.0)]:
352
+ target_length = val_dataset.get_sequence_length(target_seq, direction_name)
353
+ max_len = 1035
354
+ if target_length > max_len:
355
+ target_length = max_len
356
+
357
+ target_affinity = TargetSpecificBindingAffinity(multi_target_affinity, target_seq)
358
+ reward_model = create_reward_function(
359
+ affinity_predictor=target_affinity,
360
+ directional_oracle=directional_oracle,
361
+ target_direction=d_star,
362
+ target_protein_tokens=target_protein_tokens,
363
+ tokenizer=tokenizer,
364
+ device=device,
365
+ min_affinity_threshold=args.min_affinity_threshold,
366
+ use_confidence_weighting=True,
367
+ temperature=args.sigmoid_temperature,
368
+ )
369
+
370
+ pool_size = args.val_samples_per_target
371
+ if getattr(args, "num_pool", None) is not None:
372
+ pool_size = int(args.num_pool)
373
+ if pool_size < args.val_samples_per_target:
374
+ print(
375
+ f"[warn] num_pool ({pool_size}) < val_samples_per_target "
376
+ f"({args.val_samples_per_target}); using val_samples_per_target."
377
+ )
378
+ pool_size = args.val_samples_per_target
379
+
380
+ x_eval = _sample_sequences(
381
+ policy_model,
382
+ batch_size=pool_size,
383
+ seq_length=target_length,
384
+ total_num_steps=args.total_num_steps,
385
+ sampling_eps=args.sampling_eps,
386
+ )
387
+
388
+ sequences = tokenizer.batch_decode(x_eval)
389
+ valid_mask = np.array([analyzer.is_peptide(seq) for seq in sequences], dtype=bool)
390
+ valid_fraction = float(valid_mask.mean()) if valid_mask.size else 0.0
391
+
392
+ gated_rewards, affinities, directions, confidences = _score_sequences(reward_model, sequences)
393
+ direction_accuracy = _compute_direction_accuracy(directions, d_star)
394
+ consistency = d_star * (directions - 0.5)
395
+ success_rate = direction_accuracy * valid_fraction
396
+
397
+ if resample_enabled:
398
+ finite_rewards = np.isfinite(gated_rewards)
399
+ if np.any(finite_rewards):
400
+ rewards_t = torch.as_tensor(gated_rewards[finite_rewards], device=device)
401
+ alpha = max(float(resample_alpha), 1e-6)
402
+ weights = torch.softmax(rewards_t / alpha, dim=0)
403
+ if resample_with_replacement:
404
+ num_samples = args.val_samples_per_target
405
+ idx = torch.multinomial(weights, num_samples=num_samples, replacement=True)
406
+ else:
407
+ num_samples = min(args.val_samples_per_target, int(finite_rewards.sum()))
408
+ idx = torch.multinomial(weights, num_samples=num_samples, replacement=False)
409
+
410
+ valid_idx = np.where(finite_rewards)[0]
411
+ chosen = valid_idx[idx.detach().cpu().numpy()]
412
+ if d_star > 0:
413
+ resampled_affinity_pos.extend(affinities[chosen].tolist())
414
+ resampled_acc_pos.extend(direction_accuracy[chosen].tolist())
415
+ else:
416
+ resampled_affinity_neg.extend(affinities[chosen].tolist())
417
+ resampled_acc_neg.extend(direction_accuracy[chosen].tolist())
418
+ resampled_gated_rewards.extend(gated_rewards[chosen].tolist())
419
+
420
+ for picked in chosen.tolist():
421
+ resampled_records.append({
422
+ "target": target_seq[:20],
423
+ "sequence": sequences[picked],
424
+ "target_direction": d_star,
425
+ "is_valid": bool(valid_mask[picked]) if valid_mask.size else False,
426
+ "affinity": float(affinities[picked]) if affinities.size else np.nan,
427
+ "gated_reward": float(gated_rewards[picked]) if gated_rewards.size else np.nan,
428
+ "direction_oracle": float(directions[picked]) if directions.size else np.nan,
429
+ "consistency_reward": float(consistency[picked]) if consistency.size else np.nan,
430
+ "direction_accuracy": float(direction_accuracy[picked]) if direction_accuracy.size else np.nan,
431
+ "success_rate": float(success_rate[picked]) if success_rate.size else np.nan,
432
+ })
433
+
434
+ for idx, seq in enumerate(sequences):
435
+ records.append({
436
+ "target": target_seq[:20],
437
+ "sequence": seq,
438
+ "target_direction": d_star,
439
+ "is_valid": bool(valid_mask[idx]) if valid_mask.size else False,
440
+ "affinity": float(affinities[idx]) if affinities.size else np.nan,
441
+ "gated_reward": float(gated_rewards[idx]) if gated_rewards.size else np.nan,
442
+ "direction_oracle": float(directions[idx]) if directions.size else np.nan,
443
+ "consistency_reward": float(consistency[idx]) if consistency.size else np.nan,
444
+ "direction_accuracy": float(direction_accuracy[idx]) if direction_accuracy.size else np.nan,
445
+ "success_rate": float(success_rate[idx]) if success_rate.size else np.nan,
446
+ })
447
+
448
+ if world_size > 1:
449
+ gathered: List[List[Dict[str, Any]]] = [None for _ in range(world_size)]
450
+ dist.all_gather_object(gathered, records)
451
+ if is_main_process():
452
+ all_records = [item for sub in gathered for item in sub]
453
+ else:
454
+ all_records = []
455
+ else:
456
+ all_records = records
457
+
458
+ if world_size > 1:
459
+ gathered_resampled_records: List[List[Dict[str, Any]]] = [None for _ in range(world_size)]
460
+ dist.all_gather_object(gathered_resampled_records, resampled_records)
461
+ if is_main_process():
462
+ all_resampled_records = [item for sub in gathered_resampled_records for item in sub]
463
+ else:
464
+ all_resampled_records = []
465
+ else:
466
+ all_resampled_records = resampled_records
467
+
468
+ if world_size > 1:
469
+ resampled_payload = {
470
+ "aff_pos": resampled_affinity_pos,
471
+ "aff_neg": resampled_affinity_neg,
472
+ "acc_pos": resampled_acc_pos,
473
+ "acc_neg": resampled_acc_neg,
474
+ "gated": resampled_gated_rewards,
475
+ }
476
+ gathered_resampled = [None for _ in range(world_size)]
477
+ dist.all_gather_object(gathered_resampled, resampled_payload)
478
+ if is_main_process():
479
+ resampled_affinity_pos = []
480
+ resampled_affinity_neg = []
481
+ resampled_acc_pos = []
482
+ resampled_acc_neg = []
483
+ resampled_gated_rewards = []
484
+ for payload in gathered_resampled:
485
+ resampled_affinity_pos.extend(payload.get("aff_pos", []))
486
+ resampled_affinity_neg.extend(payload.get("aff_neg", []))
487
+ resampled_acc_pos.extend(payload.get("acc_pos", []))
488
+ resampled_acc_neg.extend(payload.get("acc_neg", []))
489
+ resampled_gated_rewards.extend(payload.get("gated", []))
490
+
491
+ if is_main_process():
492
+ df = pd.DataFrame(all_records)
493
+ output_path = os.path.join(args.save_path, f"validation_epoch_{cli_args.epoch}.csv")
494
+ df.to_csv(output_path, index=False)
495
+ print(f"Validation sequences saved to {output_path}")
496
+
497
+ if resample_enabled:
498
+ if all_resampled_records:
499
+ resampled_df = pd.DataFrame(all_resampled_records)
500
+ resampled_path = os.path.join(args.save_path, f"validation_epoch_{cli_args.epoch}_resampled.csv")
501
+ resampled_df.to_csv(resampled_path, index=False)
502
+ print(f"Resampled sequences saved to {resampled_path}")
503
+ else:
504
+ print("Resampling enabled but no finite rewards were available to select.")
505
+
506
+ if resample_enabled and resampled_gated_rewards:
507
+ aff_mean_pos = _nanmean(np.asarray(resampled_affinity_pos, dtype=np.float32))
508
+ aff_std_pos = _nanstd(np.asarray(resampled_affinity_pos, dtype=np.float32))
509
+ acc_mean_pos = _nanmean(np.asarray(resampled_acc_pos, dtype=np.float32))
510
+ acc_std_pos = _nanstd(np.asarray(resampled_acc_pos, dtype=np.float32))
511
+
512
+ aff_mean_neg = _nanmean(np.asarray(resampled_affinity_neg, dtype=np.float32))
513
+ aff_std_neg = _nanstd(np.asarray(resampled_affinity_neg, dtype=np.float32))
514
+ acc_mean_neg = _nanmean(np.asarray(resampled_acc_neg, dtype=np.float32))
515
+ acc_std_neg = _nanstd(np.asarray(resampled_acc_neg, dtype=np.float32))
516
+
517
+ gated = np.asarray(resampled_gated_rewards, dtype=np.float32)
518
+ gated_mean = _nanmean(gated)
519
+ gated_std = _nanstd(gated)
520
+ else:
521
+ def _stats_for_direction(d_star: float) -> Tuple[float, float, float, float]:
522
+ subset = df[df["target_direction"] == d_star]
523
+ affinity = subset["affinity"].to_numpy(dtype=np.float32)
524
+ direction_acc = subset["direction_accuracy"].to_numpy(dtype=np.float32)
525
+ return _nanmean(affinity), _nanstd(affinity), _nanmean(direction_acc), _nanstd(direction_acc)
526
+
527
+ aff_mean_pos, aff_std_pos, acc_mean_pos, acc_std_pos = _stats_for_direction(1.0)
528
+ aff_mean_neg, aff_std_neg, acc_mean_neg, acc_std_neg = _stats_for_direction(-1.0)
529
+ gated = df["gated_reward"].to_numpy(dtype=np.float32)
530
+ gated_mean = _nanmean(gated)
531
+ gated_std = _nanstd(gated)
532
+
533
+ print("Validation summary")
534
+ print(f" Affinity (d*=1): {aff_mean_pos:.4f} ± {aff_std_pos:.4f}")
535
+ print(f" Affinity (d*=-1): {aff_mean_neg:.4f} ± {aff_std_neg:.4f}")
536
+ print(f" Direction Accuracy (d*=1): {acc_mean_pos:.4f} ± {acc_std_pos:.4f}")
537
+ print(f" Direction Accuracy (d*=-1): {acc_mean_neg:.4f} ± {acc_std_neg:.4f}")
538
+ print(f" Gated Reward (overall): {gated_mean:.4f} ± {gated_std:.4f}")
539
+
540
+ if world_size > 1:
541
+ cleanup_distributed()
542
+
543
+
544
+ if __name__ == "__main__":
545
+ main()
546
+
547
+ # Running command:
548
+ # CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29501 run_validation_td3b.py --ckpt_path To Be Added --val_csv To Be Added --device cuda:0 --save_path To Be Added --epoch 99 --val_samples_per_target 8 --seed 42 --resample_alpha 0.1
baselines/sampling_setup.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import time
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Optional
7
+
8
+ import numpy as np
9
+ ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
10
+ if ROOT_DIR not in sys.path:
11
+ sys.path.insert(0, ROOT_DIR)
12
+
13
+ import torch
14
+ from hydra import compose, initialize_config_dir
15
+ from hydra.core.global_hydra import GlobalHydra
16
+
17
+ from diffusion import Diffusion
18
+ from scoring.scoring_functions import ScoringFunctions
19
+ from scoring.functions.binding import MultiTargetBindingAffinity
20
+ from td3b.direction_oracle import DirectionalOracle, resolve_device
21
+ from td3b.data_utils import peptide_seq_to_smiles, smiles_token_length
22
+
23
+ from baselines.baselines import (
24
+ RewardInputs,
25
+ RewardWrapper,
26
+ classifier_guidance,
27
+ peptune_mctg_sampling,
28
+ sequential_monte_carlo,
29
+ twisted_diffusion_sampler,
30
+ unguided_sampling,
31
+ )
32
+
33
+
34
+ AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
35
+
36
+
37
+ @dataclass
38
+ class ProteinTokenizer:
39
+ aa_to_id: Dict[str, int]
40
+ pad_id: int = 0
41
+
42
+ @classmethod
43
+ def default(cls) -> "ProteinTokenizer":
44
+ aa_to_id = {aa: idx + 1 for idx, aa in enumerate(AMINO_ACIDS)}
45
+ return cls(aa_to_id=aa_to_id, pad_id=0)
46
+
47
+ def encode(self, seq: str) -> torch.Tensor:
48
+ ids = [self.aa_to_id.get(aa, self.pad_id) for aa in seq]
49
+ return torch.tensor([ids], dtype=torch.long)
50
+
51
+
52
+ def load_base_model(
53
+ ckpt_path: str,
54
+ device: str,
55
+ config_name: str = "peptune_config.yaml",
56
+ ) -> Diffusion:
57
+ GlobalHydra.instance().clear()
58
+ config_dir = os.path.join(os.path.dirname(__file__), "..", "configs")
59
+ initialize_config_dir(config_dir=config_dir, job_name="load_model")
60
+ cfg = compose(config_name=config_name)
61
+ try:
62
+ model = Diffusion.load_from_checkpoint(
63
+ ckpt_path,
64
+ config=cfg,
65
+ mode="eval",
66
+ device=device,
67
+ map_location=device,
68
+ )
69
+ model.eval()
70
+ return model
71
+ except Exception as exc:
72
+ print(f"[load_base_model] Lightning load failed, falling back to raw state_dict: {exc}")
73
+
74
+ checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
75
+ if isinstance(checkpoint, dict):
76
+ if "model_state_dict" in checkpoint:
77
+ state_dict = checkpoint["model_state_dict"]
78
+ elif "state_dict" in checkpoint:
79
+ state_dict = checkpoint["state_dict"]
80
+ else:
81
+ state_dict = checkpoint
82
+ else:
83
+ raise ValueError(f"Unsupported checkpoint format: {type(checkpoint)}")
84
+
85
+ model = Diffusion(
86
+ config=cfg,
87
+ mode="eval",
88
+ device=device,
89
+ )
90
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
91
+ if missing:
92
+ print(f"[load_base_model] Missing keys: {len(missing)}")
93
+ if unexpected:
94
+ print(f"[load_base_model] Unexpected keys: {len(unexpected)}")
95
+ model.eval()
96
+ model.to(device)
97
+ return model
98
+
99
+
100
+ def load_reward_models(
101
+ prot_seq: Optional[str],
102
+ device: str,
103
+ base_model: Optional[Diffusion] = None,
104
+ base_path: Optional[str] = None,
105
+ multi_target: bool = False,
106
+ score_func_names: Optional[List[str]] = None,
107
+ ):
108
+ if multi_target:
109
+ if base_model is None or base_path is None:
110
+ raise ValueError("base_model and base_path are required for multi-target affinity.")
111
+ return MultiTargetBindingAffinity(
112
+ tokenizer=base_model.tokenizer,
113
+ base_path=base_path,
114
+ device=device,
115
+ emb_model=base_model.backbone,
116
+ )
117
+ if score_func_names is None:
118
+ score_func_names = [
119
+ "binding_affinity1",
120
+ "solubility",
121
+ "hemolysis",
122
+ "nonfouling",
123
+ "permeability",
124
+ ]
125
+ if prot_seq is None:
126
+ raise ValueError("prot_seq is required for single-target scoring.")
127
+ return ScoringFunctions(score_func_names, prot_seqs=[prot_seq], device=device)
128
+
129
+
130
+ def load_direction_oracle(args, device: str) -> DirectionalOracle:
131
+ oracle = DirectionalOracle(
132
+ model_ckpt=args.direction_oracle_ckpt,
133
+ tr2d2_checkpoint=args.direction_oracle_tr2d2_checkpoint,
134
+ tokenizer_vocab=args.direction_oracle_tokenizer_vocab,
135
+ tokenizer_splits=args.direction_oracle_tokenizer_splits,
136
+ esm_name=args.direction_oracle_esm_name,
137
+ d_model=args.direction_oracle_d_model,
138
+ n_heads=args.direction_oracle_n_heads,
139
+ n_self_attn_layers=args.direction_oracle_n_self_attn_layers,
140
+ n_bmca_layers=args.direction_oracle_n_bmca_layers,
141
+ dropout=args.direction_oracle_dropout,
142
+ max_ligand_length=args.direction_oracle_max_ligand_length,
143
+ max_protein_length=args.direction_oracle_max_protein_length,
144
+ device=device,
145
+ esm_cache_dir=args.direction_oracle_esm_cache_dir,
146
+ esm_local_files_only=args.direction_oracle_esm_local_files_only,
147
+ )
148
+ oracle.eval()
149
+ return oracle
150
+
151
+
152
+ def run_baseline(
153
+ baseline: str,
154
+ base_model: Diffusion,
155
+ reward_fn: RewardWrapper,
156
+ batch_size: int,
157
+ seq_length: int,
158
+ num_steps: int,
159
+ guidance_scale: float,
160
+ alpha: float,
161
+ guidance_steps: Optional[int],
162
+ mcts_iterations: int,
163
+ num_children: int,
164
+ sample_prob_weight: float,
165
+ invalid_penalty: float,
166
+ pareto_max_size: Optional[int],
167
+ ) -> Dict[str, torch.Tensor]:
168
+ baseline = baseline.lower()
169
+ if baseline == "cg":
170
+ return classifier_guidance(
171
+ base_model,
172
+ reward_fn,
173
+ batch_size=batch_size,
174
+ seq_length=seq_length,
175
+ num_steps=num_steps,
176
+ guidance_scale=guidance_scale,
177
+ guidance_steps=guidance_steps,
178
+ )
179
+ if baseline == "unguided":
180
+ return unguided_sampling(
181
+ base_model,
182
+ batch_size=batch_size,
183
+ seq_length=seq_length,
184
+ num_steps=num_steps,
185
+ )
186
+ if baseline == "smc":
187
+ return sequential_monte_carlo(
188
+ base_model,
189
+ reward_fn,
190
+ batch_size=batch_size,
191
+ seq_length=seq_length,
192
+ num_steps=num_steps,
193
+ alpha=alpha,
194
+ )
195
+ if baseline == "tds":
196
+ return twisted_diffusion_sampler(
197
+ base_model,
198
+ reward_fn,
199
+ batch_size=batch_size,
200
+ seq_length=seq_length,
201
+ num_steps=num_steps,
202
+ guidance_scale=guidance_scale,
203
+ alpha=alpha,
204
+ guidance_steps=guidance_steps,
205
+ )
206
+ if baseline == "peptune":
207
+ return peptune_mctg_sampling(
208
+ base_model,
209
+ reward_fn,
210
+ batch_size=batch_size,
211
+ seq_length=seq_length,
212
+ num_steps=num_steps,
213
+ mcts_iterations=mcts_iterations,
214
+ num_children=num_children,
215
+ alpha=alpha,
216
+ sample_prob_weight=sample_prob_weight,
217
+ invalid_penalty=invalid_penalty,
218
+ pareto_max_size=pareto_max_size,
219
+ )
220
+ raise ValueError(f"Unknown baseline: {baseline}")
221
+
222
+
223
+ def main():
224
+ parser = argparse.ArgumentParser()
225
+ parser.add_argument("--ckpt_path", type=str, required=True)
226
+ parser.add_argument("--device", type=str, default="cuda:0")
227
+ parser.add_argument("--baseline", type=str, default="cg", choices=["cg", "smc", "tds", "unguided", "peptune"])
228
+ parser.add_argument("--prot_seq", type=str, default=None)
229
+ parser.add_argument("--targets_csv", type=str, default=None)
230
+ parser.add_argument("--d_star", type=float, default=1.0)
231
+ parser.add_argument("--batch_size", type=int, default=32)
232
+ parser.add_argument("--seq_length", type=int, default=200)
233
+ parser.add_argument("--binder_seq", type=str, default=None)
234
+ parser.add_argument("--num_steps", type=int, default=128)
235
+ parser.add_argument("--guidance_scale", type=float, default=1.0)
236
+ parser.add_argument("--alpha", type=float, default=0.1)
237
+ parser.add_argument("--reward_alpha", type=float, default=None)
238
+ parser.add_argument("--mcts_iterations", type=int, default=20)
239
+ parser.add_argument("--num_children", type=int, default=24)
240
+ parser.add_argument("--sample_prob_weight", type=float, default=0.1)
241
+ parser.add_argument("--invalid_penalty", type=float, default=1.0)
242
+ parser.add_argument("--pareto_max_size", type=int, default=None)
243
+ parser.add_argument("--guidance_steps", type=int, default=None)
244
+ parser.add_argument("--fast_direction", action="store_true", default=False)
245
+ parser.add_argument("--num_batches", type=int, default=1)
246
+ parser.add_argument("--output_dir", type=str, default=None)
247
+ parser.add_argument("--shard_id", type=int, default=None)
248
+ parser.add_argument("--num_shards", type=int, default=None)
249
+ parser.add_argument("--direction_oracle_ckpt", type=str, default=None)
250
+ parser.add_argument("--direction_oracle_tr2d2_checkpoint", type=str, default=None)
251
+ parser.add_argument("--direction_oracle_tokenizer_vocab", type=str, default=None)
252
+ parser.add_argument("--direction_oracle_tokenizer_splits", type=str, default=None)
253
+ parser.add_argument("--direction_oracle_esm_name", type=str, default="facebook/esm2_t33_650M_UR50D")
254
+ parser.add_argument("--direction_oracle_esm_cache_dir", type=str, default=None)
255
+ parser.add_argument("--direction_oracle_esm_local_files_only", action="store_true", default=False)
256
+ parser.add_argument("--direction_oracle_max_ligand_length", type=int, default=768)
257
+ parser.add_argument("--direction_oracle_max_protein_length", type=int, default=1024)
258
+ parser.add_argument("--direction_oracle_d_model", type=int, default=256)
259
+ parser.add_argument("--direction_oracle_n_heads", type=int, default=4)
260
+ parser.add_argument("--direction_oracle_n_self_attn_layers", type=int, default=1)
261
+ parser.add_argument("--direction_oracle_n_bmca_layers", type=int, default=2)
262
+ parser.add_argument("--direction_oracle_dropout", type=float, default=0.3)
263
+ args = parser.parse_args()
264
+
265
+ rank_env = os.environ.get("LOCAL_RANK")
266
+ world_env = os.environ.get("WORLD_SIZE")
267
+ if rank_env is not None or world_env is not None:
268
+ rank = int(rank_env or 0)
269
+ world_size = int(world_env or 1)
270
+ else:
271
+ rank = int(args.shard_id) if args.shard_id is not None else 0
272
+ world_size = int(args.num_shards) if args.num_shards is not None else 1
273
+ if world_size < 1:
274
+ world_size = 1
275
+ if world_size > 1 and str(args.device).lower() in {"cuda", "cuda:0", "auto"}:
276
+ args.device = f"cuda:{rank}"
277
+
278
+ resolved_device = resolve_device(args.device)
279
+ args.device = str(resolved_device)
280
+
281
+ tr2d2_root = ROOT_DIR
282
+ if args.direction_oracle_ckpt is None:
283
+ args.direction_oracle_ckpt = os.path.join(
284
+ tr2d2_root, "direction_oracle.pt"
285
+ )
286
+ if args.direction_oracle_tr2d2_checkpoint is None:
287
+ args.direction_oracle_tr2d2_checkpoint = os.path.join(
288
+ tr2d2_root, "pretrained", "peptune-pretrained.ckpt"
289
+ )
290
+ if args.direction_oracle_tokenizer_vocab is None:
291
+ args.direction_oracle_tokenizer_vocab = os.path.join(
292
+ tr2d2_root, "tokenizer", "new_vocab.txt"
293
+ )
294
+ if args.direction_oracle_tokenizer_splits is None:
295
+ args.direction_oracle_tokenizer_splits = os.path.join(
296
+ tr2d2_root, "tokenizer", "new_splits.txt"
297
+ )
298
+
299
+ if args.targets_csv is None and args.prot_seq is None:
300
+ raise ValueError("--prot_seq is required when --targets_csv is not provided.")
301
+
302
+ base_model = load_base_model(args.ckpt_path, args.device)
303
+ base_path = os.path.abspath(os.path.join(ROOT_DIR, ".."))
304
+ multi_target = args.targets_csv is not None
305
+ scoring_fn = load_reward_models(
306
+ args.prot_seq if not multi_target else None,
307
+ args.device,
308
+ base_model=base_model,
309
+ base_path=base_path,
310
+ multi_target=multi_target,
311
+ )
312
+ direction_oracle = load_direction_oracle(args, args.device)
313
+ reward_alpha = args.reward_alpha if args.reward_alpha is not None else args.alpha
314
+
315
+ if args.targets_csv:
316
+ import pandas as pd
317
+
318
+ df = pd.read_csv(args.targets_csv)
319
+ if "Target_Sequence" not in df.columns:
320
+ raise ValueError("targets_csv must contain a 'Target_Sequence' column.")
321
+ if "Ligand_Sequence" not in df.columns:
322
+ raise ValueError("targets_csv must contain a 'Ligand_Sequence' column.")
323
+
324
+ targets = []
325
+ for row_idx, row in df.iterrows():
326
+ target_seq = str(row["Target_Sequence"]) if pd.notna(row["Target_Sequence"]) else None
327
+ if not target_seq:
328
+ continue
329
+ binder_seq = row["Ligand_Sequence"]
330
+ if pd.isna(binder_seq):
331
+ binder_seq = None
332
+ else:
333
+ binder_seq = str(binder_seq)
334
+ if binder_seq.strip() == "":
335
+ binder_seq = None
336
+ targets.append(
337
+ {
338
+ "target_seq": target_seq,
339
+ "binder_seq": binder_seq,
340
+ "row_index": int(row_idx),
341
+ }
342
+ )
343
+ else:
344
+ targets = [{"target_seq": args.prot_seq, "binder_seq": args.binder_seq, "row_index": 0}]
345
+
346
+ if world_size > 1:
347
+ targets = [item for idx, item in enumerate(targets) if idx % world_size == rank]
348
+ print(f"[shard] rank {rank}/{world_size}: {len(targets)} targets")
349
+
350
+ output_dir = args.output_dir
351
+ if output_dir is None:
352
+ output_dir = os.path.join(os.path.dirname(__file__), "outputs")
353
+ os.makedirs(output_dir, exist_ok=True)
354
+
355
+ from utils.app import PeptideAnalyzer
356
+
357
+ analyzer = PeptideAnalyzer()
358
+ all_rows = []
359
+ batch_rows = []
360
+ metrics_rows = []
361
+ def resolve_seq_length(binder_seq: Optional[str]) -> int:
362
+ if not binder_seq:
363
+ return args.seq_length
364
+ try:
365
+ smiles = peptide_seq_to_smiles(binder_seq)
366
+ if not smiles:
367
+ return args.seq_length
368
+ if base_model.tokenizer is None:
369
+ return len(smiles)
370
+ return smiles_token_length(smiles, base_model.tokenizer)
371
+ except Exception as exc:
372
+ print(f"Warning: failed to derive seq_length from binder_seq; using {args.seq_length}. Error: {exc}")
373
+ return args.seq_length
374
+
375
+ for target_idx, target_info in enumerate(targets):
376
+ target_seq = target_info["target_seq"]
377
+ binder_seq = target_info.get("binder_seq")
378
+ row_index = target_info.get("row_index", target_idx)
379
+ seq_length = resolve_seq_length(binder_seq)
380
+ protein_tokens = direction_oracle.encode_protein(target_seq)
381
+ for direction_name, d_star in [("agonist", 1.0), ("antagonist", -1.0)]:
382
+
383
+ reward_inputs = RewardInputs(
384
+ protein_tokens=protein_tokens,
385
+ d_star=d_star,
386
+ protein_seq=target_seq,
387
+ )
388
+ reward_fn = RewardWrapper(
389
+ scoring_fn=scoring_fn,
390
+ direction_oracle=direction_oracle,
391
+ base_model=base_model,
392
+ tokenizer=base_model.tokenizer,
393
+ reward_inputs=reward_inputs,
394
+ device=torch.device(args.device),
395
+ fast_direction=args.fast_direction,
396
+ reward_alpha=reward_alpha,
397
+ )
398
+
399
+ num_batches = 1 if multi_target else args.num_batches
400
+ for batch_idx in range(num_batches):
401
+ start = time.perf_counter()
402
+ result = run_baseline(
403
+ args.baseline,
404
+ base_model,
405
+ reward_fn,
406
+ batch_size=args.batch_size,
407
+ seq_length=seq_length,
408
+ num_steps=args.num_steps,
409
+ guidance_scale=args.guidance_scale,
410
+ alpha=args.alpha,
411
+ guidance_steps=args.guidance_steps,
412
+ mcts_iterations=args.mcts_iterations,
413
+ num_children=args.num_children,
414
+ sample_prob_weight=args.sample_prob_weight,
415
+ invalid_penalty=args.invalid_penalty,
416
+ pareto_max_size=args.pareto_max_size,
417
+ )
418
+ elapsed = time.perf_counter() - start
419
+
420
+ scores = reward_fn.evaluate_tokens(
421
+ result["tokens"],
422
+ torch.ones_like(result["tokens"], device=result["tokens"].device),
423
+ )
424
+ sequences = scores["sequences"]
425
+ affinity = scores["affinity"].detach().cpu().numpy()
426
+ direction = scores["direction"].detach().cpu().numpy()
427
+ gated_reward = scores["gated_reward"].detach().cpu().numpy()
428
+ valid_mask = np.array([analyzer.is_peptide(seq) for seq in sequences], dtype=np.float32)
429
+ valid_fraction = float(valid_mask.mean()) if len(valid_mask) else 0.0
430
+ consistency = d_star * (direction - 0.5)
431
+ if d_star > 0:
432
+ direction_correct = (direction >= 0.5).astype(np.float32)
433
+ else:
434
+ direction_correct = (direction < 0.5).astype(np.float32)
435
+ success = direction_correct * valid_mask
436
+ direction_mean = float(np.mean(direction))
437
+ direction_std = float(np.std(direction))
438
+ affinity_mean = float(np.mean(affinity))
439
+ affinity_std = float(np.std(affinity))
440
+ consistency_mean = float(np.mean(consistency))
441
+ consistency_std = float(np.std(consistency))
442
+ gated_reward_mean = float(np.mean(gated_reward))
443
+ gated_reward_std = float(np.std(gated_reward))
444
+ direction_acc_mean = float(np.mean(direction_correct))
445
+ direction_acc_std = float(np.std(direction_correct))
446
+ success_rate_mean = float(np.mean(success))
447
+ success_rate_std = float(np.std(success))
448
+ batch_metrics = {
449
+ "direction_mean": direction_mean,
450
+ "direction_std": direction_std,
451
+ "affinity_mean": affinity_mean,
452
+ "affinity_std": affinity_std,
453
+ "consistency_mean": consistency_mean,
454
+ "consistency_std": consistency_std,
455
+ "gated_reward_mean": gated_reward_mean,
456
+ "gated_reward_std": gated_reward_std,
457
+ "direction_accuracy_mean": direction_acc_mean,
458
+ "direction_accuracy_std": direction_acc_std,
459
+ "valid_fraction": valid_fraction,
460
+ "success_rate_mean": success_rate_mean,
461
+ "success_rate_std": success_rate_std,
462
+ }
463
+
464
+ for i, seq in enumerate(sequences):
465
+ all_rows.append(
466
+ {
467
+ "rank": rank,
468
+ "sequence": seq,
469
+ "affinity": float(affinity[i]),
470
+ "direction": float(direction[i]),
471
+ "d_star": float(d_star),
472
+ "direction_name": direction_name,
473
+ "target_seq": target_seq,
474
+ "target_index": target_idx,
475
+ "row_index": row_index,
476
+ "binder_seq": binder_seq,
477
+ "seq_length": seq_length,
478
+ "gated_reward": float(gated_reward[i]),
479
+ "consistency_reward": float(consistency[i]),
480
+ "direction_accuracy": float(direction_correct[i]),
481
+ "valid": float(valid_mask[i]),
482
+ "success": float(success[i]),
483
+ "batch_index": batch_idx,
484
+ "batch_time_sec": elapsed,
485
+ **batch_metrics,
486
+ }
487
+ )
488
+ batch_rows.append(
489
+ {
490
+ "rank": rank,
491
+ "batch_index": batch_idx,
492
+ "batch_time_sec": elapsed,
493
+ "target_index": target_idx,
494
+ "row_index": row_index,
495
+ "binder_seq": binder_seq,
496
+ "seq_length": seq_length,
497
+ "direction_name": direction_name,
498
+ }
499
+ )
500
+ metrics_rows.append(
501
+ {
502
+ "rank": rank,
503
+ "target_index": target_idx,
504
+ "target_seq": target_seq,
505
+ "row_index": row_index,
506
+ "binder_seq": binder_seq,
507
+ "seq_length": seq_length,
508
+ "direction_name": direction_name,
509
+ "d_star": float(d_star),
510
+ "batch_index": batch_idx,
511
+ "num_samples": len(sequences),
512
+ **batch_metrics,
513
+ }
514
+ )
515
+ print(
516
+ f"Target {target_idx} dir {direction_name}: "
517
+ f"generated {len(sequences)} sequences in {elapsed:.3f}s"
518
+ )
519
+
520
+ import pandas as pd
521
+
522
+ if world_size > 1:
523
+ output_csv = os.path.join(output_dir, f"{args.baseline}_samples_rank{rank}.csv")
524
+ batch_csv = os.path.join(output_dir, f"batch_times_rank{rank}.csv")
525
+ metrics_csv = os.path.join(output_dir, f"{args.baseline}_metrics_rank{rank}.csv")
526
+ else:
527
+ output_csv = os.path.join(output_dir, f"{args.baseline}_samples.csv")
528
+ batch_csv = os.path.join(output_dir, "batch_times.csv")
529
+ metrics_csv = os.path.join(output_dir, f"{args.baseline}_metrics.csv")
530
+ pd.DataFrame(all_rows).to_csv(output_csv, index=False)
531
+ pd.DataFrame(batch_rows).to_csv(batch_csv, index=False)
532
+ pd.DataFrame(metrics_rows).to_csv(metrics_csv, index=False)
533
+
534
+ print(f"Saved samples to {output_csv}")
535
+
536
+
537
+ if __name__ == "__main__":
538
+ main()
configs/finetune_config.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Shared Configuration Classes for TD3B Finetuning
3
+
4
+ This module contains all configuration dataclasses used by both:
5
+ - finetune_v1.py (single-target training)
6
+ - finetune_multi_target.py (multi-target training)
7
+
8
+ Extracted to avoid code duplication and ensure consistency.
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Optional
13
+
14
+
15
+ @dataclass
16
+ class RoFormerConfig:
17
+ """Configuration for RoFormer model architecture."""
18
+ hidden_size: int
19
+ n_layers: int
20
+ n_heads: int
21
+ max_position_embeddings: int = 1035 # Must match pretrained model
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class NoiseConfig:
26
+ """Configuration for noise scheduling."""
27
+ type: str = 'loglinear'
28
+ sigma_min: float = 1e-4
29
+ sigma_max: float = 20.0
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class TrainingConfig:
34
+ """Configuration for training parameters."""
35
+ sampling_eps: float
36
+
37
+
38
+ @dataclass(frozen=True)
39
+ class SamplingConfig:
40
+ """Configuration for sampling parameters."""
41
+ steps: int
42
+ sampling_eps: float
43
+ predictor: str = 'ddpm_cache'
44
+
45
+
46
+ @dataclass(frozen=True)
47
+ class EvalConfig:
48
+ """Configuration for evaluation parameters."""
49
+ gen_ppl_eval_model_name_or_path: str = 'gpt2-large'
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class OptimConfig:
54
+ """Configuration for optimizer parameters."""
55
+ lr: float
56
+
57
+
58
+ @dataclass(frozen=True)
59
+ class MCTSConfig:
60
+ """Configuration for MCTS parameters."""
61
+ sampling: int = 0 # 0 for Gumbel sampling
62
+
63
+
64
+ class DiffusionConfig:
65
+ """
66
+ Complete configuration for Diffusion model.
67
+
68
+ This class encapsulates all nested configuration objects required
69
+ by the Diffusion model, providing a clean interface and type safety.
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ roformer: RoFormerConfig,
75
+ noise: NoiseConfig,
76
+ training: TrainingConfig,
77
+ sampling: SamplingConfig,
78
+ eval_cfg: EvalConfig,
79
+ optim: OptimConfig,
80
+ mcts: MCTSConfig
81
+ ):
82
+ # Create anonymous objects for backward compatibility
83
+ self.roformer = type('RoFormerObj', (), {
84
+ 'hidden_size': roformer.hidden_size,
85
+ 'n_layers': roformer.n_layers,
86
+ 'n_heads': roformer.n_heads,
87
+ 'max_position_embeddings': roformer.max_position_embeddings
88
+ })()
89
+
90
+ self.noise = type('NoiseObj', (), {
91
+ 'type': noise.type,
92
+ 'sigma_min': noise.sigma_min,
93
+ 'sigma_max': noise.sigma_max
94
+ })()
95
+
96
+ self.training = type('TrainingObj', (), {
97
+ 'sampling_eps': training.sampling_eps
98
+ })()
99
+
100
+ self.sampling = type('SamplingObj', (), {
101
+ 'steps': sampling.steps,
102
+ 'sampling_eps': sampling.sampling_eps,
103
+ 'predictor': sampling.predictor
104
+ })()
105
+
106
+ self.eval = type('EvalObj', (), {
107
+ 'gen_ppl_eval_model_name_or_path': eval_cfg.gen_ppl_eval_model_name_or_path
108
+ })()
109
+
110
+ self.optim = type('OptimObj', (), {
111
+ 'lr': optim.lr
112
+ })()
113
+
114
+ self.mcts = type('MCTSObj', (), {
115
+ 'sampling': mcts.sampling
116
+ })()
117
+
118
+ # Fixed parameters
119
+ self.backbone = 'roformer'
120
+ self.parameterization = 'subs'
121
+ self.time_conditioning = False
122
+ self.T = 0
configs/peptune_config.yaml ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ noise:
2
+ type: loglinear
3
+ sigma_min: 1e-4
4
+ sigma_max: 20
5
+ state_dependent: True
6
+
7
+ mode: ppl_eval # train / ppl_eval / sample_eval
8
+ diffusion: absorbing_state
9
+ vocab: old_smiles # old_smiles / new_smiles / selfies / helm
10
+ backbone: roformer # peptideclm / helmgpt / dit / roformer / finetune_roformer
11
+ parameterization: subs # subs
12
+ time_conditioning: False
13
+ T: 0 # 0 (continuous time) / 1000
14
+ subs_masking: False
15
+
16
+ seed: 42
17
+
18
+ mcts:
19
+ num_children: 50
20
+ num_objectives: 5
21
+ topk: 100
22
+ mask_token: 4
23
+ num_iter: 128
24
+ sampling: 0 # 0 is gumbel sampling / > 0 samples children from top k probs
25
+ invalid_penalty: 0.5
26
+ sample_prob: 1.0
27
+ perm: True
28
+ dual: False
29
+ single: False
30
+ time_dependent: True
31
+
32
+ lr_scheduler:
33
+ _target_: transformers.get_constant_schedule_with_warmup
34
+ num_warmup_steps: 2500
35
+
36
+ data:
37
+ train: To Be Added
38
+ valid: To Be Added
39
+ batchinohup ng: wrapping # padding / wrapping
40
+
41
+ loader:
42
+ global_batch_size: 64
43
+ eval_global_batch_size: ${.global_batch_size}
44
+ # Note: batch_size and eval_batch_size are **per machine**
45
+ batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
46
+ eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
47
+ num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
48
+ pin_memory: True
49
+
50
+ sampling:
51
+ predictor: ddpm_cache # analytic, ddpm, ddpm_cache
52
+ num_sequences: 100
53
+ sampling_eps: 1e-3
54
+ steps: 128
55
+ seq_length: 100
56
+ noise_removal: True
57
+ num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
58
+ num_sample_log: 2
59
+ stride_length: 1
60
+ num_strides: 1
61
+
62
+ training:
63
+ antithetic_sampling: True
64
+ sampling_eps: 1e-3
65
+ focus_mask: False
66
+ #dynamic_batching: True
67
+ accumulator: False
68
+
69
+ eval:
70
+ checkpoint_path:
71
+ disable_ema: False
72
+ compute_generative_perplexity: False
73
+ perplexity_batch_size: 8
74
+ compute_perplexity_on_sanity: False
75
+ gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
76
+ generate_samples: True
77
+ generation_model:
78
+
79
+ optim:
80
+ weight_decay: 0.075
81
+ lr: 3e-4
82
+ beta1: 0.9
83
+ beta2: 0.999
84
+ eps: 1e-8
85
+
86
+ pepclm:
87
+ hidden_size: 768
88
+ cond_dim: 256
89
+ n_heads: 20
90
+ n_blocks: 4
91
+ dropout: 0.5
92
+ length: 512
93
+ #scale_by_sigma: True
94
+
95
+ model:
96
+ type: ddit
97
+ hidden_size: 768
98
+ cond_dim: 128
99
+ length: 512
100
+ n_blocks: 12
101
+ n_heads: 12
102
+ scale_by_sigma: True
103
+ dropout: 0.1
104
+
105
+ roformer:
106
+ hidden_size: 768
107
+ n_layers: 8
108
+ n_heads: 8
109
+ max_position_embeddings: 1035
110
+
111
+ helmgpt:
112
+ hidden_size: 256
113
+ embd_pdrop: 0.1
114
+ resid_pdrop: 0.1
115
+ attn_pdrop: 0.1
116
+ ff_dropout: 0.
117
+ block_size: 140
118
+ n_layer: 8
119
+ n_heads: 8
120
+
121
+
122
+ trainer:
123
+ _target_: lightning.Trainer
124
+ accelerator: cuda
125
+ num_nodes: 1
126
+ devices: ${device_count:}
127
+ accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
128
+ gradient_clip_val: 1.0
129
+ precision: 64-true
130
+ num_sanity_val_steps: 2
131
+ max_epochs: 100
132
+ max_steps: 1_000_000
133
+ log_every_n_steps: 10
134
+ limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
135
+ limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
136
+ #val_check_interval: 40 #954
137
+ check_val_every_n_epoch: 1
138
+
139
+ hydra:
140
+ run:
141
+ dir: ./${now:%Y.%m.%d}/
142
+ job:
143
+ chdir: True
144
+
145
+ checkpointing:
146
+ # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
147
+ save_dir: ${cwd:}
148
+ # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
149
+ resume_from_ckpt: True
150
+ resume_ckpt_path:
151
+
152
+ callbacks:
153
+ model_checkpoint:
154
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
155
+ every_n_epochs: 1
156
+ monitor: "val/nll"
157
+ save_top_k: 10
158
+ mode: "min"
159
+ dirpath:
diffusion.py ADDED
@@ -0,0 +1,1588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import sys
3
+ import itertools
4
+ import time
5
+ import torch
6
+ from torch import Tensor
7
+ import math
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import random as rd
11
+ import lightning as L
12
+ import torchmetrics
13
+ from dataclasses import dataclass
14
+ import gc
15
+ import utils.utils as utils
16
+
17
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
18
+ import noise_schedule
19
+ from torch.optim.lr_scheduler import _LRScheduler
20
+ import roformer as roformer
21
+ from utils.app import PeptideAnalyzer
22
+ import pandas as pd
23
+
24
+ base_path = 'To Be Added'
25
+
26
+ def _sample_categorical(categorical_probs):
27
+ gumbel_norm = (
28
+ 1e-10
29
+ - (torch.rand_like(categorical_probs) + 1e-10).log())
30
+ return (categorical_probs / gumbel_norm).argmax(dim=-1).to(dtype=torch.long)
31
+
32
+ def _sample_categorical_gradient(categorical_probs, temp = 1.0):
33
+ gumbel_norm = (
34
+ 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log())
35
+ output = torch.nn.functional.softmax((torch.log(categorical_probs)-torch.log(gumbel_norm))/temp, 2)
36
+ return output
37
+
38
+ def _unsqueeze(x, reference):
39
+ return x.view(
40
+ * x.shape,
41
+ * ((1,) * (len(reference.shape) - len(x.shape))))
42
+
43
+ def sample_batched_categorical(categorical_probs, batch_size):
44
+ """
45
+ Generates `m` distinct sequences sampled from categorical probabilities
46
+ using the Gumbel distribution to ensure randomness while following probabilities
47
+
48
+ Args:
49
+ categorical_probs (torch.Tensor): tensor of shape (sequence_length, vocab_length)
50
+ representing categorical probabilities
51
+ m (int): number of distinct sequences to sample
52
+
53
+ Returns:
54
+ torch.Tensor: tensor of shape (m, sequence_length), where each row is a
55
+ distinct sequence of sampled category indices.
56
+ """
57
+ _, sequence_length, vocab_size = categorical_probs.shape
58
+
59
+ # add Gumbel noise and sample m sequences
60
+ gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_size) + 1e-10) + 1e-10)).to(categorical_probs.device)
61
+ noisy_scores = torch.log(categorical_probs) + gumbel_noise # add Gumbel noise to log probabilities
62
+
63
+ # select the highest score (most likely category after Gumbel noise)
64
+ sampled_sequences = noisy_scores.argmax(dim=-1).to(dtype=torch.long) # shape: (m, sequence_length)
65
+
66
+ return sampled_sequences
67
+
68
+ def sample_batched_top_k(categorical_probs, batch_size, k):
69
+ """
70
+ Generates `m` sequences sampled from the top-k probabilities of each token
71
+ using Gumbel noise to ensure randomness and reduce bias towards the most likely options.
72
+
73
+ Args:
74
+ categorical_probs (torch.Tensor): A tensor of shape (sequence_length, vocab_length)
75
+ representing categorical probabilities.
76
+ m (int): Number of sequences to sample.
77
+ k (int): Number of top probabilities to consider for sampling.
78
+
79
+ Returns:
80
+ torch.Tensor: A tensor of shape (m, sequence_length), where each row is a
81
+ sampled sequence of category indices.
82
+ """
83
+ _, sequence_length, vocab_length = categorical_probs.shape
84
+
85
+ # Add Gumbel noise to the log probabilities
86
+ gumbel_noise = -torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_length) + 1e-10) + 1e-10).to(categorical_probs.device)
87
+ noisy_scores = torch.log(categorical_probs[None, :, :]) + gumbel_noise # Shape: (m, sequence_length, vocab_length)
88
+
89
+ # Get the top-k categories based on noisy scores
90
+ top_k_scores, top_k_indices = torch.topk(noisy_scores, k, dim=-1) # Shape: (m, sequence_length, k)
91
+
92
+ # Convert top-k scores back to probabilities and normalize
93
+ top_k_probs = torch.softmax(top_k_scores, dim=-1).to(categorical_probs.device) # Shape: (m, sequence_length, k)
94
+
95
+ # Sample randomly from the top-k probabilities
96
+ sampled_indices_in_top_k = torch.multinomial(top_k_probs.reshape(-1, k), num_samples=1).squeeze(-1).to(categorical_probs.device)
97
+ sampled_indices_in_top_k = sampled_indices_in_top_k.view(batch_size, sequence_length).to(categorical_probs.device) # Shape: (batch_size, sequence_length)
98
+
99
+ # Map sampled indices back to the original vocabulary indices
100
+ sampled_sequences = torch.gather(top_k_indices, -1, sampled_indices_in_top_k.unsqueeze(-1)).squeeze(-1).to(categorical_probs.device).to(dtype=torch.long)
101
+
102
+ return sampled_sequences
103
+
104
+ @dataclass
105
+ class Loss:
106
+ loss: torch.FloatTensor
107
+ nlls: torch.FloatTensor
108
+ attn_mask: torch.FloatTensor
109
+
110
+
111
+ class NLL(torchmetrics.aggregation.MeanMetric):
112
+ pass
113
+
114
+
115
+ class BPD(NLL):
116
+ def compute(self) -> Tensor:
117
+ """Computes the bits per dimension.
118
+
119
+ Returns:
120
+ bpd
121
+ """
122
+ return self.mean_value / self.weight / math.log(2)
123
+
124
+
125
+ class Perplexity(NLL):
126
+ def compute(self) -> Tensor:
127
+ """Computes the Perplexity.
128
+
129
+ Returns:
130
+ Perplexity
131
+ """
132
+ return torch.exp(self.mean_value / self.weight)
133
+
134
+
135
+ class Diffusion(L.LightningModule):
136
+ def __init__(
137
+ self,
138
+ config,
139
+ tokenizer = None,
140
+ mode="finetune",
141
+ device=None,
142
+ ):
143
+
144
+ super().__init__()
145
+ self.config = config
146
+ #self.save_hyperparameters()
147
+
148
+ # PeptideCLM tokenizer
149
+ if tokenizer is None:
150
+ self.tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/tr2d2-pep/tokenizer/new_vocab.txt',
151
+ f'{base_path}/tr2d2-pep/tokenizer/new_splits.txt')
152
+ else:
153
+ self.tokenizer = tokenizer
154
+
155
+ self.vocab_size = self.tokenizer.vocab_size
156
+ self.mask_index = self.tokenizer.mask_token_id
157
+ self.sampler = self.config.sampling.predictor
158
+ self.analyzer = PeptideAnalyzer()
159
+
160
+ # backbone LM PeptideCLM model
161
+ self.backbone = roformer.Roformer(self.config, self.tokenizer, device=device)
162
+ if mode == "finetune":
163
+ self.backbone.freeze_model()
164
+ self.backbone.unfreeze_n_layers(n=8)
165
+ elif mode == "eval":
166
+ self.backbone.freeze_model()
167
+ self.backbone.requires_grad_(False)
168
+ self.backbone.eval()
169
+ elif mode == "train":
170
+ self.backbone.requires_grad_(True)
171
+ self.backbone.train()
172
+
173
+ self.neg_infinity = -1000000.0
174
+ self.T = config.T
175
+ # noise schedule for non-peptide bond tokens (default to log-linear)
176
+ self.noise = noise_schedule.get_noise(config)
177
+
178
+ # noise schedule for peptide bonds (log-polynomial)
179
+ self.bond_noise = noise_schedule.LogPolyNoise()
180
+ self.time_conditioning = self.config.time_conditioning
181
+ self.fast_forward_epochs = None
182
+ self.fast_forward_batches = None
183
+
184
+ self.gen_ppl_eval_model_name_or_path = self.config.eval.gen_ppl_eval_model_name_or_path
185
+ self.gen_ppl_metric = Perplexity()
186
+
187
+ self.lr = self.config.optim.lr
188
+ self.sampling_eps = self.config.training.sampling_eps
189
+
190
+ metrics = torchmetrics.MetricCollection({
191
+ 'nll': NLL(),
192
+ 'bpd': BPD(),
193
+ 'ppl': Perplexity(),
194
+ })
195
+ metrics.set_dtype(torch.float64)
196
+ self.train_metrics = metrics.clone(prefix='trainer/')
197
+ self.valid_metrics = metrics.clone(prefix='val/')
198
+ self.test_metrics = metrics.clone(prefix='test/')
199
+
200
+ ### FOR THE EXPANSION AND ROLLOUT STEP ###
201
+ def sample_finetuned_with_rnd(self, args, reward_model, pretrained, eps=1e-5):
202
+ num_steps = args.total_num_steps
203
+ B = args.batch_size
204
+ x_rollout = self.sample_prior(
205
+ B, args.seq_length).to(self.device)
206
+
207
+ log_rnd = torch.zeros(args.batch_size, device=self.device)
208
+
209
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
210
+ dt = (1 - eps) / num_steps
211
+
212
+ for i in range(num_steps):
213
+ t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
214
+
215
+ log_p, x_next, log_policy_step, log_pretrained_step = \
216
+ self.mcts_reverse_step(x_rollout, t=t, dt=dt, pretrained=pretrained)
217
+
218
+ log_rnd += log_pretrained_step - log_policy_step
219
+
220
+ x_rollout = x_next
221
+
222
+ # if mask token remains, fully unmask
223
+ mask_positions = (x_rollout == self.mask_index) # (B, L) bool
224
+
225
+ # does **any** mask remain in any sequence
226
+ any_mask_global = mask_positions.any().item() # true if mask remains
227
+ if any_mask_global:
228
+ log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
229
+
230
+ x_rollout = x_next
231
+
232
+ childSequences = self.tokenizer.batch_decode(x_rollout)
233
+
234
+ # change rewards for peptides
235
+ valid_x_final = []
236
+ validSequences = []
237
+ valid_log_rnd = []
238
+
239
+ for i in range(B):
240
+ # string sequence
241
+ childSeq = childSequences[i]
242
+
243
+ # check if the peptide is valid
244
+ if self.analyzer.is_peptide(childSeq):
245
+ valid_x_final.append(x_rollout[i])
246
+ validSequences.append(childSeq)
247
+ valid_log_rnd.append(log_rnd[i])
248
+
249
+ # compute multi-objective rewards
250
+ score_vectors = reward_model(input_seqs=validSequences)
251
+ scalar_rewards = np.sum(score_vectors, axis=-1)
252
+ scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=self.device)
253
+
254
+ print(f"scalar reward dim{len(scalar_rewards)}")
255
+ valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
256
+
257
+ log_rnd = valid_log_rnd + (scalar_rewards / args.alpha) # scale down by alpha
258
+ valid_x_final = torch.stack(valid_x_final, dim=0)
259
+
260
+ return valid_x_final, log_rnd, scalar_rewards
261
+
262
+ def sample_finetuned(self, args, reward_model, batch_size=None, dataframe=False, eps=1e-5):
263
+ torch.cuda.empty_cache()
264
+ self.backbone.eval()
265
+ self.noise.eval()
266
+ print(f"device:{self.device}")
267
+
268
+ if batch_size is None:
269
+ batch_size = args.batch_size
270
+
271
+ num_steps = args.total_num_steps
272
+ x_rollout = self.sample_prior(
273
+ batch_size,
274
+ args.seq_length).to(self.device, dtype=torch.long)
275
+
276
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
277
+ dt = torch.tensor((1 - eps) / num_steps, device=self.device)
278
+
279
+ for i in range(num_steps):
280
+ t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
281
+
282
+ log_p, x_next = self.single_reverse_step(x_rollout, t=t, dt=dt)
283
+
284
+ x_rollout = x_next
285
+ x_rollout = x_rollout.to(self.device)
286
+
287
+ # if mask token remains, fully unmask
288
+ mask_positions = (x_rollout == self.mask_index) # (B, L) bool
289
+
290
+ # does **any** mask remain in any sequence
291
+ any_mask_global = mask_positions.any().item() # true if mask remains
292
+ if any_mask_global:
293
+ log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
294
+
295
+ x_rollout = x_next
296
+ x_rollout = x_rollout.to(self.device)
297
+
298
+ childSequences = self.tokenizer.batch_decode(x_rollout)
299
+ valid_x_final = []
300
+ validSequences = []
301
+
302
+ for idx, seq in enumerate(childSequences):
303
+ if self.analyzer.is_peptide(seq):
304
+ valid_x_final.append(x_rollout[idx])
305
+ validSequences.append(seq)
306
+
307
+ valid_fraction = len(validSequences) / batch_size
308
+
309
+ if (len(validSequences) != 0):
310
+ # add scores to log
311
+ result = reward_model(input_seqs=validSequences)
312
+
313
+ # Handle both TD3B (returns tuple) and base ScoringFunctions (returns array directly)
314
+ if isinstance(result, tuple):
315
+ # TD3BRewardFunction returns (total_rewards, info) tuple
316
+ # info contains 'score_vectors' which is (N, 2) array [affinities, total_rewards]
317
+ total_rewards, info = result
318
+ affinity = info['affinities']
319
+ # TD3B doesn't compute sol/hemo/nf/permeability, set to zeros
320
+ sol = np.zeros_like(affinity)
321
+ hemo = np.zeros_like(affinity)
322
+ nf = np.zeros_like(affinity)
323
+ permeability = np.zeros_like(affinity)
324
+ else:
325
+ # Base scoring functions return (N, num_objectives) array directly
326
+ score_vectors = np.asarray(result)
327
+ if score_vectors.ndim == 1:
328
+ score_vectors = score_vectors[:, None]
329
+ average_scores = score_vectors.T
330
+
331
+ affinity = average_scores[0] if average_scores.shape[0] > 0 else np.zeros((0,))
332
+ sol = average_scores[1] if average_scores.shape[0] > 1 else np.zeros_like(affinity)
333
+ hemo = average_scores[2] if average_scores.shape[0] > 2 else np.zeros_like(affinity)
334
+ nf = average_scores[3] if average_scores.shape[0] > 3 else np.zeros_like(affinity)
335
+ permeability = average_scores[4] if average_scores.shape[0] > 4 else np.zeros_like(affinity)
336
+
337
+ else:
338
+ zeros = [0.0]
339
+
340
+ affinity = zeros
341
+ sol = zeros
342
+ hemo = zeros
343
+ nf = zeros
344
+ permeability = zeros
345
+
346
+ if dataframe:
347
+ df = pd.DataFrame({
348
+ "Peptide Sequence": validSequences,
349
+ "Binding Affinity": affinity if len(validSequences) else [0.0],
350
+ "Solubility": sol if len(validSequences) else [0.0],
351
+ "Hemolysis": hemo if len(validSequences) else [0.0],
352
+ "Nonfouling": nf if len(validSequences) else [0.0],
353
+ "Permeability": permeability if len(validSequences) else [0.0],
354
+ })
355
+ return x_rollout, affinity, sol, hemo, nf, permeability, valid_fraction, df
356
+
357
+ return x_rollout, affinity, sol, hemo, nf, permeability, valid_fraction
358
+
359
+ def compute_log_policy(self, token_array, x_next, t, dt, attn_mask=None):
360
+ torch.cuda.empty_cache()
361
+ self.backbone.eval()
362
+ self.noise.eval()
363
+
364
+ sigma_t, _ = self.noise(t)
365
+
366
+ if token_array.ndim == 1:
367
+ token_array = token_array.unsqueeze(0)
368
+
369
+ if x_next.ndim == 1:
370
+ x_next = x_next.unsqueeze(0)
371
+
372
+ if t.ndim > 1:
373
+ t = t.squeeze(-1)
374
+ assert t.ndim == 1
375
+
376
+ change_prob_t = t[:, None, None]
377
+ change_prob_s = (t - dt)[:, None, None]
378
+
379
+ assert change_prob_t.ndim == 3, change_prob_t.shape
380
+
381
+ if attn_mask is None:
382
+ attn_mask = torch.ones_like(token_array).to(self.device)
383
+
384
+ log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
385
+ p_x0 = log_p.exp()
386
+
387
+ assert change_prob_t.ndim == p_x0.ndim
388
+
389
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
390
+
391
+ # zero-masking probability
392
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
393
+
394
+ copy_flag = (token_array != self.mask_index)
395
+
396
+ assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
397
+ changed_mask = (~copy_flag)
398
+
399
+ # compute the per-sequence log-probability under the pretrained model
400
+ log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1)
401
+
402
+ unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_policy_token.dtype)
403
+ log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
404
+
405
+ # returns:
406
+ # log_policy_step (B, ) log probability x_next tokens under policy
407
+ if log_policy_step.ndim == 1:
408
+ log_policy_step = log_policy_step.squeeze(0)
409
+
410
+ return log_policy_step
411
+
412
+
413
+ def single_reverse_step(self, token_array, t, dt, p_x0=None, attn_mask=None):
414
+ torch.cuda.empty_cache()
415
+ dev = self.device
416
+ self.backbone.to(dev).eval()
417
+ self.noise.eval()
418
+
419
+ t = t.to(dev)
420
+ dt = torch.as_tensor(dt, device=dev, dtype=t.dtype)
421
+ assert self.config.noise.type == 'loglinear'
422
+ sigma_t, _ = self.noise(t)
423
+ sigma_t = sigma_t.to(dev)
424
+
425
+ if t.ndim > 1:
426
+ t = t.squeeze(-1)
427
+ assert t.ndim == 1
428
+
429
+ change_prob_t = t[:, None, None]
430
+ change_prob_s = (t - dt)[:, None, None]
431
+
432
+ assert change_prob_t.ndim == 3, change_prob_t.shape
433
+
434
+ if attn_mask is None:
435
+ attn_mask = torch.ones_like(token_array, device=dev, dtype=torch.long)
436
+ else:
437
+ attn_mask = attn_mask.to(dev)
438
+
439
+ if p_x0 is None:
440
+ log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
441
+ p_x0 = log_p.exp()
442
+ else:
443
+ # ensure provided p_x0 is on dev
444
+ log_p = None
445
+ p_x0 = p_x0.to(dev)
446
+
447
+ assert change_prob_t.ndim == p_x0.ndim
448
+
449
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
450
+
451
+ # zero-masking probability
452
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
453
+
454
+ x_changed = _sample_categorical(q_xs)
455
+ if x_changed.device != dev or x_changed.dtype != token_array.dtype:
456
+ x_changed = x_changed.to(dev, dtype=token_array.dtype)
457
+
458
+ copy_flag = (token_array != self.mask_index)
459
+
460
+ int_copy_flag = copy_flag.to(token_array.dtype)
461
+ x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
462
+
463
+ # returns:
464
+ # log_p (B, L, D) log probabilties of each token under the policy model
465
+ # x_next (B, L) next sequences
466
+ return log_p, x_next
467
+
468
+
469
+ def single_noise_removal(self, token_array, t, dt, p_x0=None, attn_mask=None):
470
+ torch.cuda.empty_cache()
471
+ self.backbone.eval()
472
+ self.noise.eval()
473
+
474
+ assert self.config.noise.type == 'loglinear'
475
+ sigma_t, _ = self.noise(t)
476
+
477
+ if t.ndim > 1:
478
+ t = t.squeeze(-1)
479
+ assert t.ndim == 1
480
+
481
+ change_prob_t = t[:, None, None]
482
+ change_prob_s = (t - dt)[:, None, None]
483
+
484
+ assert change_prob_t.ndim == 3, change_prob_t.shape
485
+
486
+ if attn_mask is None:
487
+ attn_mask = torch.ones_like(token_array).to(self.device)
488
+
489
+ if p_x0 is None:
490
+ log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
491
+ p_x0 = log_p.exp()
492
+
493
+ assert change_prob_t.ndim == p_x0.ndim
494
+
495
+ # changed for noise removal
496
+ p_x0 = p_x0.clone()
497
+ p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask
498
+ p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK
499
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
500
+
501
+ x_changed = _sample_categorical(q_xs)
502
+
503
+ copy_flag = (token_array != self.mask_index)
504
+
505
+ int_copy_flag = copy_flag.to(token_array.dtype)
506
+ x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
507
+
508
+ # returns:
509
+ # log_p (B, L, D) log probabilties of each token under the policy model
510
+ # x_next (B, L) next sequences
511
+ return log_p, x_next
512
+
513
+ def mcts_reverse_step(self, token_array, t, dt, pretrained, p_x0=None, attn_mask=None):
514
+ torch.cuda.empty_cache()
515
+ self.backbone.eval()
516
+ self.noise.eval()
517
+ assert self.config.noise.type == 'loglinear'
518
+ sigma_t, _ = self.noise(t)
519
+
520
+ if t.ndim > 1:
521
+ t = t.squeeze(-1)
522
+ assert t.ndim == 1
523
+
524
+ change_prob_t = t[:, None, None]
525
+ change_prob_s = (t - dt)[:, None, None]
526
+
527
+ assert change_prob_t.ndim == 3, change_prob_t.shape
528
+
529
+ if attn_mask is None:
530
+ attn_mask = torch.ones_like(token_array).to(self.device)
531
+
532
+ if p_x0 is None:
533
+ log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
534
+ p_x0 = log_p.exp()
535
+
536
+ assert change_prob_t.ndim == p_x0.ndim
537
+
538
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
539
+
540
+ # zero-masking probability
541
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
542
+
543
+ x_changed = _sample_categorical(q_xs)
544
+
545
+ copy_flag = (token_array != self.mask_index)
546
+
547
+ int_copy_flag = copy_flag.to(token_array.dtype)
548
+ x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
549
+
550
+ # compute the log-probability under pretrained model at each step
551
+ with torch.no_grad():
552
+ # pretrained should output log-probs over vocab at each position given the *parent* (masked) input
553
+ log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
554
+
555
+ # log-prob of the *sampled token* at each position
556
+ log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
557
+
558
+ # sum only over the sites actually sampled this step (i.e., where parent was mask)
559
+
560
+ assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
561
+ changed_mask = (~copy_flag)
562
+ # mask of tokens that were unmasked in this step
563
+ unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype)
564
+
565
+ log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
566
+
567
+ # compute the per-sequence log-probability under the pretrained model
568
+ log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
569
+ log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
570
+
571
+ # returns:
572
+ # log_p (B, L, D) log probabilties of each token under the policy model
573
+ # x_next (B, L) next sequences
574
+ # log_policy_step (B, ) log probability of all unmasked tokens under policy
575
+ # log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
576
+ return log_p, x_next, log_policy_step, log_pretrained_step
577
+
578
+ def mcts_noise_removal(self, token_array, t, dt, pretrained, p_x0=None, attn_mask=None):
579
+ torch.cuda.empty_cache()
580
+ self.backbone.eval()
581
+ self.noise.eval()
582
+
583
+ assert self.config.noise.type == 'loglinear'
584
+ sigma_t, _ = self.noise(t)
585
+
586
+ if t.ndim > 1:
587
+ t = t.squeeze(-1)
588
+ assert t.ndim == 1
589
+
590
+ change_prob_t = t[:, None, None]
591
+ change_prob_s = (t - dt)[:, None, None]
592
+
593
+ assert change_prob_t.ndim == 3, change_prob_t.shape
594
+
595
+ if attn_mask is None:
596
+ attn_mask = torch.ones_like(token_array).to(self.device)
597
+
598
+ if p_x0 is None:
599
+ log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
600
+ p_x0 = log_p.exp()
601
+
602
+ assert change_prob_t.ndim == p_x0.ndim
603
+
604
+ # changed for noise removal
605
+ p_x0 = p_x0.clone()
606
+ p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask
607
+ p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK
608
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
609
+
610
+ x_changed = _sample_categorical(q_xs)
611
+
612
+ copy_flag = (token_array != self.mask_index)
613
+
614
+ int_copy_flag = copy_flag.to(token_array.dtype)
615
+ x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
616
+
617
+ # compute the log-probability under pretrained model at each step
618
+ with torch.no_grad():
619
+ # pretrained should output log-probs over vocab at each position given the *parent* (masked) input
620
+ log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
621
+
622
+ # log-prob of the *sampled token* at each position
623
+ log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
624
+
625
+ # sum only over the sites actually sampled this step (i.e., where parent was mask)
626
+
627
+ assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
628
+ changed_mask = (~copy_flag)
629
+ # mask of tokens that were unmasked in this step
630
+ unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype)
631
+
632
+ log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
633
+
634
+ # compute the per-sequence log-probability under the pretrained model
635
+ log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
636
+ log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
637
+
638
+ # returns:
639
+ # log_p (B, L, D) log probabilties of each token under the policy model
640
+ # x_next (B, L) next sequences
641
+ # log_policy_step (B, ) log probability of all unmasked tokens under policy
642
+ # log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
643
+ return log_p, x_next, log_policy_step, log_pretrained_step
644
+
645
+ # first step in expansion
646
+ def batch_mcts_reverse_step(self, token_array, t, dt, batch_size, pretrained, p_x0=None, attn_mask=None):
647
+ torch.cuda.empty_cache()
648
+ self.backbone.eval()
649
+ self.noise.eval()
650
+
651
+ assert self.config.noise.type == 'loglinear'
652
+ sigma_t, _ = self.noise(t)
653
+
654
+ if t.ndim > 1:
655
+ t = t.squeeze(-1)
656
+ assert t.ndim == 1
657
+
658
+ change_prob_t = t[:, None, None]
659
+ change_prob_s = (t - dt)[:, None, None]
660
+
661
+ assert change_prob_t.ndim == 3, change_prob_t.shape
662
+
663
+ if token_array.dim() == 1:
664
+ token_array = token_array.unsqueeze(0)
665
+
666
+ # expand to match (num_children, L)
667
+ if attn_mask is None:
668
+ attn_mask = torch.ones_like(token_array).to(self.device)
669
+
670
+ token_array = token_array.to(self.device)
671
+ sigma_t = sigma_t.to(self.device)
672
+
673
+ # ====== INPUT VALIDATION for batch_mcts_reverse_step ======
674
+ token_min = token_array.min().item()
675
+ token_max = token_array.max().item()
676
+ if token_min < 0 or token_max >= self.vocab_size:
677
+ raise ValueError(
678
+ f"batch_mcts_reverse_step: Invalid token IDs in token_array: "
679
+ f"min={token_min}, max={token_max}, vocab_size={self.vocab_size}"
680
+ )
681
+
682
+ if p_x0 is None:
683
+ log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
684
+ p_x0 = log_p.exp()
685
+
686
+ assert change_prob_t.ndim == p_x0.ndim
687
+
688
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
689
+
690
+ # zero-masking probability
691
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
692
+
693
+ # repeat the parent token along the first dimension which will be unmasked into distinct sequences
694
+ token_array_expanded = token_array.repeat(batch_size, 1)
695
+
696
+ if self.config.mcts.sampling == 0:
697
+ x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
698
+ else:
699
+ x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
700
+
701
+ copy_flag = (token_array_expanded != self.mask_index)
702
+
703
+ int_copy_flag = copy_flag.to(token_array.dtype)
704
+ x_children = int_copy_flag * token_array_expanded + (1 - int_copy_flag) * x_changed
705
+
706
+
707
+ # compute the log-probability under pretrained model at each step
708
+ with torch.no_grad():
709
+ # pretrained should output log-probs over vocab at each position given the *parent* (masked) input
710
+ log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
711
+
712
+ # expand to match the shape of x_children
713
+ log_pre = log_pre.repeat(batch_size, 1, 1)
714
+
715
+ # log-prob of the *sampled token* at each position
716
+ log_pre_token = log_pre.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
717
+
718
+ # sum only over the sites actually sampled this step (i.e., where parent was mask)
719
+
720
+ assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
721
+ changed_mask = (~copy_flag)
722
+ # mask of tokens that were unmasked in this step
723
+ unmasked_this_step = (changed_mask & (x_children != self.mask_index)).to(log_pre_token.dtype)
724
+
725
+ log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
726
+
727
+ # compute the per-child log-probability under the pretrained model
728
+ log_p = log_p.repeat(batch_size, 1, 1)
729
+ log_policy_token = log_p.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # (B, L) probability of each chosen token
730
+ #print(log_policy_token)
731
+ log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
732
+
733
+ # returns:
734
+ # log_p (B, L, D) log probabilties of each token under the policy model
735
+ # x_children (B, L) child sequences
736
+ # log_policy_step (B, ) log probability of all unmasked tokens under policy
737
+ # log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
738
+ return log_p, x_children, log_policy_step, log_pretrained_step
739
+
740
+
741
+ def compute_invalid_loss(self, logits, k=None, temp=None):
742
+ """
743
+ Penalizes logits that produce invalid sequences using the `is_peptide` function,
744
+ scaling penalties inversely with token probabilities.
745
+
746
+ Args:
747
+ logits: Tensor of shape [batch_size, seq_len, vocab_size].
748
+ k: Number of samples for Gumbel-Rao.
749
+ temp: Temperature for softmax.
750
+
751
+ Returns:
752
+ loss: A scalar tensor representing the total loss for invalid sequences.
753
+ """
754
+
755
+ #samples = self.gumbel_rao(logits, k=k, temp=temp) # (batch_size, seq_len, vocab_size)
756
+
757
+ # Convert logits to sequences using the tokenizer
758
+ batch_token_ids = logits.argmax(dim=-1).to(self.device) # (batch_size, seq_len)
759
+ sampled_sequences = self.tokenizer.batch_decode(batch_token_ids)
760
+
761
+ # Check validity of each sampled sequence (not differentiable)
762
+ penalties = torch.tensor(
763
+ [1 if not self.analyzer.is_peptide(seq) else 0 for seq in sampled_sequences],
764
+ dtype=torch.float32,
765
+ device=self.device
766
+ )
767
+ #print(penalties)
768
+
769
+ # Compute probabilities for each token (batch_size, seq_length)
770
+ sampled_probs = torch.softmax(logits, dim=-1).gather(dim=-1, index=batch_token_ids.unsqueeze(-1)).squeeze(-1).to(self.device)
771
+
772
+ # scale penalties by softmax probability of sampled tokens
773
+ scaled_penalty = penalties[:, None] * sampled_probs # (batch_size, seq_length)
774
+
775
+ return scaled_penalty.to(self.device)
776
+
777
+ ### DIFFUSION LOSS ###
778
+
779
+ def sample_t(self, n, device):
780
+ """
781
+ Sample random time steps for batch training
782
+ """
783
+ # sample values uniformly at random from [0, 1)
784
+ eps_t = torch.rand(n, device=device)
785
+ # antithetic sampling: reduce variance by pairing each sample with complementary sample
786
+ if self.config.training.antithetic_sampling:
787
+ # compute interval between sampled time steps
788
+ offset = torch.arange(n, device=device) / n
789
+ # ensure that each eps value is evenly spaced between [0, 1)
790
+ eps_t = ((eps_t / n) + offset) % 1
791
+
792
+ # ensures values are not exactly 0 or 1
793
+ t = (1 - self.config.training.sampling_eps) * eps_t + self.config.training.sampling_eps
794
+
795
+ return t
796
+
797
+ """def mask_samples(self, x0, mask_prob):
798
+
799
+ # generate array of values in range [0, 1] uniformly at random
800
+ # will be used to determine which tokens are masked
801
+ mask_indices = torch.rand(* x0.shape, device=x0.device) # (batch_size, L)
802
+
803
+ # select tokens to mask if the random value in mask_indices is less than mask_prob
804
+ # this will mask approximately the fraction of tokens indicated by mask_prob
805
+ zt = torch.where(mask_indices < mask_prob, self.mask_index, x0)
806
+
807
+ return zt"""
808
+
809
+ def q_xt(self, x, mask_prob):
810
+ """Computes the noisy sample xt.
811
+
812
+ Args:
813
+ x: int torch.Tensor with shape (batch_size,
814
+ diffusion_model_input_length), input.
815
+ move_chance: float torch.Tensor with shape (batch_size, 1).
816
+ """
817
+
818
+ actual_seq_length = (x != 0).sum(dim=-1, keepdim=True)
819
+ #print(actual_seq_length)
820
+
821
+ max_mask_length = (actual_seq_length * 0.75).long()
822
+
823
+ mask_indices = torch.rand(*x.shape, device=x.device) < mask_prob
824
+
825
+ restricted_move_indices = torch.zeros_like(mask_indices, dtype=torch.bool)
826
+
827
+ for i in range(x.shape[0]):
828
+ true_positions = torch.where(mask_indices[i])[0]
829
+ if len(true_positions) > max_mask_length[i]:
830
+ selected_positions = true_positions[:max_mask_length[i].item()]
831
+ restricted_move_indices[i, selected_positions] = True
832
+ else:
833
+ restricted_move_indices[i] = mask_indices[i]
834
+
835
+ xt = torch.where(restricted_move_indices, self.tokenizer.mask_token_id, x)
836
+
837
+ return xt
838
+
839
+
840
+ def sample_prior(self, *batch_dims):
841
+ """
842
+ Returns array of fully masked sequences with same shape as input
843
+ """
844
+ return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64)
845
+
846
+
847
+ ### COMPUTING LOSS ###
848
+
849
+ def compute_diffusion_loss(self, model_output, xt, x0, t):
850
+ """
851
+ Computes diffusion loss term in ELBO
852
+ (evaluates how accurately the model predicts the token probabilities at each time step)
853
+
854
+ Inputs:
855
+ - model_output: [sequence length, vocab size, vocab size] array of logits for each token at each sequence position
856
+ - zt: corrupted version of original input x0 at timestep t
857
+ - x0: original input sequence
858
+ - t: timestep
859
+ """
860
+ # compute interval between each timestep
861
+ dt = 1 / self.T
862
+
863
+ # compute vectorized alpha scaling terms for the logits at timestep s and t
864
+ alpha_t = 1 - t + torch.zeros_like(x0)
865
+ # s = t - dt
866
+ alpha_s = 1 - (t - dt) + torch.zeros_like(x0)
867
+
868
+ # gather vector of log-probabilities for each token in x0
869
+ # log<x_theta, x>
870
+ log_x_theta_at_x0 = torch.gather(model_output, -1, x0[:, :, None]) # shape (B, L, vocab_size)
871
+ # gather log-probabillities for assigning a masked token at each position in the sequence at time t
872
+ # log<x_theta, m>
873
+ log_x_theta_at_m = model_output[:, :, self.mask_index]
874
+ # obtain non-log probability of assigning a masked token
875
+ # <xt, m>
876
+ x_theta_at_m = log_x_theta_at_m.exp()
877
+
878
+ # first term of diffusion loss
879
+ term_1_coef = dt / t
880
+ term_1_log_numerator = torch.log((alpha_t * x_theta_at_m) / t + 1)
881
+ term_1_log_denom = log_x_theta_at_x0
882
+
883
+ # second term of diffusion loss
884
+ term_2_coef = 1 - (dt / t)
885
+ term_2_log_numerator = term_1_log_numerator
886
+ term_2_log_denom = torch.log((alpha_s * x_theta_at_m) / (t - dt) + 1)
887
+
888
+ L_vb_masked = (term_1_coef * (term_1_log_numerator - term_1_log_denom) +
889
+ term_2_coef * (term_2_log_numerator - term_2_log_denom))
890
+
891
+ # multiply by <zt, m> term
892
+ L_vb = L_vb_masked * (xt == self.mask_index)
893
+
894
+ # scale by T and return
895
+ return self.T * L_vb
896
+
897
+ def _forward_pass_diffusion(self, x0, attn_mask, bond_mask=None, mask=None):
898
+ """
899
+ Training reverse diffusion model x_theta to reconstruct samples x0
900
+
901
+ bond_mask: (batch, seq_length)
902
+ """
903
+ # randomly sample time steps to start the denoising process for each x0 in batch
904
+ t = self.sample_t(x0.shape[0], self.device)
905
+
906
+ # if we are training the intermediate transition blocks
907
+ if self.T > 0:
908
+ # scale by total timesteps T and cast to integer
909
+ t = (t * self.T).to(torch.int)
910
+ # scale down by T to get a multiple of 1/T
911
+ t = t / self.T
912
+ # add 1/T to ensure no 0 values
913
+ t += (1 / self.T)
914
+
915
+ # get noise and rate of noise at timestep t
916
+ # sigma = -log(1-t); dsigma = 1 / (1-t)
917
+ sigma, dsigma = self.noise(t)
918
+ time_conditioning = sigma[:, None]
919
+
920
+ # Get masking probabilities for all tokens for each batch
921
+ # log-linear: 1 - alpha = t
922
+ base_mask_prob = 1 - torch.exp(-sigma[:, None]) # (batch_size, L)
923
+
924
+ if self.config.noise.state_dependent and (bond_mask is not None):
925
+ # log-polynomial masking schedule: alpha = 1 - t^w
926
+ # bond_sigma = -log(1-t^w) for w = 3 (default)
927
+ # bond_dsigma = -wt^(w-1) / (1-t^w)
928
+ bond_sigma, bond_dsigma = self.bond_noise(t) # scalar
929
+ # expand dimensions for broadcasting to (B, L)
930
+ bond_sigma = bond_sigma[:, None]
931
+ bond_dsigma = bond_dsigma[:, None]
932
+ sigma = sigma[:, None]
933
+ dsigma = dsigma[:, None]
934
+
935
+ # compute masking probability for peptide bonds 1 - bond_alpha = t^w
936
+ bond_mask_prob = 1 - torch.exp(-bond_sigma).to(self.device)
937
+ # piece together (B, L) tensor with modified masking prob at peptide-bond locations
938
+ mask_prob = torch.where(bond_mask == 1, bond_mask_prob, base_mask_prob).to(self.device)
939
+ #print(mask_prob)
940
+ dsigma = torch.where(bond_mask == 1, bond_dsigma, dsigma).to(self.device)
941
+ sigma = torch.where(bond_mask == 1, bond_sigma, sigma).to(self.device)
942
+ else:
943
+ mask_prob = base_mask_prob.to(self.device)
944
+
945
+ # get masked samples at different timesteps
946
+ if mask is None:
947
+ zt = self.q_xt(x0, mask_prob).to(self.device)
948
+ else:
949
+ zt = x0.where(mask==1, torch.full_like(x0, self.mask_index)).to(self.device)
950
+
951
+ model_output = self.forward(zt, attn_mask=attn_mask.to(self.device), sigma=time_conditioning).to(self.device)
952
+
953
+ # debugging
954
+ assert not torch.isnan(model_output).any()
955
+ assert model_output.is_cuda
956
+ utils.print_nans(model_output, 'model_output')
957
+
958
+ # compute invalid loss
959
+ invalid_loss = self.compute_invalid_loss(logits=model_output).to(self.device) # (B, L)
960
+ #print(invalid_loss)
961
+
962
+ if self.T > 0:
963
+ # compute diffusion loss
964
+ diffusion_loss = self.compute_diffusion_loss(model_output, zt, x0, t)
965
+ return diffusion_loss
966
+
967
+ # compute loss for the final that converts from z0 to x0
968
+ # -log(p_theta)
969
+ # get (batch_size, L) array of log-probabilities
970
+ log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1).to(self.device) # (B, L)
971
+
972
+ if self.config.noise.state_dependent and (bond_mask is not None):
973
+ return (-log_p_theta * (dsigma / torch.expm1(sigma)) + invalid_loss).to(self.device)
974
+ else:
975
+ return ((-log_p_theta * (dsigma / torch.expm1(sigma))[:, None]) + invalid_loss).to(self.device)
976
+
977
+ def _loss(self, x0, attn_mask, bond_mask=None, mask=None):
978
+ loss = self._forward_pass_diffusion(x0, attn_mask, bond_mask, mask)
979
+
980
+ # negative log loss
981
+ nlls = loss * attn_mask
982
+
983
+ # count number of tokens
984
+ num_tokens = attn_mask.sum()
985
+
986
+ # compute batch loss
987
+ batch_nll = nlls.sum()
988
+ # compute per token loss
989
+ token_nll = batch_nll / num_tokens
990
+ # return losses
991
+ return Loss(loss = token_nll.to(self.device), nlls = nlls.to(self.device), attn_mask = attn_mask.to(self.device))
992
+
993
+ def _compute_loss(self, batch, prefix, bond_mask=None):
994
+
995
+ attn_mask = batch['attention_mask'].to(self.device)
996
+
997
+ if 'mask' in batch:
998
+ mask = batch['mask'].to(self.device)
999
+ else:
1000
+ mask = None
1001
+
1002
+ if 'bond_mask' in batch:
1003
+ bond_mask = batch['bond_mask'].to(self.device)
1004
+ else:
1005
+ bond_mask = None
1006
+
1007
+ losses = self._loss(batch['input_ids'].to(self.device), attn_mask, bond_mask, mask)
1008
+ loss = losses.loss
1009
+
1010
+ if prefix == 'train':
1011
+ self.train_metrics.update(
1012
+ losses.nlls.to(self.device),
1013
+ losses.attn_mask.to(self.device)
1014
+ )
1015
+ metrics = self.train_metrics
1016
+ elif prefix == 'val':
1017
+ self.valid_metrics.update(
1018
+ losses.nlls.to(self.device),
1019
+ losses.attn_mask.to(self.device)
1020
+ )
1021
+ metrics = self.valid_metrics
1022
+ elif prefix == 'test':
1023
+ self.test_metrics.update(losses.nlls, losses.attn_mask)
1024
+ metrics = self.test_metrics
1025
+ else:
1026
+ raise ValueError(f'Invalid prefix: {prefix}')
1027
+
1028
+ self.log_dict(metrics,
1029
+ on_step=False,
1030
+ on_epoch=True,
1031
+ sync_dist=True)
1032
+
1033
+ return loss
1034
+
1035
+
1036
+ ### SAMPLING ###
1037
+
1038
+ def generate_from_masked(self, num_samples=None, seq_length=None, sample_steps=128, eps=1e-5):
1039
+ # get number of timesteps
1040
+ if sample_steps is None:
1041
+ sample_steps = self.config.sampling.steps
1042
+
1043
+ if seq_length is None:
1044
+ seq_length = self.config.sampling.seq_length
1045
+
1046
+ # sample fully masked sequences
1047
+ z = self.sample_prior(num_samples, seq_length).to(self.device)
1048
+
1049
+ # create vector of sample_steps timesteps
1050
+ timesteps = torch.linspace(1, eps, sample_steps + 1, device=self.device)
1051
+
1052
+ # compute interval between timesteps
1053
+ dt = (1 - eps) / sample_steps
1054
+
1055
+ for i in range(sample_steps):
1056
+ t = timesteps[i] * torch.ones(z.shape[0], 1, device=self.device)
1057
+
1058
+ z = self.single_reverse_step(z, t, dt)
1059
+
1060
+ return z
1061
+
1062
+
1063
+ ### SAMPLING STEP ###
1064
+ """
1065
+ def single_reverse_step(self, zt, t, dt, attn_mask=None):
1066
+ # get sigma values that determine masking prob
1067
+ sigma_t, _ = self.noise(t)
1068
+ sigma_s, _ = self.noise(t - dt)
1069
+
1070
+ # reshape sigmas
1071
+ if sigma_t.ndim > 1:
1072
+ sigma_t = sigma_t.squeeze(-1)
1073
+ if sigma_s.ndim > 1:
1074
+ sigma_s = sigma_s.squeeze(-1)
1075
+ assert sigma_t.ndim == 1, sigma_t.shape
1076
+ assert sigma_s.ndim == 1, sigma_s.shape
1077
+
1078
+ # compute masking probabilities for each timestep
1079
+ change_prob_t = 1 - torch.exp(-sigma_t)
1080
+ change_prob_s = 1 - torch.exp(-sigma_s)
1081
+
1082
+ # expand dimensions
1083
+ change_prob_t = change_prob_t[:, None, None]
1084
+ change_prob_s = change_prob_s[:, None, None]
1085
+
1086
+ # get prodiction model that outputs token probabilities
1087
+ log_p_x0 = self.forward(zt, attn_mask=attn_mask, sigma=sigma_t)
1088
+
1089
+ # check dimensions match
1090
+ assert change_prob_t.ndim == log_p_x0.ndim
1091
+
1092
+ # compute reverse diffusion probability of being unmasked at timestep s
1093
+ # (sigma_s - sigma_t)*x_theta
1094
+ q_zs = log_p_x0.exp() * (change_prob_t - change_prob_s)
1095
+
1096
+ # compute reverse diffusion probability of remaining masked at timestep s
1097
+ # (1 - sigma_s)*m
1098
+ q_zs[:, :, self.mask_index] = change_prob_s[:, :, 0]
1099
+
1100
+ # sample sequence at timestep s from categorical distribution of q_zs
1101
+ z_changed = _sample_categorical(q_zs)
1102
+
1103
+ copy_flag = (zt != self.mask_index).to(zt.dtype)
1104
+ return (copy_flag * zt) + ((1 - copy_flag) * z_changed)"""
1105
+
1106
+ def cached_reverse_step(self, x, t, dt, p_x0=None, attn_mask=None):
1107
+ assert self.config.noise.type == 'loglinear'
1108
+ sigma_t, _ = self.noise(t)
1109
+
1110
+ if t.ndim > 1:
1111
+ t = t.squeeze(-1)
1112
+ assert t.ndim == 1
1113
+
1114
+ change_prob_t = t[:, None, None]
1115
+ change_prob_s = (t - dt)[:, None, None]
1116
+
1117
+ assert change_prob_t.ndim == 3, change_prob_t.shape
1118
+
1119
+ if p_x0 is None:
1120
+ p_x0 = self.forward(x, attn_mask=attn_mask, sigma=sigma_t).exp()
1121
+
1122
+ assert change_prob_t.ndim == p_x0.ndim
1123
+
1124
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
1125
+
1126
+ # zero-masking probability
1127
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
1128
+
1129
+ x_changed = _sample_categorical(q_xs)
1130
+
1131
+ copy_flag = (x != self.mask_index).to(x.dtype)
1132
+
1133
+ return p_x0, copy_flag * x + (1 - copy_flag) * x_changed
1134
+
1135
+ # first step in expansion
1136
+ def batch_cached_reverse_step(self, token_array, t, dt, batch_size, p_x0=None, attn_mask=None):
1137
+ """
1138
+ Generates batch_size different samples from the same starting point for the
1139
+ first expansion step of MCTS
1140
+ """
1141
+
1142
+ assert self.config.noise.type == 'loglinear'
1143
+ sigma_t, _ = self.noise(t)
1144
+
1145
+ if t.ndim > 1:
1146
+ t = t.squeeze(-1)
1147
+ assert t.ndim == 1
1148
+
1149
+ change_prob_t = t[:, None, None]
1150
+ change_prob_s = (t - dt)[:, None, None]
1151
+
1152
+ assert change_prob_t.ndim == 3, change_prob_t.shape
1153
+
1154
+ if token_array.dim() == 1:
1155
+ token_array = token_array.unsqueeze(0)
1156
+ #token_array = token_array.repeat(batch_size, 1)
1157
+
1158
+ attn_mask = torch.ones_like(token_array).to(self.device)
1159
+
1160
+ if p_x0 is None:
1161
+ p_x0 = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t).exp()
1162
+
1163
+ assert change_prob_t.ndim == p_x0.ndim
1164
+
1165
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
1166
+
1167
+ # zero-masking probability
1168
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
1169
+
1170
+ # repeat the parent token along the first dimension which will be unmasked into distinct sequences
1171
+ token_array = token_array.repeat(batch_size, 1)
1172
+
1173
+ if self.config.mcts.sampling == 0:
1174
+ x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
1175
+ else:
1176
+ x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
1177
+
1178
+ copy_flag = (token_array != self.mask_index).to(token_array.dtype)
1179
+
1180
+ return p_x0, copy_flag * token_array + (1 - copy_flag) * x_changed
1181
+
1182
+ def _process_sigma(self, sigma):
1183
+ if sigma.ndim > 1:
1184
+ sigma = sigma.squeeze(-1)
1185
+ if not self.time_conditioning:
1186
+ sigma = torch.zeros_like(sigma)
1187
+ assert sigma.ndim == 1, sigma.shape
1188
+ return sigma
1189
+
1190
+ def forward(self, zt, attn_mask, sigma):
1191
+ """
1192
+ Predicts the token log-probabilities from zt at time t with noise schedule sigma
1193
+ """
1194
+ sigma = self._process_sigma(sigma)
1195
+
1196
+ # ====== INPUT VALIDATION (CPU-side) ======
1197
+ # Check 1: Token IDs must be in valid range [0, vocab_size - 1]
1198
+ zt_min = zt.min().item()
1199
+ zt_max = zt.max().item()
1200
+ if zt_min < 0 or zt_max >= self.vocab_size:
1201
+ raise ValueError(
1202
+ f"Invalid token IDs in zt: min={zt_min}, max={zt_max}, "
1203
+ f"vocab_size={self.vocab_size}. Token IDs must be in [0, {self.vocab_size-1}]"
1204
+ )
1205
+
1206
+ # Check 2: Sequence length must not exceed model's max_position_embeddings
1207
+ seq_len = zt.shape[1]
1208
+ max_pos = getattr(self.backbone.model.config, 'max_position_embeddings', 512)
1209
+ if seq_len > max_pos:
1210
+ raise ValueError(
1211
+ f"Sequence length {seq_len} exceeds model's max_position_embeddings {max_pos}. "
1212
+ f"Input shape: {zt.shape}"
1213
+ )
1214
+
1215
+ with torch.cuda.amp.autocast(dtype=torch.float32):
1216
+ logits = self.backbone.forward(input_ids=zt, attn_mask=attn_mask).to(self.device)
1217
+
1218
+ return self.subs_parameterization(logits, zt)
1219
+
1220
+ def subs_parameterization(self, logits, zt):
1221
+ """
1222
+ Updates reverse diffusion logits based on SUBS parameterization:
1223
+ - zero masking probabilities: -infinity probability of being masked during reverse diffusion
1224
+ - carry-over unmasking: unmasked input tokens remain unchanged during reverse diffusion
1225
+
1226
+ Args:
1227
+ logits: vector of token probabilities for unmasking masked tokens
1228
+ zt: partially unmasked sequence at current timestep
1229
+ """
1230
+ logits[:, :, self.mask_index] += self.neg_infinity # [sequence index, current token, next token]
1231
+
1232
+
1233
+ logits = (logits - torch.logsumexp(logits, dim=-1, keepdim=True)).to(self.device)
1234
+
1235
+
1236
+ unmasked_indices = (zt != self.mask_index).to(self.device) # shape: [200, seq_length]
1237
+ batch_idx, seq_idx = torch.where(unmasked_indices) # Get explicit indices
1238
+ batch_idx = batch_idx.to(self.device)
1239
+ seq_idx = seq_idx.to(self.device)
1240
+ tokens = zt[batch_idx, seq_idx].to(self.device) # Get the tokens at those positions
1241
+
1242
+ #assert logits.is_contiguous(), "logits tensor is not contiguous"
1243
+ #assert unmasked_indices.shape == zt.shape, "same shape"
1244
+ #assert not torch.isnan(logits).any(), "NaN values found in logits"
1245
+ #assert tokens.max() < logits.shape[-1], "token indices out of bounds"
1246
+ #assert batch_idx.max() < logits.shape[0], "batch index out of bounds"
1247
+ #assert seq_idx.max() < logits.shape[1], "seq index out of bounds"
1248
+ #assert batch_idx.device == seq_idx.device == logits.device == tokens.device, "device inconsistent"
1249
+
1250
+ logits[unmasked_indices] = self.neg_infinity # Set everything to -inf first
1251
+
1252
+ # CRITICAL FIX: Clip token indices to valid vocab range to prevent index out of bounds
1253
+ # This can happen with variable-length sequences or corrupted tokens
1254
+ tokens_for_indexing = zt[unmasked_indices]
1255
+ valid_token_mask = tokens_for_indexing < logits.shape[-1]
1256
+
1257
+ if not valid_token_mask.all():
1258
+ # Log warning about invalid tokens
1259
+ import logging
1260
+ logger = logging.getLogger(__name__)
1261
+ invalid_count = (~valid_token_mask).sum().item()
1262
+ max_invalid_token = tokens_for_indexing[~valid_token_mask].max().item() if invalid_count > 0 else 0
1263
+ logger.warning(f"Found {invalid_count} invalid token indices (max={max_invalid_token}, vocab_size={logits.shape[-1]}). Clipping to valid range.")
1264
+
1265
+ # Clip to valid range
1266
+ tokens_for_indexing = torch.clamp(tokens_for_indexing, 0, logits.shape[-1] - 1)
1267
+
1268
+ logits[unmasked_indices, tokens_for_indexing] = 0 # Set only the specific token positions to 0
1269
+ # return logits with SUBS parameterization
1270
+ return logits.to(self.device)
1271
+
1272
+ """SAMPLING"""
1273
+ @torch.no_grad()
1274
+ def _sample(self, num_steps=None, eps=1e-5, x_input=None):
1275
+ """
1276
+ Generate samples
1277
+ """
1278
+ batch_size_per_gpu = self.config.eval.perplexity_batch_size
1279
+
1280
+ if num_steps is None:
1281
+ num_steps = self.config.sampling.steps
1282
+
1283
+ if x_input is not None:
1284
+ x = x_input['input_ids'].to(self.device)
1285
+ attn_mask = x_input['attention_mask'].to(self.device)
1286
+ else:
1287
+ x = self.sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device)
1288
+ attn_mask = torch.ones_like(x).to(self.device)
1289
+
1290
+
1291
+ timesteps = torch.linspace(1, eps, num_steps+1, device=self.device)
1292
+ dt = (1 - eps) / num_steps
1293
+ p_x0_cache = None
1294
+ generation_history = [] # used to track which tokens are unmasked
1295
+
1296
+ for i in range(num_steps):
1297
+ t = timesteps[i] * torch.ones(x.shape[0], 1, device = self.device)
1298
+ if self.sampler == 'ddpm':
1299
+ x = self.single_reverse_step(x, t, dt).to(self.device)
1300
+
1301
+ elif self.sampler == 'ddpm_cache':
1302
+ p_x0_cache, x_next = self.cached_reverse_step(x, t, dt, p_x0=p_x0_cache, attn_mask=attn_mask)
1303
+ if (not torch.allclose(x_next, x) or self.time_conditioning):
1304
+ # Disable caching
1305
+ p_x0_cache = None
1306
+ x = x_next.to(self.device)
1307
+ #print(self.tokenizer.decode(x.squeeze()))
1308
+ else:
1309
+ x = self._analytic_update(x, t, dt, attn_mask).to(self.device)
1310
+
1311
+ if self.config.sampling.noise_removal:
1312
+ t = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device)
1313
+ if self.sampler == 'analytic':
1314
+ x = self._denoiser_update(x, t).to(self.device)
1315
+ else:
1316
+ time_conditioning = self.noise(t)[0].to(self.device)
1317
+ x = self.forward(x, attn_mask=attn_mask, sigma=time_conditioning).argmax(dim=-1).to(self.device)
1318
+ #print(self.tokenizer.decode(x.squeeze()))
1319
+ return x.to(self.device)
1320
+
1321
+
1322
+ def restore_model_and_sample(self, num_steps, eps=1e-5):
1323
+ """Generate samples from the model."""
1324
+ self.backbone.eval()
1325
+ self.noise.eval()
1326
+ samples = self._sample(num_steps=num_steps, eps=eps)
1327
+ self.backbone.train()
1328
+ self.noise.train()
1329
+ return samples
1330
+
1331
+ def get_score(self, zt, sigma, attn_mask=None):
1332
+
1333
+ # score(x, t) = p_t(y) / p_t(x)
1334
+ # => log score(x, t) = log p_t(y) - log p_t(x)
1335
+
1336
+ # case 1: x = masked
1337
+ # (i) y = unmasked
1338
+ # log score(x, t) = log p_\theta(x)|_y + log k
1339
+ # where k = exp(- sigma) / (1 - exp(- sigma))
1340
+ # (ii) y = masked
1341
+ # log score(x, t) = 0
1342
+
1343
+ # case 2: x = unmasked
1344
+ # (i) y != masked, y != x
1345
+ # log score(x_i, t) = - inf
1346
+ # (ii) y = x
1347
+ # log score(x_i, t) = 0
1348
+ # (iii) y = masked token
1349
+ # log score(x_i, t) = - log k
1350
+ # where k = exp(- sigma) / (1 - exp(- sigma))
1351
+
1352
+ model_output = self.forward(zt, attn_mask=attn_mask, sigma=sigma)
1353
+
1354
+ log_k = -torch.log(torch.expm1(sigma)).squeeze(-1)
1355
+ assert log_k.ndim == 1
1356
+
1357
+ masked_score = model_output + log_k[:, None, None]
1358
+ masked_score[:, :, self.mask_index] = 0
1359
+
1360
+ unmasked_score = self.neg_infinity * torch.ones_like(model_output)
1361
+ unmasked_score = torch.scatter(
1362
+ unmasked_score, -1,
1363
+ zt[..., None],
1364
+ torch.zeros_like(unmasked_score[..., :1]))
1365
+
1366
+ unmasked_score[:, :, self.mask_index] = - (log_k[:, None] * torch.ones_like(zt))
1367
+
1368
+ masked_indices = (zt == self.mask_index).to(model_output.dtype)[:, :, None]
1369
+
1370
+ model_output = (masked_score * masked_indices + unmasked_score * (1 - masked_indices))
1371
+
1372
+ return model_output.exp()
1373
+
1374
+ def _staggered_score(self, score, dsigma):
1375
+ score = score.clone()
1376
+ extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
1377
+ score *= dsigma.exp()[:, None]
1378
+ score[..., self.mask_index] += extra_const
1379
+ return score
1380
+
1381
+ def _analytic_update(self, x, t, step_size, attn_mask=None):
1382
+ curr_sigma, _ = self.noise(t)
1383
+ next_sigma, _ = self.noise(t - step_size)
1384
+ dsigma = curr_sigma - next_sigma
1385
+ score = self.get_score(x, attn_mask, curr_sigma)
1386
+ stag_score = self._staggered_score(score, dsigma)
1387
+ probs = stag_score * self._transp_transition(x, dsigma)
1388
+ return _sample_categorical(probs)
1389
+
1390
+ def _denoiser_update(self, x, t):
1391
+ sigma, _ = self.noise(t)
1392
+ score = self.get_score(x, sigma)
1393
+ stag_score = self._staggered_score(score, sigma)
1394
+ probs = stag_score * self._transp_transition(x, sigma)
1395
+ probs[..., self.mask_index] = 0
1396
+ samples = _sample_categorical(probs)
1397
+ return samples
1398
+
1399
+ def _transp_transition(self, i, sigma):
1400
+ sigma = unsqueeze(sigma, reference=i[..., None])
1401
+ edge = torch.exp(-sigma) * F.one_hot(
1402
+ i, num_classes=self.vocab_size)
1403
+ edge += torch.where(i == self.mask_index,
1404
+ 1 - torch.exp(-sigma).squeeze(-1),
1405
+ 0)[..., None]
1406
+ return edge
1407
+
1408
+
1409
+ """TRAINING from https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py"""
1410
+
1411
+ def on_train_epoch_start(self):
1412
+ torch.cuda.empty_cache()
1413
+ self.backbone.train()
1414
+ self.noise.train()
1415
+
1416
+
1417
+ def training_step(self, batch, batch_idx):
1418
+ # Initialize throughput calculation
1419
+ start_time = time.time()
1420
+
1421
+ if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
1422
+ loss = self._compute_loss(batch, prefix='train', bond_mask=batch['bond_mask'])
1423
+ else:
1424
+ loss = self._compute_loss(batch, prefix='train')
1425
+
1426
+ self.log(name='trainer/loss',
1427
+ value=loss.item(),
1428
+ on_step=True,
1429
+ on_epoch=False,
1430
+ sync_dist=True)
1431
+
1432
+ # Calculate throughput
1433
+ elapsed_time = time.time() - start_time
1434
+ total_tokens = batch['input_ids'].numel()
1435
+ throughput = total_tokens / elapsed_time
1436
+
1437
+ self.log(name='trainer/throughput',
1438
+ value=throughput,
1439
+ on_step=True,
1440
+ on_epoch=False,
1441
+ sync_dist=True)
1442
+
1443
+ return loss
1444
+
1445
+
1446
+ def on_load_checkpoint(self, checkpoint):
1447
+ self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
1448
+ self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed']
1449
+
1450
+ ### VALIDATION ###
1451
+ def on_validation_epoch_start(self):
1452
+ gc.collect()
1453
+ torch.cuda.empty_cache()
1454
+ self.backbone.eval()
1455
+ self.noise.eval()
1456
+ assert self.valid_metrics.nll.mean_value == 0
1457
+ assert self.valid_metrics.nll.weight == 0
1458
+
1459
+ def validation_step(self, batch, batch_idx):
1460
+ if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
1461
+ loss = self._compute_loss(batch, prefix='val', bond_mask=batch['bond_mask'])
1462
+ else:
1463
+ loss = self._compute_loss(batch, prefix='val')
1464
+
1465
+ self.log(name='trainer/val_loss',
1466
+ value=loss.item(),
1467
+ on_step=True,
1468
+ on_epoch=False,
1469
+ prog_bar=True,
1470
+ sync_dist=True)
1471
+ return loss
1472
+
1473
+ def on_validation_epoch_end(self):
1474
+ gc.collect()
1475
+ torch.cuda.empty_cache()
1476
+
1477
+ ### OPTIMIZATION ###
1478
+
1479
+ def optimizer_step(self, *args, **kwargs):
1480
+ super().optimizer_step(*args, **kwargs)
1481
+
1482
+ gc.collect()
1483
+ torch.cuda.empty_cache()
1484
+
1485
+ def configure_optimizers(self):
1486
+ optimizer = torch.optim.AdamW(
1487
+ itertools.chain(self.backbone.parameters(),self.noise.parameters()),
1488
+ lr=self.config.optim.lr,
1489
+ betas=(self.config.optim.beta1, self.config.optim.beta2),
1490
+ eps=self.config.optim.eps,
1491
+ weight_decay=self.config.optim.weight_decay
1492
+ )
1493
+
1494
+ self.total_steps = self.config.trainer.max_steps
1495
+ scheduler = CosineWarmup(optimizer,
1496
+ warmup_steps=self.config.lr_scheduler.num_warmup_steps,
1497
+ total_steps=self.total_steps)
1498
+
1499
+ scheduler_dict = {
1500
+ 'scheduler': scheduler,
1501
+ 'interval': 'step',
1502
+ 'frequency': 1,
1503
+ 'monitor': 'val/loss',
1504
+ 'name': 'trainer/lr'
1505
+ }
1506
+
1507
+ return [optimizer], [scheduler_dict]
1508
+
1509
+ @torch.no_grad()
1510
+ def compute_masked_perplexity(self, generated_ids, input_ids):
1511
+ """
1512
+ Computes masked perplexity between array of generated token ids and masked ids that are converted to logits
1513
+ """
1514
+
1515
+ total_nll = 0
1516
+ total_tokens = 0
1517
+
1518
+ input_ids = torch.tensor(input_ids).to(self.device)
1519
+ #print(input_ids)
1520
+
1521
+ for sequence in generated_ids:
1522
+ # tokenize the sequence
1523
+
1524
+ gt_ids = torch.tensor(sequence).to(self.device)
1525
+ #print(gt_ids)
1526
+
1527
+ sys.stdout.flush()
1528
+
1529
+ # forward pass thorugh backbone peptideclm model
1530
+ attn_mask = torch.ones_like(input_ids).to(self.device)
1531
+
1532
+ # compute logits using backbone
1533
+
1534
+ if self.config.mode in ['train', 'ppl_eval']:
1535
+ outputs = self.backbone.forward(input_ids=input_ids, attn_mask=attn_mask)
1536
+ elif self.config.mode == 'sample_eval':
1537
+ outputs = self.backbone.forward(input_ids=input_ids, attn_mask=attn_mask)
1538
+
1539
+
1540
+ # get logits for each position in sequence across all tokens in vocab
1541
+ #logits = outputs[-1] # (batch_size, seq_length, vocab_size)
1542
+
1543
+ logits = outputs.view(-1, outputs.size(-1))
1544
+ gt_ids = gt_ids.view(-1)
1545
+
1546
+ #print(logits.shape)
1547
+ #print(gt_ids.shape)
1548
+
1549
+ # compute loss
1550
+ # shift_logits = logits[:, :-1, :].contiguous() # remove eos
1551
+ # shift_labels = input_ids[:, 1:].contiguous()
1552
+ # print(masked)
1553
+
1554
+ loss = F.cross_entropy(logits,
1555
+ gt_ids.where(input_ids==self.mask_index, torch.full_like(gt_ids, -100)).view(-1),
1556
+ reduction='sum')
1557
+
1558
+ total_nll += loss.item()
1559
+ # count all non-padding tokens
1560
+ total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() # count in bos and eos
1561
+
1562
+ # compute pseudo-perplexity
1563
+ # print(total_nll, ",;,", total_tokens)
1564
+ pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens))
1565
+ self.gen_ppl_metric.update(pseudo_perplexity)
1566
+
1567
+ return pseudo_perplexity.item()
1568
+
1569
+
1570
+ def unsqueeze(x, reference):
1571
+ return x.view(* x.shape, * ((1,) * (len(reference.shape) - len(x.shape))))
1572
+
1573
+ class CosineWarmup(_LRScheduler):
1574
+ def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
1575
+ self.warmup_steps = warmup_steps
1576
+ self.total_steps = total_steps
1577
+ self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
1578
+ super(CosineWarmup, self).__init__(optimizer, last_epoch)
1579
+
1580
+ def get_lr(self):
1581
+ if self.last_epoch < self.warmup_steps:
1582
+ return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
1583
+
1584
+ progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
1585
+ cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
1586
+ decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
1587
+
1588
+ return [decayed_lr * base_lr for base_lr in self.base_lrs]
distributed_utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal distributed training utilities."""
2
+ import os
3
+ import torch
4
+ import torch.distributed as dist
5
+
6
+
7
+ def setup_distributed(rank: int, world_size: int, backend: str = "nccl") -> None:
8
+ """Initialize distributed process group."""
9
+ if world_size <= 1:
10
+ return
11
+ os.environ.setdefault("MASTER_ADDR", "localhost")
12
+ os.environ.setdefault("MASTER_PORT", "29500")
13
+ dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
14
+ if torch.cuda.is_available():
15
+ torch.cuda.set_device(rank)
16
+
17
+
18
+ def cleanup_distributed() -> None:
19
+ """Destroy distributed process group."""
20
+ if dist.is_initialized():
21
+ dist.destroy_process_group()
22
+
23
+
24
+ def is_main_process() -> bool:
25
+ """Check if this is the main (rank 0) process."""
26
+ if not dist.is_initialized():
27
+ return True
28
+ return dist.get_rank() == 0
env.yml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: td3b
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ dependencies:
7
+ - python=3.10
8
+ - pip
9
+ - pytorch
10
+ - torchvision
11
+ - pytorch-cuda=12.1
12
+ - rdkit
13
+ - numpy
14
+ - pandas
15
+ - scikit-learn
16
+ - jupyterlab
17
+ - matplotlib-base
18
+ - seaborn
19
+ - tqdm
20
+ - pyyaml
21
+ - pip:
22
+ - pytorch-lightning==2.5.5
23
+ - lightning==2.5.5
24
+ - fair-esm==2.0.0
25
+ - transformers==4.56.2
26
+ - SmilesPE==0.0.3
27
+ - scipy==1.13.1
28
+ - wandb==0.22.0
29
+ - hydra-core==1.3.2
30
+ - hydra-submitit-launcher==1.2.0
31
+ - pathos==0.3.4
32
+ - matplotlib==3.10.1
33
+ - pandas==2.2.2
34
+ - seaborn==0.13.2
35
+ - timm==1.0.20
36
+ - xgboost==3.0.5
37
+ - loguru==0.7.3
finetune_multi_target.py ADDED
@@ -0,0 +1,1061 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Target TD3B Fine-Tuning Script
3
+
4
+ Trains TD3B on multiple protein targets with random sampling strategy.
5
+ Uses the GPCR directional oracle for direction-aware gating.
6
+
7
+ Architecture: Transition-Directed Discrete Diffusion for Binders (TD3B)
8
+ Training: Random K-target sampling + MCTS-guided trajectory optimization + contrastive learning
9
+
10
+ Key Features:
11
+ - Random K targets sampled per MCTS round
12
+ - Small-batch training to prevent OOM
13
+ - Periodic validation on held-out targets
14
+ - Checkpoint saving with validation metrics
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ import argparse
20
+ import logging
21
+ import warnings
22
+ from typing import List, Tuple, Dict, Optional
23
+ from dataclasses import dataclass
24
+ from pathlib import Path
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import numpy as np
29
+ import pandas as pd
30
+ import wandb
31
+ from tqdm import tqdm
32
+
33
+ # Add project root to path
34
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
35
+
36
+ from diffusion import Diffusion
37
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
38
+ from utils.app import PeptideAnalyzer
39
+ from scoring.functions.binding import MultiTargetBindingAffinity, TargetSpecificBindingAffinity
40
+ from td3b.data_utils import peptide_seq_to_smiles, smiles_token_length
41
+
42
+ # TD3B imports
43
+ from td3b.td3b_losses import TD3BTotalLoss
44
+ from td3b.td3b_finetune import (
45
+ extract_embeddings_from_mdlm,
46
+ add_td3b_sampling_to_model
47
+ )
48
+ from td3b.direction_oracle import DirectionalOracle
49
+
50
+ # Import shared configuration classes
51
+ from configs.finetune_config import (
52
+ RoFormerConfig,
53
+ NoiseConfig,
54
+ TrainingConfig,
55
+ SamplingConfig,
56
+ EvalConfig,
57
+ OptimConfig,
58
+ MCTSConfig,
59
+ DiffusionConfig
60
+ )
61
+
62
+ # Import shared utilities
63
+ from finetune_utils import (
64
+ load_tokenizer,
65
+ initialize_device,
66
+ create_output_directory,
67
+ save_model,
68
+ setup_wandb,
69
+ cleanup_wandb,
70
+ create_mcts_instance,
71
+ create_reward_function,
72
+ )
73
+
74
+ # Configure logging
75
+ logging.basicConfig(
76
+ level=logging.INFO,
77
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
78
+ )
79
+ logger = logging.getLogger(__name__)
80
+
81
+ # Suppress warnings
82
+ warnings.filterwarnings('ignore', category=FutureWarning)
83
+ warnings.filterwarnings('ignore', category=UserWarning)
84
+
85
+ # Constants
86
+ SEPARATOR_LINE = "=" * 80
87
+ eps = 1e-5
88
+
89
+ class TargetDataset:
90
+ """Dataset handler for multi-target training."""
91
+
92
+ def __init__(self, csv_path: str, tokenizer: Optional[SMILES_SPE_Tokenizer] = None):
93
+ """
94
+ Load target dataset from CSV.
95
+
96
+ Args:
97
+ csv_path: Path to CSV file with columns:
98
+ - Target_Sequence: Protein target sequence
99
+ - Ligand_Sequence: Binder sequence (for length reference)
100
+ - label: 'agonist' or 'antagonist'
101
+ tokenizer: Tokenizer used to compute SMILES token length
102
+ """
103
+ self.df = pd.read_csv(csv_path)
104
+ logger.info(f"Loaded {len(self.df)} samples from {csv_path}")
105
+ self.tokenizer = tokenizer
106
+
107
+ # Group by target
108
+ self.targets = {}
109
+ for target_seq in self.df['Target_Sequence'].unique():
110
+ target_df = self.df[self.df['Target_Sequence'] == target_seq]
111
+
112
+ # Get binder lengths for each direction
113
+ agonist_binders = target_df[target_df['label'] == 'agonist']['Ligand_Sequence'].tolist()
114
+ antagonist_binders = target_df[target_df['label'] == 'antagonist']['Ligand_Sequence'].tolist()
115
+
116
+ # Store actual sequence lengths
117
+ agonist_lengths = [self._binder_length(seq) for seq in agonist_binders] if agonist_binders else []
118
+ antagonist_lengths = [self._binder_length(seq) for seq in antagonist_binders] if antagonist_binders else []
119
+
120
+ # Use most common length for each direction, or average if tied
121
+ # This ensures we generate sequences similar to the provided data
122
+ if agonist_lengths:
123
+ agonist_target_length = int(np.median(agonist_lengths))
124
+ else:
125
+ # Default to antagonist length if no agonist, or 50 if neither
126
+ agonist_target_length = int(np.median(antagonist_lengths)) if antagonist_lengths else 50
127
+
128
+ if antagonist_lengths:
129
+ antagonist_target_length = int(np.median(antagonist_lengths))
130
+ else:
131
+ # Default to agonist length if no antagonist, or 50 if neither
132
+ antagonist_target_length = int(np.median(agonist_lengths)) if agonist_lengths else 50
133
+
134
+ self.targets[target_seq] = {
135
+ 'sequence': target_seq,
136
+ 'agonist_length': agonist_target_length, # Target length for agonist generation
137
+ 'antagonist_length': antagonist_target_length, # Target length for antagonist generation
138
+ 'agonist_count': len(agonist_binders),
139
+ 'antagonist_count': len(antagonist_binders)
140
+ }
141
+
142
+ logger.info(f"Found {len(self.targets)} unique targets")
143
+
144
+ def _binder_length(self, binder_seq: str) -> int:
145
+ smiles = peptide_seq_to_smiles(binder_seq)
146
+ if self.tokenizer is None:
147
+ return len(smiles)
148
+ return smiles_token_length(smiles, self.tokenizer)
149
+
150
+ def sample_targets(self, k: int, random_state: Optional[int] = None) -> List[str]:
151
+ """
152
+ Randomly sample K targets.
153
+
154
+ Args:
155
+ k: Number of targets to sample
156
+ random_state: Random seed for reproducibility
157
+
158
+ Returns:
159
+ List of target sequences
160
+ """
161
+ if random_state is not None:
162
+ np.random.seed(random_state)
163
+
164
+ target_seqs = list(self.targets.keys())
165
+ k = min(k, len(target_seqs))
166
+ return np.random.choice(target_seqs, size=k, replace=False).tolist()
167
+
168
+ def get_target_info(self, target_seq: str) -> Dict:
169
+ """Get information for a specific target."""
170
+ return self.targets[target_seq]
171
+
172
+ def get_sequence_length(self, target_seq: str, direction: str) -> int:
173
+ """
174
+ Get the target sequence length for generation.
175
+
176
+ Args:
177
+ target_seq: Target protein sequence
178
+ direction: 'agonist' or 'antagonist'
179
+
180
+ Returns:
181
+ Target binder sequence length
182
+ """
183
+ target_info = self.targets[target_seq]
184
+ if direction == 'agonist' or direction == 1.0 or direction == '+1':
185
+ return target_info['agonist_length']
186
+ else: # antagonist
187
+ return target_info['antagonist_length']
188
+
189
+ def get_all_targets(self) -> List[str]:
190
+ """Get all target sequences."""
191
+ return list(self.targets.keys())
192
+
193
+
194
+ def run_validation(
195
+ policy_model: Diffusion,
196
+ multi_target_affinity: MultiTargetBindingAffinity,
197
+ directional_oracle: DirectionalOracle,
198
+ tokenizer: SMILES_SPE_Tokenizer,
199
+ val_dataset: TargetDataset,
200
+ args: argparse.Namespace,
201
+ epoch: int,
202
+ device: torch.device,
203
+ protein_token_cache: Optional[Dict[str, torch.Tensor]] = None
204
+ ) -> Dict:
205
+ """
206
+ Run validation on all targets in validation dataset.
207
+
208
+ Args:
209
+ policy_model: Trained diffusion model
210
+ affinity_predictor: Binding affinity predictor
211
+ directional_oracle: Directional oracle
212
+ tokenizer: Tokenizer
213
+ val_dataset: Validation dataset
214
+ args: Training arguments
215
+ epoch: Current epoch
216
+ device: Device
217
+
218
+ Returns:
219
+ Dictionary with validation metrics
220
+ """
221
+ logger.info(f"\n{SEPARATOR_LINE}")
222
+ logger.info(f"Running validation at epoch {epoch}")
223
+ logger.info(f"{SEPARATOR_LINE}")
224
+
225
+ policy_model.eval()
226
+
227
+ all_sequences = []
228
+ all_affinities = []
229
+ all_gated_rewards = []
230
+ all_directions = []
231
+ all_target_directions = [] # d* for each sequence
232
+ all_valid_fractions = []
233
+ all_valid_fractions_per_sample = []
234
+ all_target_names = []
235
+
236
+ val_targets = val_dataset.get_all_targets()
237
+
238
+ if protein_token_cache is None:
239
+ protein_token_cache = {}
240
+
241
+ with torch.no_grad():
242
+ for target_seq in tqdm(val_targets, desc="Validating targets"):
243
+ target_info = val_dataset.get_target_info(target_seq)
244
+ target_protein_tokens = protein_token_cache.get(target_seq)
245
+ if target_protein_tokens is None:
246
+ target_protein_tokens = directional_oracle.encode_protein(target_seq)
247
+ protein_token_cache[target_seq] = target_protein_tokens
248
+
249
+ # Generate for both agonist and antagonist
250
+ for direction_name, d_star in [('agonist', 1.0), ('antagonist', -1.0)]:
251
+ # Get the target sequence length for this direction
252
+ target_length = val_dataset.get_sequence_length(target_seq, direction_name)
253
+
254
+ # Temporarily set args.seq_length for this generation
255
+ original_seq_length = args.seq_length
256
+ args.seq_length = target_length
257
+
258
+ # Create target-specific affinity predictor for this target
259
+ target_affinity = TargetSpecificBindingAffinity(multi_target_affinity, target_seq)
260
+
261
+ # Create reward model for this target+direction
262
+ reward_model = create_reward_function(
263
+ affinity_predictor=target_affinity,
264
+ directional_oracle=directional_oracle,
265
+ target_direction=d_star,
266
+ target_protein_tokens=target_protein_tokens,
267
+ tokenizer=tokenizer,
268
+ device=device,
269
+ min_affinity_threshold=args.min_affinity_threshold,
270
+ use_confidence_weighting=True,
271
+ temperature=args.sigmoid_temperature
272
+ )
273
+
274
+ # Sample sequences with the correct length
275
+ x_eval, eval_metrics = policy_model.sample_finetuned_td3b(
276
+ args,
277
+ reward_model,
278
+ batch_size=args.val_samples_per_target,
279
+ dataframe=False
280
+ )
281
+
282
+ # Restore original seq_length
283
+ args.seq_length = original_seq_length
284
+
285
+ # Decode sequences
286
+ sequences = tokenizer.batch_decode(x_eval)
287
+
288
+ # Get metrics
289
+ affinities = eval_metrics.get('affinity', [])
290
+ gated_rewards = eval_metrics.get('gated_reward', [])
291
+ directions = eval_metrics.get('direction_predictions', [])
292
+ valid_fraction = eval_metrics.get('valid_fraction', 0.0)
293
+
294
+ # CRITICAL FIX: Metrics are only computed for valid sequences
295
+ # So we should extend based on the length of metrics arrays, not all sequences
296
+ num_valid = len(affinities) # Number of valid sequences with metrics
297
+
298
+ # Filter to only valid sequences (metrics are only for valid ones)
299
+ from utils.app import PeptideAnalyzer
300
+ analyzer = PeptideAnalyzer()
301
+ valid_sequences = [seq for seq in sequences if analyzer.is_peptide(seq)][:num_valid]
302
+
303
+ # Store (all arrays must have the same length = num_valid)
304
+ all_sequences.extend(valid_sequences) # Only valid sequences
305
+ all_affinities.extend(affinities)
306
+ all_gated_rewards.extend(gated_rewards)
307
+ all_directions.extend(directions)
308
+ all_target_directions.extend([d_star] * num_valid)
309
+ all_valid_fractions.append(valid_fraction)
310
+ all_valid_fractions_per_sample.extend([valid_fraction] * num_valid)
311
+ all_target_names.extend([target_seq[:20]] * num_valid)
312
+
313
+ # Compute validation metrics
314
+ all_affinities = np.array(all_affinities)
315
+ all_gated_rewards = np.array(all_gated_rewards)
316
+ all_directions = np.array(all_directions)
317
+ all_target_directions = np.array(all_target_directions)
318
+
319
+ if all_directions.size == 0:
320
+ direction_correct = np.array([], dtype=np.float32)
321
+ else:
322
+ direction_correct = np.where(
323
+ all_target_directions > 0,
324
+ all_directions >= 0.5,
325
+ all_directions < 0.5
326
+ ).astype(np.float32)
327
+
328
+ # Consistency rewards: d* × (f_φ - 0.5)
329
+ consistency_rewards = all_target_directions * (all_directions - 0.5) # range from -1 to 1.
330
+ success_rates = direction_correct * np.array(all_valid_fractions_per_sample, dtype=np.float32)
331
+
332
+ # Separate by direction
333
+ agonist_mask = all_target_directions == 1.0
334
+ antagonist_mask = all_target_directions == -1.0
335
+
336
+ consistency_agonist = consistency_rewards[agonist_mask]
337
+ consistency_antagonist = consistency_rewards[antagonist_mask]
338
+
339
+ val_metrics = {
340
+ 'affinity_mean': np.mean(all_affinities),
341
+ 'affinity_std': np.std(all_affinities),
342
+ 'gated_reward_mean': np.mean(all_gated_rewards),
343
+ 'gated_reward_std': np.std(all_gated_rewards),
344
+ 'direction_oracle_mean': np.mean(all_directions),
345
+ 'direction_oracle_std': np.std(all_directions),
346
+ 'consistency_reward_mean': np.mean(consistency_rewards),
347
+ 'consistency_reward_std': np.std(consistency_rewards),
348
+ 'consistency_agonist_mean': np.mean(consistency_agonist) if len(consistency_agonist) > 0 else 0.0,
349
+ 'consistency_agonist_std': np.std(consistency_agonist) if len(consistency_agonist) > 0 else 0.0,
350
+ 'consistency_antagonist_mean': np.mean(consistency_antagonist) if len(consistency_antagonist) > 0 else 0.0,
351
+ 'consistency_antagonist_std': np.std(consistency_antagonist) if len(consistency_antagonist) > 0 else 0.0,
352
+ 'valid_fraction_mean': np.mean(all_valid_fractions),
353
+ 'valid_fraction_std': np.std(all_valid_fractions),
354
+ 'direction_accuracy_mean': np.mean(direction_correct) if direction_correct.size else 0.0,
355
+ 'direction_accuracy_std': np.std(direction_correct) if direction_correct.size else 0.0,
356
+ 'success_rate_mean': np.mean(success_rates) if success_rates.size else 0.0,
357
+ 'success_rate_std': np.std(success_rates) if success_rates.size else 0.0
358
+ }
359
+
360
+ # Log validation metrics
361
+ logger.info(f"\nValidation Results (Epoch {epoch}):")
362
+ logger.info(f" Affinity: {val_metrics['affinity_mean']:.4f} ± {val_metrics['affinity_std']:.4f}")
363
+ logger.info(f" Gated Reward: {val_metrics['gated_reward_mean']:.4f} ± {val_metrics['gated_reward_std']:.4f}")
364
+ logger.info(f" Direction Oracle: {val_metrics['direction_oracle_mean']:.4f} ± {val_metrics['direction_oracle_std']:.4f}")
365
+ logger.info(f" Consistency Reward: {val_metrics['consistency_reward_mean']:.4f} ± {val_metrics['consistency_reward_std']:.4f}")
366
+ logger.info(f" Consistency (d*=+1): {val_metrics['consistency_agonist_mean']:.4f} ± {val_metrics['consistency_agonist_std']:.4f}")
367
+ logger.info(f" Consistency (d*=-1): {val_metrics['consistency_antagonist_mean']:.4f} ± {val_metrics['consistency_antagonist_std']:.4f}")
368
+ logger.info(f" Valid Fraction: {val_metrics['valid_fraction_mean']:.4f} ± {val_metrics['valid_fraction_std']:.4f}")
369
+ logger.info(f" Direction Accuracy: {val_metrics['direction_accuracy_mean']:.4f} ± {val_metrics['direction_accuracy_std']:.4f}")
370
+ logger.info(f" Success Rate: {val_metrics['success_rate_mean']:.4f} ± {val_metrics['success_rate_std']:.4f}")
371
+
372
+ # Save validation sequences to file
373
+ val_df = pd.DataFrame({
374
+ 'target': all_target_names,
375
+ 'sequence': all_sequences,
376
+ 'target_direction': all_target_directions,
377
+ 'affinity': all_affinities,
378
+ 'gated_reward': all_gated_rewards,
379
+ 'direction_oracle': all_directions,
380
+ 'consistency_reward': consistency_rewards,
381
+ 'direction_accuracy': direction_correct,
382
+ 'success_rate': success_rates
383
+ })
384
+
385
+ val_output_path = os.path.join(args.save_path, f'validation_epoch_{epoch}.csv')
386
+ val_df.to_csv(val_output_path, index=False)
387
+ logger.info(f"Validation sequences saved to {val_output_path}")
388
+
389
+ policy_model.train()
390
+
391
+ return val_metrics
392
+
393
+
394
+ def parse_args():
395
+ """Parse command-line arguments."""
396
+ parser = argparse.ArgumentParser(description='Multi-Target TD3B Fine-Tuning')
397
+
398
+ # Paths
399
+ path_group = parser.add_argument_group('Paths')
400
+ path_group.add_argument('--base_path', type=str, required=True,
401
+ help='Base path for TR2-D2 project')
402
+ path_group.add_argument('--train_csv', type=str, required=True,
403
+ help='Path to training CSV file')
404
+ path_group.add_argument('--val_csv', type=str, default=None,
405
+ help='Path to validation CSV file (optional)')
406
+ path_group.add_argument('--pretrained_checkpoint', type=str, required=True,
407
+ help='Path to pretrained diffusion model checkpoint')
408
+ path_group.add_argument('--run_name', type=str, required=True,
409
+ help='Name for this training run')
410
+ path_group.add_argument('--device', type=str, default='cuda',
411
+ help='Device to use (cuda or cpu)')
412
+
413
+ # Multi-target sampling
414
+ target_group = parser.add_argument_group('Multi-Target Sampling')
415
+ target_group.add_argument('--targets_per_mcts', type=int, default=5,
416
+ help='Number of targets to sample per MCTS round (K)')
417
+ target_group.add_argument('--resample_targets_every', type=int, default=1,
418
+ help='Resample targets every N epochs')
419
+
420
+ # Training hyperparameters
421
+ train_group = parser.add_argument_group('Training')
422
+ train_group.add_argument('--num_epochs', type=int, default=200,
423
+ help='Total number of training epochs')
424
+ train_group.add_argument('--learning_rate', type=float, default=3e-4,
425
+ help='Learning rate for optimizer')
426
+ train_group.add_argument('--train_batch_size', type=int, default=16,
427
+ help='Batch size for training (small to prevent OOM)')
428
+ train_group.add_argument('--gradient_accumulation_steps', type=int, default=4,
429
+ help='Accumulate gradients over N steps')
430
+ train_group.add_argument('--resample_every_n_step', type=int, default=10,
431
+ help='Resample MCTS every N epochs')
432
+ train_group.add_argument('--save_every_n_epochs', type=int, default=20,
433
+ help='Save checkpoint every N epochs')
434
+ train_group.add_argument('--validate_every_n_epochs', type=int, default=20,
435
+ help='Run validation every N epochs')
436
+ train_group.add_argument('--num_epoch_for_sampling', type=int, default=5,
437
+ help='Run evaluation sampling every N epochs (set <=0 to disable)')
438
+ train_group.add_argument('--reset_every_n_step', type=int, default=50,
439
+ help='Reset MCTS tree every N epochs')
440
+
441
+ # MCTS hyperparameters
442
+ mcts_group = parser.add_argument_group('MCTS')
443
+ mcts_group.add_argument('--num_iter', type=int, default=50,
444
+ help='MCTS iterations per resample (v1 default: 50, reduce for multi-target)')
445
+ mcts_group.add_argument('--num_children', type=int, default=30,
446
+ help='Children per MCTS expansion')
447
+ mcts_group.add_argument('--buffer_size', type=int, default=50,
448
+ help='Pareto buffer size (v1 default: 50)')
449
+ mcts_group.add_argument('--replay_buffer_size', type=int, default=0,
450
+ help='Max replay buffer size across resamples (0 disables replay)')
451
+ mcts_group.add_argument('--replay_buffer_strategy', type=str, default='fifo',
452
+ choices=['fifo', 'random'],
453
+ help='Replay buffer eviction strategy when full')
454
+ mcts_group.add_argument('--alpha', type=float, default=0.1,
455
+ help='Temperature for importance weighting')
456
+ mcts_group.add_argument('--exploration', type=float, default=1.0,
457
+ help='UCB exploration constant')
458
+
459
+ # TD3B loss hyperparameters
460
+ loss_group = parser.add_argument_group('TD3B Loss')
461
+ loss_group.add_argument('--contrastive_weight', type=float, default=0.1,
462
+ help='Weight for contrastive loss (v1 default: 0.1)')
463
+ loss_group.add_argument('--contrastive_margin', type=float, default=1.0,
464
+ help='Margin for contrastive loss')
465
+ loss_group.add_argument('--contrastive_type', type=str, default='triplet',
466
+ choices=['triplet', 'ntxent', 'supcon'],
467
+ help='Type of contrastive loss')
468
+ loss_group.add_argument('--kl_beta', type=float, default=0.1,
469
+ help='KL divergence regularization coefficient (v1 default: 0.1)')
470
+ loss_group.add_argument('--min_affinity_threshold', type=float, default=0.0,
471
+ help='Minimum affinity threshold for allosteric control (CRITICAL)')
472
+ loss_group.add_argument('--sigmoid_temperature', type=float, default=0.1,
473
+ help='Temperature for sigmoid gating')
474
+
475
+ # Validation
476
+ val_group = parser.add_argument_group('Validation')
477
+ val_group.add_argument('--val_samples_per_target', type=int, default=20,
478
+ help='Number of sequences to generate per target during validation')
479
+
480
+ # Architecture
481
+ arch_group = parser.add_argument_group('Architecture')
482
+ arch_group.add_argument('--seq_length', type=int, default=200,
483
+ help='Maximum sequence length')
484
+ arch_group.add_argument('--embedding_pool_method', type=str, default='cls',
485
+ choices=['cls', 'mean', 'max'],
486
+ help='Pooling method for embeddings')
487
+ arch_group.add_argument('--hidden_dim', type=int, default=768,
488
+ help='Hidden dimension size')
489
+ arch_group.add_argument('--num_layers', type=int, default=8,
490
+ help='Number of transformer layers (v1 default: 8)')
491
+ arch_group.add_argument('--num_heads', type=int, default=8,
492
+ help='Number of attention heads (v1 default: 8)')
493
+ arch_group.add_argument('--sampling_eps', type=float, default=1e-3,
494
+ help='Sampling epsilon (v1 default: 1e-3)')
495
+ arch_group.add_argument('--total_num_steps', type=int, default=128,
496
+ help='Total number of diffusion steps (v1 default: 128)')
497
+
498
+ # Optimization
499
+ opt_group = parser.add_argument_group('Optimization')
500
+ opt_group.add_argument('--grad_clip', action='store_true',
501
+ help='Enable gradient clipping')
502
+ opt_group.add_argument('--gradnorm_clip', type=float, default=1.0,
503
+ help='Gradient norm clipping threshold')
504
+ opt_group.add_argument('--wdce_num_replicates', type=int, default=16,
505
+ help='Number of replicates for WDCE loss (v1 default: 16)')
506
+ opt_group.add_argument('--centering', action='store_true',
507
+ help='Enable centering in WDCE loss')
508
+
509
+ # Logging
510
+ log_group = parser.add_argument_group('Logging')
511
+ log_group.add_argument('--wandb_project', type=str, default='TD3B-multi-target',
512
+ help='W&B project name')
513
+ log_group.add_argument('--wandb_entity', type=str, default='phos_zj',
514
+ help='W&B entity name')
515
+
516
+ # Directional oracle
517
+ oracle_group = parser.add_argument_group('Directional Oracle')
518
+ oracle_group.add_argument('--direction_oracle_ckpt', type=str, default=None,
519
+ help='Path to directional oracle checkpoint')
520
+ oracle_group.add_argument('--direction_oracle_tr2d2_checkpoint', type=str, default=None,
521
+ help='Path to TR2D2 checkpoint used by the oracle')
522
+ oracle_group.add_argument('--direction_oracle_tokenizer_vocab', type=str, default=None,
523
+ help='Path to SMILES tokenizer vocab for oracle')
524
+ oracle_group.add_argument('--direction_oracle_tokenizer_splits', type=str, default=None,
525
+ help='Path to SMILES tokenizer splits for oracle')
526
+ oracle_group.add_argument('--direction_oracle_esm_name', type=str,
527
+ default='facebook/esm2_t33_650M_UR50D',
528
+ help='ESM model name or local path')
529
+ oracle_group.add_argument('--direction_oracle_esm_cache_dir', type=str, default=None,
530
+ help='Optional cache directory for ESM model')
531
+ oracle_group.add_argument('--direction_oracle_esm_local_files_only', action='store_true',
532
+ help='Load ESM from local cache only (no network)')
533
+ oracle_group.add_argument('--direction_oracle_max_ligand_length', type=int, default=768,
534
+ help='Max SMILES token length for oracle')
535
+ oracle_group.add_argument('--direction_oracle_max_protein_length', type=int, default=1024,
536
+ help='Max protein token length for oracle')
537
+ oracle_group.add_argument('--direction_oracle_d_model', type=int, default=256,
538
+ help='Oracle hidden dimension (must match checkpoint)')
539
+ oracle_group.add_argument('--direction_oracle_n_heads', type=int, default=4,
540
+ help='Oracle attention heads (must match checkpoint)')
541
+ oracle_group.add_argument('--direction_oracle_n_self_attn_layers', type=int, default=1,
542
+ help='Oracle self-attention layers (must match checkpoint)')
543
+ oracle_group.add_argument('--direction_oracle_n_bmca_layers', type=int, default=2,
544
+ help='Oracle cross-attention layers (must match checkpoint)')
545
+ oracle_group.add_argument('--direction_oracle_dropout', type=float, default=0.3,
546
+ help='Oracle dropout (must match checkpoint)')
547
+
548
+ args = parser.parse_args()
549
+
550
+ # Resolve default oracle paths relative to base_path
551
+ base_tr2d2_path = os.path.join(args.base_path, 'tr2d2-pep')
552
+ if args.direction_oracle_ckpt is None:
553
+ args.direction_oracle_ckpt = os.path.join(
554
+ base_tr2d2_path, 'best_model_tr2d2_gpcr_fixed.pt'
555
+ )
556
+ if args.direction_oracle_tr2d2_checkpoint is None:
557
+ args.direction_oracle_tr2d2_checkpoint = os.path.join(
558
+ base_tr2d2_path, 'pretrained', 'peptune-pretrained.ckpt'
559
+ )
560
+ if args.direction_oracle_tokenizer_vocab is None:
561
+ args.direction_oracle_tokenizer_vocab = os.path.join(
562
+ base_tr2d2_path, 'tokenizer', 'new_vocab.txt'
563
+ )
564
+ if args.direction_oracle_tokenizer_splits is None:
565
+ args.direction_oracle_tokenizer_splits = os.path.join(
566
+ base_tr2d2_path, 'tokenizer', 'new_splits.txt'
567
+ )
568
+
569
+ # Add derived attributes (required by MCTS)
570
+ args.time_conditioning = False
571
+ args.num_obj = 5 # Must match padded score vector size
572
+ args.scalarization = "sum"
573
+
574
+ # Create save path
575
+ args.save_path = create_output_directory(
576
+ args.base_path,
577
+ args.run_name,
578
+ add_timestamp=True
579
+ )
580
+
581
+ return args
582
+
583
+
584
+ def main():
585
+ args = parse_args()
586
+
587
+ logger.info(f"\n{SEPARATOR_LINE}")
588
+ logger.info("Multi-Target TD3B Fine-Tuning")
589
+ logger.info(f"{SEPARATOR_LINE}\n")
590
+
591
+ # Set device
592
+ device = initialize_device(args.device)
593
+
594
+ # Initialize W&B
595
+ setup_wandb(
596
+ project=args.wandb_project,
597
+ name=args.run_name,
598
+ config=vars(args),
599
+ entity=args.wandb_entity
600
+ )
601
+
602
+ # Tokenizer
603
+ tokenizer = load_tokenizer(args.base_path)
604
+
605
+ # Load datasets
606
+ logger.info("\n[1/6] Loading datasets...")
607
+ train_dataset = TargetDataset(args.train_csv, tokenizer=tokenizer)
608
+ val_dataset = TargetDataset(args.val_csv, tokenizer=tokenizer) if args.val_csv else None
609
+
610
+ # Load models
611
+ logger.info("\n[2/6] Loading models...")
612
+
613
+ # Create diffusion config
614
+ config = DiffusionConfig(
615
+ roformer=RoFormerConfig(
616
+ hidden_size=args.hidden_dim,
617
+ n_layers=args.num_layers,
618
+ n_heads=args.num_heads
619
+ ),
620
+ noise=NoiseConfig(),
621
+ training=TrainingConfig(sampling_eps=args.sampling_eps),
622
+ sampling=SamplingConfig(
623
+ steps=args.total_num_steps,
624
+ sampling_eps=args.sampling_eps
625
+ ),
626
+ eval_cfg=EvalConfig(),
627
+ optim=OptimConfig(lr=args.learning_rate),
628
+ mcts=MCTSConfig()
629
+ )
630
+
631
+ # Policy model
632
+ policy_model = Diffusion(
633
+ config=config,
634
+ tokenizer=tokenizer,
635
+ device=device
636
+ ).to(device)
637
+
638
+ # Load pretrained checkpoint
639
+ checkpoint = torch.load(args.pretrained_checkpoint, map_location=device, weights_only=False)
640
+
641
+ # Handle different checkpoint formats (like v1)
642
+ CHECKPOINT_KEYS = ('state_dict', 'model_state_dict')
643
+ state_dict = None
644
+ for key in CHECKPOINT_KEYS:
645
+ if key in checkpoint:
646
+ state_dict = checkpoint[key]
647
+ logger.info(f"Loading checkpoint from key: {key}")
648
+ break
649
+
650
+ if state_dict is None:
651
+ # Assume checkpoint is already a state_dict
652
+ state_dict = checkpoint
653
+ logger.info("Loading checkpoint as direct state_dict")
654
+
655
+ policy_model.load_state_dict(state_dict, strict=False)
656
+ logger.info(f"Loaded pretrained checkpoint from {args.pretrained_checkpoint}")
657
+
658
+ # Reference model (frozen)
659
+ reference_model = Diffusion(
660
+ config=config,
661
+ tokenizer=tokenizer,
662
+ device=device
663
+ ).to(device)
664
+ reference_model.load_state_dict(state_dict, strict=False)
665
+ reference_model.eval()
666
+ for param in reference_model.parameters():
667
+ param.requires_grad = False
668
+ logger.info("Created reference model (frozen)")
669
+
670
+ # Add TD3B sampling method, fix bugs, sampling sequences with w(t) as condition
671
+ policy_model = add_td3b_sampling_to_model(policy_model)
672
+
673
+ # Multi-target affinity predictor
674
+ multi_target_affinity = MultiTargetBindingAffinity(
675
+ tokenizer=tokenizer,
676
+ base_path=args.base_path,
677
+ device=device,
678
+ emb_model=policy_model.backbone # Use backbone Roformer model (matches v1)
679
+ )
680
+ logger.info("Created multi-target binding affinity predictor")
681
+
682
+ # Directional oracle (GPCR classifier)
683
+ for path_label, path in [
684
+ ("direction_oracle_ckpt", args.direction_oracle_ckpt),
685
+ ("direction_oracle_tr2d2_checkpoint", args.direction_oracle_tr2d2_checkpoint),
686
+ ("direction_oracle_tokenizer_vocab", args.direction_oracle_tokenizer_vocab),
687
+ ("direction_oracle_tokenizer_splits", args.direction_oracle_tokenizer_splits),
688
+ ]:
689
+ if not os.path.isfile(path):
690
+ raise FileNotFoundError(f"Missing {path_label}: {path}")
691
+
692
+ directional_oracle = DirectionalOracle(
693
+ model_ckpt=args.direction_oracle_ckpt,
694
+ tr2d2_checkpoint=args.direction_oracle_tr2d2_checkpoint,
695
+ tokenizer_vocab=args.direction_oracle_tokenizer_vocab,
696
+ tokenizer_splits=args.direction_oracle_tokenizer_splits,
697
+ esm_name=args.direction_oracle_esm_name,
698
+ d_model=args.direction_oracle_d_model,
699
+ n_heads=args.direction_oracle_n_heads,
700
+ n_self_attn_layers=args.direction_oracle_n_self_attn_layers,
701
+ n_bmca_layers=args.direction_oracle_n_bmca_layers,
702
+ dropout=args.direction_oracle_dropout,
703
+ max_ligand_length=args.direction_oracle_max_ligand_length,
704
+ max_protein_length=args.direction_oracle_max_protein_length,
705
+ device=device,
706
+ esm_cache_dir=args.direction_oracle_esm_cache_dir,
707
+ esm_local_files_only=args.direction_oracle_esm_local_files_only
708
+ )
709
+ directional_oracle.eval()
710
+
711
+ protein_token_cache: Dict[str, torch.Tensor] = {}
712
+
713
+ def get_protein_tokens(target_seq: str) -> torch.Tensor:
714
+ cached = protein_token_cache.get(target_seq)
715
+ if cached is None:
716
+ cached = directional_oracle.encode_protein(target_seq)
717
+ protein_token_cache[target_seq] = cached
718
+ return cached
719
+
720
+ # Loss function
721
+ logger.info("\n[3/6] Creating loss function...")
722
+ td3b_loss_fn = TD3BTotalLoss(
723
+ contrastive_weight=args.contrastive_weight,
724
+ contrastive_margin=args.contrastive_margin,
725
+ kl_beta=args.kl_beta,
726
+ reference_model=reference_model,
727
+ adaptive_margin=True
728
+ )
729
+
730
+ # WDCE loss
731
+ from finetune_utils import loss_wdce
732
+
733
+ logger.info("\n[4/6] Setting up training...")
734
+ policy_model.train()
735
+ torch.set_grad_enabled(True)
736
+ optimizer = torch.optim.AdamW(policy_model.parameters(), lr=args.learning_rate)
737
+
738
+ # Training logs
739
+ batch_losses = []
740
+ batch_wdce_losses = []
741
+ batch_contrastive_losses = []
742
+ batch_kl_losses = []
743
+
744
+ # Multi-target buffer
745
+ # We'll store sequences from all sampled targets here
746
+ buffer_sequences = [] # List of (x, log_rnd, reward, directional_label, confidence)
747
+ current_targets = []
748
+
749
+ def trim_replay_buffer(items, max_size, strategy):
750
+ if max_size <= 0 or len(items) <= max_size:
751
+ return items
752
+ if strategy == "fifo":
753
+ return items[-max_size:]
754
+ indices = np.random.choice(len(items), size=max_size, replace=False)
755
+ return [items[i] for i in indices]
756
+
757
+ logger.info(f"\n{SEPARATOR_LINE}")
758
+ logger.info("Starting Training")
759
+ logger.info(f"{SEPARATOR_LINE}\n")
760
+
761
+ # Training loop
762
+ pbar = tqdm(range(args.num_epochs))
763
+
764
+ for epoch in pbar:
765
+ # Sample new targets if needed
766
+ if epoch % args.resample_targets_every == 0 or len(current_targets) == 0:
767
+ current_targets = train_dataset.sample_targets(
768
+ k=args.targets_per_mcts,
769
+ random_state=epoch
770
+ )
771
+ logger.info(f"\nEpoch {epoch}: Sampled {len(current_targets)} targets for training")
772
+
773
+ # MCTS sampling phase (less frequent) - this is when we regenerate sequences
774
+ if epoch % args.resample_every_n_step == 0:
775
+ if args.replay_buffer_size <= 0:
776
+ # Clear buffer only when regenerating with new MCTS if replay is disabled
777
+ buffer_sequences = []
778
+ else:
779
+ logger.info(
780
+ f"Epoch {epoch}: Replay buffer enabled, keeping {len(buffer_sequences)} sequences before refresh"
781
+ )
782
+ logger.info(f"Epoch {epoch}: Running MCTS for {len(current_targets)} targets...")
783
+ mcts_valid_total = 0
784
+ mcts_run_count = 0
785
+ mcts_empty_runs = 0
786
+
787
+ with torch.no_grad():
788
+ for target_seq in current_targets:
789
+ target_info = train_dataset.get_target_info(target_seq)
790
+
791
+ # Sample both agonist and antagonist
792
+ for direction_name, d_star in [('agonist', 1.0), ('antagonist', -1.0)]:
793
+ # Get the target sequence length for this direction
794
+ target_length = train_dataset.get_sequence_length(target_seq, direction_name)
795
+
796
+ # Temporarily set args.seq_length for this generation
797
+ original_seq_length = args.seq_length
798
+ args.seq_length = target_length
799
+
800
+ # Create target-specific affinity predictor for this target
801
+ target_affinity = TargetSpecificBindingAffinity(multi_target_affinity, target_seq)
802
+
803
+ # Create reward model for this target
804
+ reward_model = create_reward_function(
805
+ affinity_predictor=target_affinity,
806
+ directional_oracle=directional_oracle,
807
+ target_direction=d_star,
808
+ target_protein_tokens=get_protein_tokens(target_seq),
809
+ tokenizer=tokenizer,
810
+ device=device,
811
+ min_affinity_threshold=args.min_affinity_threshold,
812
+ use_confidence_weighting=True,
813
+ temperature=args.sigmoid_temperature
814
+ )
815
+
816
+ # Create MCTS using shared utility
817
+ mcts = create_mcts_instance(
818
+ args=args,
819
+ policy_model=policy_model,
820
+ reward_function=reward_model,
821
+ tokenizer=tokenizer,
822
+ buffer_size=args.buffer_size
823
+ )
824
+
825
+ # Run MCTS
826
+ reset_tree = (epoch % args.reset_every_n_step == 0)
827
+ results = mcts.forward(resetTree=reset_tree)
828
+
829
+ # Restore original seq_length
830
+ args.seq_length = original_seq_length
831
+
832
+ # Unpack results
833
+ if len(results) == 7:
834
+ x_final, log_rnd, final_rewards, score_vectors, sequences, directional_labels, confidences = results
835
+
836
+ # Skip if MCTS returned empty buffer (no valid sequences found)
837
+ if len(x_final) == 0:
838
+ logger.warning(f"MCTS returned empty buffer for target={target_seq[:20]}, direction={direction_name}")
839
+ mcts_run_count += 1
840
+ mcts_empty_runs += 1
841
+ continue
842
+ mcts_run_count += 1
843
+ mcts_valid_total += len(sequences)
844
+
845
+ # Add to buffer
846
+ for i in range(len(x_final)):
847
+ buffer_sequences.append({
848
+ 'x': x_final[i],
849
+ 'log_rnd': log_rnd[i],
850
+ 'reward': final_rewards[i],
851
+ 'directional_label': d_star,
852
+ 'confidence': confidences[i] if isinstance(confidences, np.ndarray) else 1.0
853
+ })
854
+
855
+ if args.replay_buffer_size > 0:
856
+ buffer_sequences = trim_replay_buffer(
857
+ buffer_sequences,
858
+ args.replay_buffer_size,
859
+ args.replay_buffer_strategy
860
+ )
861
+
862
+ logger.info(
863
+ f"Epoch {epoch}: MCTS runs={mcts_run_count}, "
864
+ f"valid_sequences={mcts_valid_total}, empty_runs={mcts_empty_runs}"
865
+ )
866
+ logger.info(f"Epoch {epoch}: Buffer size: {len(buffer_sequences)} sequences")
867
+
868
+ # Training phase: sample mini-batches from buffer
869
+ if len(buffer_sequences) == 0:
870
+ logger.warning(f"Epoch {epoch}: Buffer is empty, skipping training")
871
+ continue
872
+
873
+ # Shuffle buffer
874
+ np.random.shuffle(buffer_sequences)
875
+
876
+ # Mini-batch training
877
+ num_batches = max(1, len(buffer_sequences) // args.train_batch_size)
878
+ epoch_loss = 0.0
879
+ epoch_wdce_loss = 0.0
880
+ epoch_contrastive_loss = 0.0
881
+ epoch_kl_loss = 0.0
882
+
883
+ optimizer.zero_grad()
884
+
885
+ for batch_idx in range(num_batches):
886
+ start_idx = batch_idx * args.train_batch_size
887
+ end_idx = min(start_idx + args.train_batch_size, len(buffer_sequences))
888
+ batch_data = buffer_sequences[start_idx:end_idx]
889
+
890
+ # Pad sequences to the same length (efficient batching for variable-length sequences)
891
+ # Use padding to handle different sequence lengths from different targets
892
+ x_list = [item['x'] for item in batch_data]
893
+ log_rnd_list = [item['log_rnd'] for item in batch_data] # Scalars, not vectors!
894
+
895
+ # Pad x_batch: pad with mask_index (typically 0 or a special token)
896
+ mask_index = policy_model.mask_index if hasattr(policy_model, 'mask_index') else 0
897
+ max_len = max(x.shape[0] for x in x_list)
898
+
899
+ # Create padded tensors
900
+ x_batch = torch.full(
901
+ (len(x_list), max_len),
902
+ fill_value=mask_index,
903
+ dtype=x_list[0].dtype,
904
+ device=device
905
+ )
906
+
907
+ # Create attention mask: 1 for real tokens, 0 for padding
908
+ # This tells the model which positions are valid vs padded
909
+ attn_mask = torch.zeros(
910
+ (len(x_list), max_len),
911
+ dtype=torch.long,
912
+ device=device
913
+ )
914
+
915
+ # Fill in the real sequences and mark valid positions
916
+ for i, x in enumerate(x_list):
917
+ seq_len = x.shape[0]
918
+ x_batch[i, :seq_len] = x.to(device)
919
+ attn_mask[i, :seq_len] = 1 # Mark valid positions
920
+
921
+ # log_rnd is a SCALAR per sequence, not a vector - just stack them
922
+ log_rnd_batch = torch.stack([lr.to(device) if isinstance(lr, torch.Tensor) else torch.tensor(lr, device=device) for lr in log_rnd_list])
923
+
924
+ directional_labels_batch = torch.tensor(
925
+ [item['directional_label'] for item in batch_data],
926
+ dtype=torch.float32,
927
+ device=device
928
+ )
929
+
930
+ # WDCE loss (with attention mask to handle variable-length sequences)
931
+ wdce_loss = loss_wdce(
932
+ policy_model,
933
+ log_rnd_batch,
934
+ x_batch,
935
+ num_replicates=args.wdce_num_replicates,
936
+ centering=args.centering,
937
+ attn_mask=attn_mask # Pass attention mask to avoid computing loss on padding
938
+ )
939
+
940
+ # KL loss
941
+ mask_index = policy_model.mask_index
942
+ lamda = torch.rand(x_batch.shape[0], device=device)
943
+ sigma_kl = -torch.log1p(-(1 - eps) * lamda)
944
+ masked_index = torch.rand(*x_batch.shape, device=device) < lamda[..., None]
945
+ perturbed_batch = torch.where(masked_index, mask_index, x_batch)
946
+ # Use the actual attention mask (not all ones) to handle variable-length sequences
947
+ attn_mask_kl = attn_mask.to(device)
948
+
949
+ kl_loss = td3b_loss_fn.compute_kl_loss(
950
+ policy_model,
951
+ perturbed_batch,
952
+ attn_mask_kl,
953
+ sigma_kl
954
+ )
955
+
956
+ # Contrastive loss (if we have multiple directions)
957
+ if len(torch.unique(directional_labels_batch)) > 1:
958
+ embeddings = extract_embeddings_from_mdlm(
959
+ policy_model,
960
+ x_batch,
961
+ pool_method=args.embedding_pool_method
962
+ )
963
+
964
+ debug_mode = (epoch < 3) or (epoch > 0 and batch_contrastive_losses and batch_contrastive_losses[-1] < 1e-6)
965
+
966
+ total_loss, loss_dict = td3b_loss_fn.compute_loss(
967
+ wdce_loss,
968
+ embeddings,
969
+ directional_labels_batch,
970
+ kl_loss=kl_loss,
971
+ debug=debug_mode
972
+ )
973
+ else:
974
+ # Only WDCE + KL if no contrastive
975
+ total_loss = wdce_loss + args.kl_beta * kl_loss
976
+ loss_dict = {
977
+ 'total_loss': total_loss.item(),
978
+ 'wdce_loss': wdce_loss.item(),
979
+ 'contrastive_loss': 0.0,
980
+ 'kl_loss': kl_loss.item()
981
+ }
982
+
983
+ # Scale loss for gradient accumulation
984
+ scaled_loss = total_loss / args.gradient_accumulation_steps
985
+ scaled_loss.backward()
986
+
987
+ # Accumulate losses
988
+ epoch_loss += loss_dict['total_loss']
989
+ epoch_wdce_loss += loss_dict['wdce_loss']
990
+ epoch_contrastive_loss += loss_dict['contrastive_loss']
991
+ epoch_kl_loss += loss_dict['kl_loss']
992
+
993
+ # Gradient accumulation
994
+ if (batch_idx + 1) % args.gradient_accumulation_steps == 0 or (batch_idx + 1) == num_batches:
995
+ if args.grad_clip:
996
+ torch.nn.utils.clip_grad_norm_(policy_model.parameters(), args.gradnorm_clip)
997
+ optimizer.step()
998
+ optimizer.zero_grad()
999
+
1000
+ # Average losses
1001
+ epoch_loss /= num_batches
1002
+ epoch_wdce_loss /= num_batches
1003
+ epoch_contrastive_loss /= num_batches
1004
+ epoch_kl_loss /= num_batches
1005
+
1006
+ batch_losses.append(epoch_loss)
1007
+ batch_wdce_losses.append(epoch_wdce_loss)
1008
+ batch_contrastive_losses.append(epoch_contrastive_loss)
1009
+ batch_kl_losses.append(epoch_kl_loss)
1010
+
1011
+ # Validation
1012
+ if val_dataset is not None and (epoch + 1) % args.validate_every_n_epochs == 0:
1013
+ val_metrics = run_validation(
1014
+ policy_model,
1015
+ multi_target_affinity,
1016
+ directional_oracle,
1017
+ tokenizer,
1018
+ val_dataset,
1019
+ args,
1020
+ epoch,
1021
+ device,
1022
+ protein_token_cache=protein_token_cache
1023
+ )
1024
+
1025
+ # Log to W&B
1026
+ wandb.log({
1027
+ "epoch": epoch,
1028
+ "val/affinity_mean": val_metrics['affinity_mean'],
1029
+ "val/affinity_std": val_metrics['affinity_std'],
1030
+ "val/gated_reward_mean": val_metrics['gated_reward_mean'],
1031
+ "val/gated_reward_std": val_metrics['gated_reward_std'],
1032
+ "val/direction_oracle_mean": val_metrics['direction_oracle_mean'],
1033
+ "val/direction_oracle_std": val_metrics['direction_oracle_std'],
1034
+ "val/consistency_reward_mean": val_metrics['consistency_reward_mean'],
1035
+ "val/consistency_reward_std": val_metrics['consistency_reward_std'],
1036
+ "val/consistency_agonist_mean": val_metrics['consistency_agonist_mean'],
1037
+ "val/consistency_antagonist_mean": val_metrics['consistency_antagonist_mean'],
1038
+ "val/valid_fraction_mean": val_metrics['valid_fraction_mean'],
1039
+ "val/direction_accuracy_mean": val_metrics['direction_accuracy_mean'],
1040
+ "val/direction_accuracy_std": val_metrics['direction_accuracy_std'],
1041
+ "val/success_rate_mean": val_metrics['success_rate_mean'],
1042
+ "val/success_rate_std": val_metrics['success_rate_std']
1043
+ })
1044
+
1045
+ # Save checkpoint
1046
+ if (epoch + 1) % args.save_every_n_epochs == 0:
1047
+ model_path = os.path.join(args.save_path, f'model_epoch_{epoch}.ckpt')
1048
+ save_model(policy_model, model_path, config=vars(args), epoch=epoch)
1049
+
1050
+ # Final save
1051
+ final_model_path = os.path.join(args.save_path, 'model_final.ckpt')
1052
+ save_model(policy_model, final_model_path, config=vars(args))
1053
+
1054
+ cleanup_wandb()
1055
+ logger.info(f"\n{SEPARATOR_LINE}")
1056
+ logger.info("Training completed!")
1057
+ logger.info(f"{SEPARATOR_LINE}\n")
1058
+
1059
+
1060
+ if __name__ == '__main__':
1061
+ main()
finetune_utils.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for TD3B finetuning and sampling."""
2
+
3
+ import logging
4
+ import os
5
+ import random
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional, Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.distributed as dist
13
+ import torch.nn.functional as F
14
+ import wandb
15
+ from torch.utils.data import DataLoader, TensorDataset
16
+ from tqdm import tqdm
17
+
18
+ from diffusion import Diffusion
19
+ from td3b.td3b_mcts import create_td3b_mcts
20
+ from td3b.td3b_scoring import TD3BRewardFunction
21
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
22
+ from utils.utils import sample_categorical_logits
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Standard checkpoint keys to try when loading
27
+ CHECKPOINT_KEYS = ("state_dict", "model_state_dict")
28
+
29
+
30
+ def to_one_hot(x_idx, num_classes=4):
31
+ oh = F.one_hot(x_idx.long(), num_classes=num_classes)
32
+ return oh.float()
33
+
34
+
35
+ def rnd(model, reward_model, batch_size, scale=1, device="cuda:0"):
36
+ r"""
37
+ Run random order sampling and compute the RND $\log\frac{dP^*}{dP^u}$ along the trajectory
38
+ reward_model: r(X)
39
+
40
+ return:
41
+ - x: the final samples, [B, D]
42
+ - log_rnd: the log RND along this trajectory, [B]
43
+ """
44
+ if hasattr(model, "module"):
45
+ model = model.module
46
+
47
+ x = torch.full((batch_size, model.length), model.vocab_size - 1).to(device=device, dtype=torch.int64)
48
+ batch_arange = torch.arange(batch_size, device=device)
49
+ jump_pos = torch.rand(x.shape, device=device).argsort(dim=-1)
50
+ # jump_times, jump_pos = torch.rand(x.shape, device=device).sort(dim=-1)
51
+ # jump_times: Unif[0,1] in increasing order
52
+ # jump_pos: random permutation of range(D)
53
+ log_rnd = torch.zeros(batch_size, device=device) # [B]
54
+ for d in range(model.length - 1, -1, -1):
55
+ # jump at time jump_times[:, d] at position jump_pos[:, d]
56
+ logits = model(x)[:, :, :-1] # [B, D, N-1]
57
+ update = sample_categorical_logits(logits[batch_arange, jump_pos[:, d]]) # [B]
58
+ if torch.is_grad_enabled(): # avoid issues with in-place operations
59
+ x = x.clone()
60
+ x[batch_arange, jump_pos[:, d]] = update
61
+ log_rnd += -np.log(model.vocab_size - 1) - logits[batch_arange, jump_pos[:, d], update]
62
+ log_rnd += scale * reward_model(x) # [B]
63
+ return x, log_rnd
64
+
65
+
66
+ @torch.no_grad()
67
+ def sampling(model, batch_size, rounds=1, device="cuda:0"):
68
+ """Any order autoregressive sampling"""
69
+ if hasattr(model, "module"):
70
+ model = model.module
71
+ batch_arange = torch.arange(batch_size, device=device)
72
+ all_samples = []
73
+ for _ in tqdm(range(rounds), leave=False):
74
+ x = torch.full((batch_size, model.length), model.vocab_size - 1).to(device=device, dtype=torch.int64)
75
+ jump_pos = torch.rand(x.shape, device=device).argsort(dim=-1)
76
+ # jump_times, jump_pos = torch.rand(x.shape, device=device).sort(dim=-1)
77
+ # jump_times: Unif[0,1] in increasing order
78
+ # jump_pos: random permutation of range(D)
79
+ for d in tqdm(range(model.length - 1, -1, -1), leave=False):
80
+ # jump at time jump_times[:, d] at position jump_pos[:, d]
81
+ logits = model.logits(x)[:, :, :-1] # [B, D, N-1], not log-softmaxed but fine
82
+ update = sample_categorical_logits(logits[batch_arange, jump_pos[:, d]]) # [B]
83
+ x[batch_arange, jump_pos[:, d]] = update
84
+ all_samples.append(x)
85
+ return torch.cat(all_samples) # (rounds * B, L)
86
+
87
+
88
+ def loss_ce(log_rnd):
89
+ """Cross entropy loss KL(P^*||P^u)"""
90
+ weights = log_rnd.detach().softmax(dim=-1)
91
+ return (log_rnd * weights).sum()
92
+
93
+
94
+ def loss_lv(log_rnd):
95
+ r"""Log variance loss Var_{P^\bar{u}}\log\frac{dP^*}{dP^u}"""
96
+ return log_rnd.var()
97
+
98
+
99
+ def loss_re_rf(log_rnd, const=0):
100
+ r"""Relative entropy loss KL(P^u||P^*) with REINFORCE trick"""
101
+ return (-log_rnd * (-log_rnd.detach() + const)).mean()
102
+
103
+
104
+ def loss_wdce(
105
+ policy_model,
106
+ log_rnd,
107
+ x,
108
+ num_replicates=16,
109
+ weight_func=lambda l: 1 / l,
110
+ eps=1e-3,
111
+ centering=False,
112
+ attn_mask=None,
113
+ ):
114
+ r"""
115
+ Weighted denoising cross entropy loss
116
+ X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
117
+
118
+ log_rnd: [B]; x: [B, L] (no mask)
119
+ num_replicates: R, number of replicates of each row in x
120
+ weight_func: w(lambda) for each sample, 1/lambda by default
121
+ attn_mask: [B, L] attention mask (1 for real tokens, 0 for padding) - IMPORTANT for variable-length sequences
122
+ """
123
+ mask_index = policy_model.mask_index
124
+ if hasattr(policy_model, "module"):
125
+ policy_model = policy_model.module
126
+
127
+ batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
128
+
129
+ batch_weights = log_rnd.detach_().softmax(dim=-1) # [B*R]
130
+ if centering:
131
+ batch_weights = batch_weights - batch_weights.mean(dim=-1, keepdim=True)
132
+
133
+ batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
134
+
135
+ lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
136
+ lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
137
+
138
+ masked_index = torch.rand(*batch.shape, device=batch.device) < lamda[..., None] # [B*R, D]
139
+ perturbed_batch = torch.where(masked_index, mask_index, batch)
140
+
141
+ # add time conditioning
142
+ t = lamda
143
+ sigma_t = -torch.log1p(-(1 - eps) * t)
144
+
145
+ # Use provided attention mask or create default (all ones for fixed-length)
146
+ if attn_mask is not None:
147
+ attn_mask = attn_mask.repeat_interleave(num_replicates, dim=0).to(policy_model.device)
148
+ else:
149
+ attn_mask = torch.ones_like(perturbed_batch).to(policy_model.device)
150
+
151
+ # compute logits
152
+ logits = policy_model(perturbed_batch, attn_mask=attn_mask, sigma=sigma_t)
153
+ losses = torch.zeros(*batch.shape, device=batch.device, dtype=logits.dtype) # [B*R, D]
154
+ losses[masked_index] = torch.gather(
155
+ input=logits[masked_index], dim=-1, index=batch[masked_index][..., None]
156
+ ).squeeze(-1)
157
+
158
+ # Apply attention mask to exclude padding tokens from loss computation.
159
+ losses = losses * attn_mask
160
+
161
+ return -((losses.sum(dim=-1) * lamda_weights * batch_weights).mean())
162
+
163
+
164
+ def loss_dce(model, x, weight_func=lambda l: 1 / l):
165
+ r"""
166
+ Denoising cross entropy loss, x [B, D] are ground truth samples
167
+ weight_func: w(lambda) for each sample, 1/lambda by default
168
+ """
169
+ lamda = torch.rand(x.shape[0], device=x.device) # [B]
170
+ lamda_weights = weight_func(lamda).clamp(max=1e5) # [B]
171
+ masked_index = torch.rand(*x.shape, device=x.device) < lamda[..., None] # [B, D]
172
+ perturbed_batch = torch.where(masked_index, model.vocab_size - 1, x)
173
+ logits = model(perturbed_batch)
174
+ losses = torch.zeros(*x.shape, device=x.device, dtype=logits.dtype) # [B, D]
175
+ losses[masked_index] = torch.gather(
176
+ input=logits[masked_index], dim=-1, index=x[masked_index][..., None]
177
+ ).squeeze(-1)
178
+ return -((losses.sum(dim=-1) * lamda_weights).mean())
179
+
180
+
181
+ def load_tokenizer(base_path: str) -> SMILES_SPE_Tokenizer:
182
+ """
183
+ Load the peptide tokenizer from the standard location.
184
+
185
+ Args:
186
+ base_path: Base directory path (e.g., 'To Be Added')
187
+
188
+ Returns:
189
+ Loaded SMILES_SPE_Tokenizer instance
190
+
191
+ Example:
192
+ >>> tokenizer = load_tokenizer('To Be Added')
193
+ """
194
+ base_path = Path(base_path)
195
+ vocab_path = base_path / "tr2d2-pep" / "tokenizer" / "new_vocab.txt"
196
+ spe_path = base_path / "tr2d2-pep" / "tokenizer" / "new_splits.txt"
197
+
198
+ if not vocab_path.exists():
199
+ raise FileNotFoundError(f"Vocabulary file not found: {vocab_path}")
200
+ if not spe_path.exists():
201
+ raise FileNotFoundError(f"SPE splits file not found: {spe_path}")
202
+
203
+ tokenizer = SMILES_SPE_Tokenizer(str(vocab_path), str(spe_path))
204
+ logger.info("Loaded tokenizer with vocab_size=%s", tokenizer.vocab_size)
205
+
206
+ return tokenizer
207
+
208
+
209
+ def load_checkpoint(
210
+ checkpoint_path: str,
211
+ model: torch.nn.Module,
212
+ device: torch.device,
213
+ strict: bool = True,
214
+ ) -> Dict[str, Any]:
215
+ """
216
+ Load model weights from checkpoint with automatic key detection.
217
+
218
+ Handles different checkpoint formats:
219
+ - {'state_dict': ...}
220
+ - {'model_state_dict': ...}
221
+ - Direct state_dict
222
+
223
+ Args:
224
+ checkpoint_path: Path to checkpoint file
225
+ model: Model to load weights into
226
+ device: Device to load checkpoint onto
227
+ strict: Whether to strictly enforce state_dict keys match
228
+
229
+ Returns:
230
+ Full checkpoint dictionary (for accessing metadata like epoch, config, etc.)
231
+
232
+ Raises:
233
+ FileNotFoundError: If checkpoint file doesn't exist
234
+ RuntimeError: If checkpoint loading fails
235
+
236
+ Example:
237
+ >>> checkpoint = load_checkpoint('model.ckpt', model, device, strict=False)
238
+ >>> if 'epoch' in checkpoint:
239
+ >>> print(f"Loaded from epoch {checkpoint['epoch']}")
240
+ """
241
+ if not os.path.exists(checkpoint_path):
242
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
243
+
244
+ logger.info("Loading checkpoint from: %s", checkpoint_path)
245
+ checkpoint = torch.load(checkpoint_path, map_location=device)
246
+
247
+ # Try to find state_dict in standard checkpoint keys
248
+ state_dict = None
249
+ for key in CHECKPOINT_KEYS:
250
+ if key in checkpoint:
251
+ state_dict = checkpoint[key]
252
+ logger.info("Found state_dict at checkpoint key: '%s'", key)
253
+ break
254
+
255
+ # If not found in standard keys, assume checkpoint IS the state_dict
256
+ if state_dict is None:
257
+ state_dict = checkpoint
258
+ logger.info("Loading checkpoint as direct state_dict")
259
+
260
+ # Load state dict into model
261
+ try:
262
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
263
+ if not strict and (incompatible_keys.missing_keys or incompatible_keys.unexpected_keys):
264
+ logger.warning("Incompatible keys when loading checkpoint:")
265
+ if incompatible_keys.missing_keys:
266
+ logger.warning(" Missing keys: %s...", incompatible_keys.missing_keys[:5])
267
+ if incompatible_keys.unexpected_keys:
268
+ logger.warning(" Unexpected keys: %s...", incompatible_keys.unexpected_keys[:5])
269
+ else:
270
+ logger.info("Checkpoint loaded successfully")
271
+ except Exception as exc:
272
+ raise RuntimeError(f"Failed to load checkpoint: {exc}")
273
+
274
+ return checkpoint
275
+
276
+
277
+ def initialize_device(device_str: str = "cuda") -> torch.device:
278
+ """
279
+ Initialize compute device with fallback to CPU if CUDA unavailable or invalid.
280
+
281
+ Args:
282
+ device_str: Requested device ('cuda', 'cuda:0', 'cpu', or 'auto')
283
+
284
+ Returns:
285
+ Torch device object
286
+
287
+ Example:
288
+ >>> device = initialize_device('cuda')
289
+ >>> print(device) # cuda:0 or cpu
290
+ """
291
+ if device_str is None or str(device_str).lower() == "auto":
292
+ device_str = "cuda:0" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu"
293
+
294
+ try:
295
+ device = torch.device(device_str)
296
+ except Exception as exc:
297
+ logger.warning("Invalid device '%s': %s. Falling back to CPU.", device_str, exc)
298
+ return torch.device("cpu")
299
+
300
+ if device.type != "cuda":
301
+ logger.info("Using device: %s", device)
302
+ return device
303
+
304
+ if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
305
+ logger.warning("CUDA requested but not available, falling back to CPU")
306
+ return torch.device("cpu")
307
+
308
+ index = device.index if device.index is not None else 0
309
+ if index < 0 or index >= torch.cuda.device_count():
310
+ logger.warning(
311
+ "CUDA device %s requested but only %d visible; using cuda:0",
312
+ index,
313
+ torch.cuda.device_count(),
314
+ )
315
+ device = torch.device("cuda:0")
316
+
317
+ logger.info("Using device: %s (%s)", device, torch.cuda.get_device_name(device.index or 0))
318
+ return device
319
+
320
+
321
+ def create_output_directory(base_path: str, run_name: str, add_timestamp: bool = True) -> str:
322
+ """
323
+ Create output directory for saving results.
324
+
325
+ Args:
326
+ base_path: Base directory (e.g., 'To Be Added')
327
+ run_name: Name for this training run
328
+ add_timestamp: Whether to append timestamp to run_name
329
+
330
+ Returns:
331
+ Path to created output directory
332
+
333
+ Example:
334
+ >>> save_path = create_output_directory('To Be Added', 'my_run')
335
+ >>> # Creates: To Be Added
336
+ """
337
+ if add_timestamp:
338
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
339
+ dir_name = f"{run_name}_{timestamp}"
340
+ else:
341
+ dir_name = run_name
342
+
343
+ output_dir = os.path.join(base_path, "tr2d2-pep", "results", dir_name)
344
+ os.makedirs(output_dir, exist_ok=True)
345
+
346
+ logger.info("Created output directory: %s", output_dir)
347
+ return output_dir
348
+
349
+
350
+ def save_model(
351
+ model: torch.nn.Module,
352
+ save_path: str,
353
+ config: Optional[Dict[str, Any]] = None,
354
+ epoch: Optional[int] = None,
355
+ optimizer_state: Optional[Dict] = None,
356
+ ) -> None:
357
+ """
358
+ Save model checkpoint with optional metadata.
359
+
360
+ Args:
361
+ model: Model to save
362
+ save_path: Path to save checkpoint
363
+ config: Optional configuration dictionary to save
364
+ epoch: Optional epoch number
365
+ optimizer_state: Optional optimizer state dict
366
+
367
+ Example:
368
+ >>> save_model(model, 'checkpoint.ckpt', config=vars(args), epoch=10)
369
+ """
370
+ checkpoint = {"model_state_dict": model.state_dict()}
371
+
372
+ if config is not None:
373
+ checkpoint["config"] = config
374
+ if epoch is not None:
375
+ checkpoint["epoch"] = epoch
376
+ if optimizer_state is not None:
377
+ checkpoint["optimizer_state_dict"] = optimizer_state
378
+
379
+ torch.save(checkpoint, save_path)
380
+ logger.info("Model saved: %s", save_path)
381
+
382
+
383
+ def setup_wandb(project: str, name: str, config: Dict[str, Any], entity: Optional[str] = None) -> None:
384
+ """
385
+ Initialize Weights & Biases logging.
386
+
387
+ Args:
388
+ project: W&B project name
389
+ name: Run name
390
+ config: Configuration dictionary to log
391
+ entity: Optional W&B team/entity name
392
+
393
+ Example:
394
+ >>> setup_wandb('my-project', 'run1', vars(args), entity='my-team')
395
+ """
396
+ wandb_config = {
397
+ "project": project,
398
+ "name": name,
399
+ "config": config,
400
+ }
401
+
402
+ if entity:
403
+ wandb_config["entity"] = entity
404
+
405
+ wandb.init(**wandb_config)
406
+ logger.info("Initialized W&B: project=%s, run=%s", project, name)
407
+
408
+
409
+ def cleanup_wandb() -> None:
410
+ """Finish W&B logging session."""
411
+ wandb.finish()
412
+ logger.info("Finished W&B logging")
413
+
414
+
415
+ def get_mask_index(tokenizer: SMILES_SPE_Tokenizer) -> int:
416
+ """
417
+ Get mask token index from tokenizer.
418
+
419
+ Args:
420
+ tokenizer: Peptide tokenizer
421
+
422
+ Returns:
423
+ Mask token ID
424
+
425
+ Note:
426
+ Standardizes mask index retrieval across different code paths.
427
+ """
428
+ if hasattr(tokenizer, "mask_token_id"):
429
+ return tokenizer.mask_token_id
430
+ return tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
431
+
432
+
433
+ def create_mcts_instance(
434
+ args,
435
+ policy_model: Diffusion,
436
+ reward_function: TD3BRewardFunction,
437
+ tokenizer: SMILES_SPE_Tokenizer,
438
+ buffer_size: Optional[int] = None,
439
+ ) -> Any:
440
+ """
441
+ Create TD3B MCTS instance with standardized configuration.
442
+
443
+ Args:
444
+ args: Training arguments
445
+ policy_model: Diffusion policy model
446
+ reward_function: TD3B reward function
447
+ tokenizer: Peptide tokenizer
448
+ buffer_size: Optional buffer size (uses args.buffer_size if None)
449
+
450
+ Returns:
451
+ TD3B_MCTS instance
452
+
453
+ Example:
454
+ >>> mcts = create_mcts_instance(args, model, reward_func, tokenizer)
455
+ """
456
+ if hasattr(args, "no_mcts") and args.no_mcts:
457
+ logger.info("MCTS disabled (--no_mcts flag)")
458
+ return None
459
+
460
+ # Get mask index using standardized method
461
+ mask_index = get_mask_index(tokenizer)
462
+
463
+ # Use provided buffer_size or fall back to args
464
+ if buffer_size is None:
465
+ buffer_size = getattr(args, "buffer_size", 50)
466
+
467
+ mcts = create_td3b_mcts(
468
+ args=args,
469
+ diffusion_model=policy_model,
470
+ td3b_reward_function=reward_function,
471
+ alpha=getattr(args, "alpha", 0.1),
472
+ mask_index=mask_index,
473
+ buffer_size=buffer_size,
474
+ tokenizer=tokenizer,
475
+ )
476
+
477
+ logger.info("Created TD3B MCTS (buffer_size=%s, alpha=%s)", buffer_size, args.alpha)
478
+ return mcts
479
+
480
+
481
+ def create_reward_function(
482
+ affinity_predictor,
483
+ directional_oracle,
484
+ target_direction: float,
485
+ target_protein_tokens: torch.Tensor,
486
+ tokenizer: SMILES_SPE_Tokenizer,
487
+ device: torch.device,
488
+ min_affinity_threshold: float = 0.0,
489
+ use_confidence_weighting: bool = True,
490
+ temperature: float = 0.1,
491
+ ) -> TD3BRewardFunction:
492
+ """
493
+ Create TD3B reward function with standardized parameters.
494
+
495
+ Args:
496
+ affinity_predictor: Binding affinity prediction model
497
+ directional_oracle: Directional prediction oracle
498
+ target_direction: Target direction (1.0 for agonist, -1.0 for antagonist)
499
+ target_protein_tokens: Protein target tokens
500
+ tokenizer: Peptide tokenizer
501
+ device: Compute device
502
+ min_affinity_threshold: Minimum affinity for allosteric control
503
+ use_confidence_weighting: Whether to use confidence weighting
504
+ temperature: Temperature for sigmoid gating
505
+
506
+ Returns:
507
+ TD3BRewardFunction instance
508
+
509
+ Example:
510
+ >>> reward_func = create_reward_function(
511
+ ... affinity_pred, oracle, 1.0, target_tokens,
512
+ ... tokenizer, device, min_affinity_threshold=0.5
513
+ ... )
514
+ """
515
+ reward_func = TD3BRewardFunction(
516
+ affinity_predictor=affinity_predictor,
517
+ directional_oracle=directional_oracle,
518
+ target_direction=target_direction,
519
+ target_protein_tokens=target_protein_tokens,
520
+ peptide_tokenizer=tokenizer,
521
+ device=device,
522
+ min_affinity_threshold=min_affinity_threshold,
523
+ use_confidence_weighting=use_confidence_weighting,
524
+ temperature=temperature,
525
+ )
526
+
527
+ logger.info(
528
+ "Created TD3B reward function (d*=%s, threshold=%s)", target_direction, min_affinity_threshold
529
+ )
530
+ return reward_func
531
+
532
+
533
+ def log_gpu_memory(stage: str = "") -> None:
534
+ """
535
+ Log current GPU memory usage.
536
+
537
+ Args:
538
+ stage: Optional stage description for logging context
539
+
540
+ Example:
541
+ >>> log_gpu_memory("After model loading")
542
+ """
543
+ if torch.cuda.is_available():
544
+ allocated = torch.cuda.memory_allocated() / 1024**3 # GB
545
+ reserved = torch.cuda.memory_reserved() / 1024**3 # GB
546
+ stage_str = f" [{stage}]" if stage else ""
547
+ logger.info(
548
+ "GPU Memory%s: %.2fGB allocated, %.2fGB reserved",
549
+ stage_str,
550
+ allocated,
551
+ reserved,
552
+ )
553
+
554
+
555
+ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
556
+ """
557
+ Count total and trainable parameters in model.
558
+
559
+ Args:
560
+ model: PyTorch model
561
+
562
+ Returns:
563
+ Tuple of (total_params, trainable_params)
564
+
565
+ Example:
566
+ >>> total, trainable = count_parameters(model)
567
+ >>> print(f"Total: {total:,}, Trainable: {trainable:,}")
568
+ """
569
+ total_params = sum(p.numel() for p in model.parameters())
570
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
571
+ return total_params, trainable_params
inference.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ TD3B Inference Script
4
+ Generate directional binders for target proteins using a finetuned TD3B model.
5
+
6
+ Usage:
7
+ python inference.py \
8
+ --ckpt_path checkpoints/td3b.ckpt \
9
+ --val_csv data/test.csv \
10
+ --save_path results/ \
11
+ --seed 42
12
+ """
13
+ import argparse
14
+ import os
15
+ import sys
16
+ import logging
17
+ from typing import Dict, List, Tuple
18
+
19
+ import numpy as np
20
+ import pandas as pd
21
+ import torch
22
+
23
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ if ROOT_DIR not in sys.path:
25
+ sys.path.insert(0, ROOT_DIR)
26
+
27
+ from diffusion import Diffusion
28
+ from configs.finetune_config import (
29
+ DiffusionConfig, RoFormerConfig, NoiseConfig,
30
+ TrainingConfig, SamplingConfig, EvalConfig, OptimConfig, MCTSConfig,
31
+ )
32
+ from finetune_utils import load_tokenizer, create_reward_function
33
+ from td3b.direction_oracle import DirectionalOracle
34
+ from td3b.td3b_scoring import create_td3b_reward_function
35
+ from utils.app import PeptideAnalyzer
36
+
37
+ logger = logging.getLogger(__name__)
38
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
39
+
40
+ # ─── Defaults ─────────────────────────────────────────────────────────────────
41
+ DEFAULTS = dict(
42
+ seq_length=200,
43
+ sampling_eps=1e-3,
44
+ total_num_steps=128,
45
+ hidden_dim=768,
46
+ num_layers=8,
47
+ num_heads=8,
48
+ alpha=0.1,
49
+ min_affinity_threshold=0.0,
50
+ sigmoid_temperature=0.1,
51
+ num_pool=32,
52
+ val_samples_per_target=8,
53
+ )
54
+
55
+
56
+ def load_model(ckpt_path: str, device: torch.device):
57
+ """Load finetuned TD3B model from checkpoint."""
58
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
59
+ state_dict = ckpt.get("model_state_dict") or ckpt.get("state_dict") or ckpt
60
+ config = ckpt.get("config") or {}
61
+
62
+ tokenizer = load_tokenizer(ROOT_DIR)
63
+
64
+ cfg = DiffusionConfig(
65
+ roformer=RoFormerConfig(
66
+ hidden_size=config.get("hidden_dim", 768),
67
+ n_layers=config.get("num_layers", 8),
68
+ n_heads=config.get("num_heads", 8),
69
+ ),
70
+ noise=NoiseConfig(),
71
+ training=TrainingConfig(sampling_eps=1e-3),
72
+ sampling=SamplingConfig(steps=128, sampling_eps=1e-3),
73
+ eval_cfg=EvalConfig(),
74
+ optim=OptimConfig(lr=3e-4),
75
+ mcts=MCTSConfig(),
76
+ )
77
+
78
+ model = Diffusion(config=cfg, tokenizer=tokenizer, device=device).to(device)
79
+ model.load_state_dict(state_dict, strict=False)
80
+ model.eval()
81
+ model.tokenizer = tokenizer
82
+ return model, tokenizer
83
+
84
+
85
+ def sample_sequences(model, batch_size: int, seq_length: int, num_steps: int, eps: float = 1e-5):
86
+ """Sample sequences from the diffusion model."""
87
+ x = model.sample_prior(batch_size, seq_length).to(model.device, dtype=torch.long)
88
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=model.device)
89
+ dt = torch.tensor((1 - eps) / num_steps, device=model.device)
90
+
91
+ for i in range(num_steps):
92
+ t = timesteps[i] * torch.ones(x.shape[0], 1, device=model.device)
93
+ _, x = model.single_reverse_step(x, t=t, dt=dt)
94
+ x = x.to(model.device)
95
+
96
+ # Remove remaining masks
97
+ mask_pos = (x == model.mask_index)
98
+ if mask_pos.any():
99
+ t = timesteps[-2] * torch.ones(x.shape[0], 1, device=model.device)
100
+ _, x = model.single_noise_removal(x, t=t, dt=dt)
101
+ x = x.to(model.device)
102
+
103
+ return x
104
+
105
+
106
+ def score_sequences(reward_model, sequences: List[str]):
107
+ """Score sequences with the TD3B reward function."""
108
+ result = reward_model(sequences)
109
+ if isinstance(result, tuple):
110
+ rewards, info = result
111
+ return (
112
+ np.asarray(rewards),
113
+ np.asarray(info.get("affinities", rewards)),
114
+ np.asarray(info.get("directions", np.zeros_like(rewards))),
115
+ np.asarray(info.get("confidences", np.ones_like(rewards))),
116
+ )
117
+ rewards = np.asarray(result)
118
+ return rewards, rewards, np.zeros_like(rewards), np.ones_like(rewards)
119
+
120
+
121
+ def main():
122
+ parser = argparse.ArgumentParser(description="TD3B Inference")
123
+ parser.add_argument("--ckpt_path", type=str, required=True, help="Path to TD3B checkpoint")
124
+ parser.add_argument("--val_csv", type=str, required=True, help="CSV with Target_Sequence, Ligand_Sequence, label columns")
125
+ parser.add_argument("--save_path", type=str, default="results", help="Output directory")
126
+ parser.add_argument("--device", type=str, default="cuda:0")
127
+ parser.add_argument("--seed", type=int, default=42)
128
+ parser.add_argument("--num_pool", type=int, default=32, help="Pool size for candidate generation")
129
+ parser.add_argument("--val_samples_per_target", type=int, default=8, help="Samples to keep per target-direction")
130
+ parser.add_argument("--resample_alpha", type=float, default=0.1, help="Temperature for weighted resampling")
131
+ parser.add_argument("--direction_oracle_ckpt", type=str, default=None)
132
+ parser.add_argument("--direction_oracle_tr2d2_checkpoint", type=str, default=None)
133
+ args = parser.parse_args()
134
+
135
+ # Setup
136
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
137
+ torch.manual_seed(args.seed)
138
+ np.random.seed(args.seed)
139
+ os.makedirs(args.save_path, exist_ok=True)
140
+
141
+ analyzer = PeptideAnalyzer()
142
+
143
+ # Load model
144
+ logger.info(f"Loading model from {args.ckpt_path}")
145
+ model, tokenizer = load_model(args.ckpt_path, device)
146
+
147
+ # Load targets
148
+ logger.info(f"Loading targets from {args.val_csv}")
149
+ df = pd.read_csv(args.val_csv)
150
+ targets = []
151
+ for _, row in df.iterrows():
152
+ targets.append({
153
+ "target_seq": row["Target_Sequence"],
154
+ "target_uid": row.get("Target_UniProt_ID", ""),
155
+ "binder_seq": row.get("Ligand_Sequence", ""),
156
+ "label": row.get("label", ""),
157
+ "seq_length": min(len(row.get("Ligand_SMILES", "x" * 200)), 200),
158
+ })
159
+
160
+ # Build reward function for each target
161
+ logger.info("Building reward functions...")
162
+ oracle_ckpt = args.direction_oracle_ckpt or os.path.join(ROOT_DIR, "checkpoints", "direction_oracle.pt")
163
+ oracle_tr2d2 = args.direction_oracle_tr2d2_checkpoint or os.path.join(ROOT_DIR, "checkpoints", "pretrained.ckpt")
164
+
165
+ records = []
166
+
167
+ for tidx, target in enumerate(targets):
168
+ for d_star, d_name in [(1.0, "agonist"), (-1.0, "antagonist")]:
169
+ logger.info(f"[{tidx+1}/{len(targets)}] Target {target['target_uid']} direction={d_name}")
170
+
171
+ # Create reward function
172
+ try:
173
+ reward_model = create_reward_function(
174
+ base_path=ROOT_DIR,
175
+ tokenizer=tokenizer,
176
+ target_protein_seq=target["target_seq"],
177
+ target_direction="agonist" if d_star > 0 else "antagonist",
178
+ device=device,
179
+ emb_model=model.backbone,
180
+ directional_oracle_checkpoint=oracle_ckpt,
181
+ direction_oracle_tr2d2_checkpoint=oracle_tr2d2,
182
+ )
183
+ except Exception as e:
184
+ logger.warning(f"Failed to create reward for {target['target_uid']}: {e}")
185
+ continue
186
+
187
+ # Generate pool of candidates
188
+ target_length = target.get("seq_length", 200)
189
+ x_pool = sample_sequences(model, args.num_pool, target_length, 128)
190
+ sequences = tokenizer.batch_decode(x_pool)
191
+
192
+ # Check validity
193
+ valid_mask = np.array([analyzer.is_peptide(seq) for seq in sequences])
194
+
195
+ # Score all
196
+ gated_rewards, affinities, directions, confidences = score_sequences(reward_model, sequences)
197
+ direction_accuracy = ((directions > 0.5).astype(float) if d_star > 0
198
+ else (directions < 0.5).astype(float))
199
+
200
+ # Weighted resampling (Algorithm 2)
201
+ finite = np.isfinite(gated_rewards)
202
+ if finite.any():
203
+ rewards_t = torch.as_tensor(gated_rewards[finite], device=device)
204
+ alpha = max(args.resample_alpha, 1e-6)
205
+ weights = torch.softmax(rewards_t / alpha, dim=0)
206
+ idx = torch.multinomial(weights, num_samples=args.val_samples_per_target, replacement=True)
207
+ valid_idx = np.where(finite)[0]
208
+ chosen = valid_idx[idx.cpu().numpy()]
209
+ else:
210
+ chosen = np.arange(min(args.val_samples_per_target, len(sequences)))
211
+
212
+ # Save only VALID resampled samples
213
+ for i in chosen:
214
+ is_valid = bool(valid_mask[i]) if valid_mask.size else False
215
+ if not is_valid:
216
+ continue # Skip invalid samples
217
+
218
+ records.append({
219
+ "target": target["target_seq"][:20],
220
+ "target_uid": target["target_uid"],
221
+ "sequence": sequences[i],
222
+ "target_direction": d_star,
223
+ "direction_name": d_name,
224
+ "is_valid": True,
225
+ "affinity": float(affinities[i]),
226
+ "gated_reward": float(gated_rewards[i]),
227
+ "direction_oracle": float(directions[i]),
228
+ "direction_accuracy": float(direction_accuracy[i]),
229
+ })
230
+
231
+ # Save results
232
+ out_df = pd.DataFrame(records)
233
+ out_path = os.path.join(args.save_path, f"td3b_results_seed{args.seed}.csv")
234
+ out_df.to_csv(out_path, index=False)
235
+
236
+ # Print summary
237
+ if len(out_df) > 0:
238
+ dp = out_df[out_df["target_direction"] == 1.0]
239
+ dm = out_df[out_df["target_direction"] == -1.0]
240
+ logger.info(f"\n{'='*60}")
241
+ logger.info(f"Results saved to {out_path} ({len(out_df)} valid samples)")
242
+ logger.info(f" Aff(d*=+1) = {dp['affinity'].mean():.2f}" if len(dp) else " No agonist samples")
243
+ logger.info(f" Aff(d*=-1) = {dm['affinity'].mean():.2f}" if len(dm) else " No antagonist samples")
244
+ logger.info(f" DA(d*=+1) = {dp['direction_accuracy'].mean():.3f}" if len(dp) else "")
245
+ logger.info(f" DA(d*=-1) = {dm['direction_accuracy'].mean():.3f}" if len(dm) else "")
246
+ logger.info(f" Gated Reward = {out_df['gated_reward'].mean():.2f}")
247
+ logger.info(f"{'='*60}")
248
+ else:
249
+ logger.warning("No valid samples generated.")
250
+
251
+
252
+ if __name__ == "__main__":
253
+ main()
launch_multi_target.sh ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Multi-Target TD3B Training Launch Script
4
+ # Trains TD3B on multiple protein targets with random sampling strategy
5
+
6
+ # ============================================================================
7
+ # Configuration
8
+ # ============================================================================
9
+
10
+ # Paths — update these to your local paths
11
+ BASE_PATH="/path/to/TD3B"
12
+ PRETRAINED_CHECKPOINT="${BASE_PATH}/checkpoints/pretrained.ckpt"
13
+ TRAIN_CSV="${BASE_PATH}/data/train.csv"
14
+ VAL_CSV="${BASE_PATH}/data/test.csv" # Optional: create validation split
15
+
16
+ # Run configuration
17
+ RUN_NAME="multi_target_td3b" # Timestamp will be added automatically
18
+ DEVICE="cuda:0"
19
+ # Multi-target sampling
20
+ TARGETS_PER_MCTS=2 # Number of targets sampled per MCTS round (K)
21
+ RESAMPLE_TARGETS_EVERY=1 # Resample targets every N epochs
22
+
23
+ # Training hyperparameters
24
+ NUM_EPOCHS=200
25
+ LEARNING_RATE=3e-4
26
+ TRAIN_BATCH_SIZE=1 # Small batch size to prevent OOM
27
+ GRADIENT_ACCUMULATION_STEPS=32 # Effective batch size = 16 * 4 = 64
28
+ RESAMPLE_EVERY=10 # Run MCTS every N epochs
29
+ SAVE_EVERY=20
30
+ VALIDATE_EVERY=20
31
+ RESET_TREE_EVERY=50
32
+
33
+ # MCTS hyperparameters (aligned with v1, but can reduce for multi-target)
34
+ NUM_ITER=20 # MCTS iterations per resample (v1 default: 50, reduced for multi-target)
35
+ NUM_CHILDREN=16 # Children per MCTS expansion
36
+ BUFFER_SIZE=50 # Pareto buffer size (v1 default: 50)
37
+ REPLAY_BUFFER_SIZE=1000 # Recommended range: 500-5000 (0 disables replay)
38
+ REPLAY_BUFFER_STRATEGY="fifo" # fifo or random
39
+ ALPHA=0.1 # Temperature for importance weighting
40
+ EXPLORATION=1.0 # UCB exploration constant
41
+
42
+ # TD3B hyperparameters (aligned with v1 defaults)
43
+ CONTRASTIVE_WEIGHT=0.1 # v1 default: 0.1
44
+ CONTRASTIVE_MARGIN=1.0
45
+ KL_BETA=0.1 # v1 default: 0.1
46
+ MIN_AFFINITY_THRESHOLD=0.0 # CRITICAL: minimum affinity for allosteric control
47
+ SIGMOID_TEMPERATURE=0.1
48
+
49
+ # Validation
50
+ VAL_SAMPLES_PER_TARGET=20 # Number of sequences per target during validation
51
+
52
+ # Directional oracle (GPCR classifier)
53
+ ORACLE_CKPT="${BASE_PATH}/checkpoints/direction_oracle.pt"
54
+ ORACLE_TR2D2_CHECKPOINT="${BASE_PATH}/checkpoints/pretrained.ckpt"
55
+ ORACLE_TOKENIZER_VOCAB="${BASE_PATH}/tokenizer/new_vocab.txt"
56
+ ORACLE_TOKENIZER_SPLITS="${BASE_PATH}/tokenizer/new_splits.txt"
57
+ ORACLE_ESM_NAME="facebook/esm2_t33_650M_UR50D"
58
+ ORACLE_ESM_CACHE_DIR="" # Optional: set to a cache dir path
59
+ ORACLE_ESM_LOCAL_FILES_ONLY=0 # Set to 1 to avoid network access
60
+ ORACLE_MAX_LIGAND_LENGTH=768
61
+ ORACLE_MAX_PROTEIN_LENGTH=1024
62
+ ORACLE_D_MODEL=256
63
+ ORACLE_N_HEADS=4
64
+ ORACLE_N_SELF_ATTN_LAYERS=1
65
+ ORACLE_N_BMCA_LAYERS=2
66
+ ORACLE_DROPOUT=0.3
67
+
68
+ EXTRA_ORACLE_ARGS=""
69
+ if [ -n "$ORACLE_ESM_CACHE_DIR" ]; then
70
+ EXTRA_ORACLE_ARGS="$EXTRA_ORACLE_ARGS --direction_oracle_esm_cache_dir $ORACLE_ESM_CACHE_DIR"
71
+ fi
72
+ if [ "$ORACLE_ESM_LOCAL_FILES_ONLY" -eq 1 ]; then
73
+ EXTRA_ORACLE_ARGS="$EXTRA_ORACLE_ARGS --direction_oracle_esm_local_files_only"
74
+ fi
75
+
76
+ # W&B (optional)
77
+ WANDB_PROJECT="tr2d2-multi-target"
78
+ WANDB_ENTITY="phos_zj"
79
+
80
+ # ============================================================================
81
+ # Launch Training
82
+ # ============================================================================
83
+
84
+ cd ${BASE_PATH}
85
+
86
+ echo "============================================================================"
87
+ echo "Multi-Target TD3B Training"
88
+ echo "============================================================================"
89
+ echo "Configuration:"
90
+ echo " - Targets per MCTS: ${TARGETS_PER_MCTS}"
91
+ echo " - Training batch size: ${TRAIN_BATCH_SIZE}"
92
+ echo " - Gradient accumulation: ${GRADIENT_ACCUMULATION_STEPS}"
93
+ echo " - Effective batch size: $((TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS))"
94
+ echo " - Epochs: ${NUM_EPOCHS}"
95
+ echo " - MCTS iterations: ${NUM_ITER}"
96
+ echo " - MCTS children: ${NUM_CHILDREN}"
97
+ echo " - Buffer size: ${BUFFER_SIZE}"
98
+ echo " - Replay buffer size: ${REPLAY_BUFFER_SIZE} (${REPLAY_BUFFER_STRATEGY})"
99
+ echo "============================================================================"
100
+ echo ""
101
+
102
+ # Build command
103
+ CMD="python finetune_multi_target.py \
104
+ --base_path ${BASE_PATH} \
105
+ --train_csv ${TRAIN_CSV} \
106
+ --pretrained_checkpoint ${PRETRAINED_CHECKPOINT} \
107
+ --run_name ${RUN_NAME} \
108
+ --device ${DEVICE} \
109
+ \
110
+ --targets_per_mcts ${TARGETS_PER_MCTS} \
111
+ --resample_targets_every ${RESAMPLE_TARGETS_EVERY} \
112
+ \
113
+ --num_epochs ${NUM_EPOCHS} \
114
+ --learning_rate ${LEARNING_RATE} \
115
+ --train_batch_size ${TRAIN_BATCH_SIZE} \
116
+ --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \
117
+ --resample_every_n_step ${RESAMPLE_EVERY} \
118
+ --save_every_n_epochs ${SAVE_EVERY} \
119
+ --validate_every_n_epochs ${VALIDATE_EVERY} \
120
+ --reset_every_n_step ${RESET_TREE_EVERY} \
121
+ \
122
+ --num_iter ${NUM_ITER} \
123
+ --num_children ${NUM_CHILDREN} \
124
+ --buffer_size ${BUFFER_SIZE} \
125
+ --replay_buffer_size ${REPLAY_BUFFER_SIZE} \
126
+ --replay_buffer_strategy ${REPLAY_BUFFER_STRATEGY} \
127
+ --alpha ${ALPHA} \
128
+ --exploration ${EXPLORATION} \
129
+ \
130
+ --contrastive_weight ${CONTRASTIVE_WEIGHT} \
131
+ --contrastive_margin ${CONTRASTIVE_MARGIN} \
132
+ --kl_beta ${KL_BETA} \
133
+ --min_affinity_threshold ${MIN_AFFINITY_THRESHOLD} \
134
+ --sigmoid_temperature ${SIGMOID_TEMPERATURE} \
135
+ \
136
+ --direction_oracle_ckpt ${ORACLE_CKPT} \
137
+ --direction_oracle_tr2d2_checkpoint ${ORACLE_TR2D2_CHECKPOINT} \
138
+ --direction_oracle_tokenizer_vocab ${ORACLE_TOKENIZER_VOCAB} \
139
+ --direction_oracle_tokenizer_splits ${ORACLE_TOKENIZER_SPLITS} \
140
+ --direction_oracle_esm_name ${ORACLE_ESM_NAME} \
141
+ --direction_oracle_max_ligand_length ${ORACLE_MAX_LIGAND_LENGTH} \
142
+ --direction_oracle_max_protein_length ${ORACLE_MAX_PROTEIN_LENGTH} \
143
+ --direction_oracle_d_model ${ORACLE_D_MODEL} \
144
+ --direction_oracle_n_heads ${ORACLE_N_HEADS} \
145
+ --direction_oracle_n_self_attn_layers ${ORACLE_N_SELF_ATTN_LAYERS} \
146
+ --direction_oracle_n_bmca_layers ${ORACLE_N_BMCA_LAYERS} \
147
+ --direction_oracle_dropout ${ORACLE_DROPOUT} \
148
+ ${EXTRA_ORACLE_ARGS} \
149
+ \
150
+ --val_samples_per_target ${VAL_SAMPLES_PER_TARGET} \
151
+ \
152
+ --grad_clip \
153
+ --gradnorm_clip 1.0 \
154
+ --wandb_project ${WANDB_PROJECT}"
155
+
156
+ # Add validation CSV if it exists
157
+ if [ -f "${VAL_CSV}" ]; then
158
+ CMD="${CMD} --val_csv ${VAL_CSV}"
159
+ echo "Validation CSV: ${VAL_CSV}"
160
+ else
161
+ echo "No validation CSV found (${VAL_CSV})"
162
+ echo "Skipping validation during training"
163
+ fi
164
+
165
+ # Add W&B entity if specified
166
+ if [ -n "${WANDB_ENTITY}" ]; then
167
+ CMD="${CMD} --wandb_entity ${WANDB_ENTITY}"
168
+ fi
169
+
170
+ echo ""
171
+ echo "Launching training..."
172
+ echo ""
173
+
174
+ # Execute
175
+ eval $CMD
noise_schedule.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ torch._C._jit_set_profiling_mode(False)
7
+ torch._C._jit_set_profiling_executor(False)
8
+ torch._C._jit_override_can_fuse_on_cpu(True)
9
+ torch._C._jit_override_can_fuse_on_gpu(True)
10
+
11
+ def get_noise(config, dtype=torch.float32):
12
+ if config.noise.type == 'geometric':
13
+ return GeometricNoise(config.noise.sigma_min, config.noise.sigma_max)
14
+ elif config.noise.type == 'loglinear':
15
+ return LogLinearNoise()
16
+ elif config.noise.type == 'cosine':
17
+ return CosineNoise()
18
+ elif config.noise.type == 'cosinesqr':
19
+ return CosineSqrNoise()
20
+ elif config.noise.type == 'linear':
21
+ return Linear(config.noise.sigma_min, config.noise.sigma_max, dtype)
22
+ else:
23
+ raise ValueError(f'{config.noise.type} is not a valid noise')
24
+
25
+
26
+ def binary_discretization(z):
27
+ z_hard = torch.sign(z)
28
+ z_soft = z / torch.norm(z, dim=-1, keepdim=True)
29
+ return z_soft + (z_hard - z_soft).detach()
30
+
31
+
32
+ class Noise(abc.ABC, nn.Module):
33
+ """
34
+ Baseline forward method to get the total + rate of noise at a timestep
35
+ """
36
+ def forward(self, t):
37
+ # Assume time goes from 0 to 1
38
+ return self.total_noise(t), self.rate_noise(t)
39
+
40
+
41
+ class CosineNoise(Noise):
42
+ def __init__(self, eps=1e-3):
43
+ super().__init__()
44
+ self.eps = eps
45
+
46
+ def rate_noise(self, t):
47
+ cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
48
+ sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
49
+ scale = torch.pi / 2
50
+ return scale * sin / (cos + self.eps)
51
+
52
+ def total_noise(self, t):
53
+ cos = torch.cos(t * torch.pi / 2)
54
+ return - torch.log(self.eps + (1 - self.eps) * cos)
55
+
56
+
57
+ class CosineSqrNoise(Noise):
58
+ def __init__(self, eps=1e-3):
59
+ super().__init__()
60
+ self.eps = eps
61
+
62
+ def rate_noise(self, t):
63
+ cos = (1 - self.eps) * (
64
+ torch.cos(t * torch.pi / 2) ** 2)
65
+ sin = (1 - self.eps) * torch.sin(t * torch.pi)
66
+ scale = torch.pi / 2
67
+ return scale * sin / (cos + self.eps)
68
+
69
+ def total_noise(self, t):
70
+ cos = torch.cos(t * torch.pi / 2) ** 2
71
+ return - torch.log(self.eps + (1 - self.eps) * cos)
72
+
73
+
74
+ class Linear(Noise):
75
+ def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
76
+ super().__init__()
77
+ self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
78
+ self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
79
+
80
+ def rate_noise(self):
81
+ return self.sigma_max - self.sigma_min
82
+
83
+ def total_noise(self, t):
84
+ return self.sigma_min + t * (self.sigma_max - self.sigma_min)
85
+
86
+ def importance_sampling_transformation(self, t):
87
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
88
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
89
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
90
+ return (sigma_t - self.sigma_min) / (
91
+ self.sigma_max - self.sigma_min)
92
+
93
+
94
+ class GeometricNoise(Noise):
95
+ def __init__(self, sigma_min=1e-3, sigma_max=1):
96
+ super().__init__()
97
+ self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
98
+
99
+ def rate_noise(self, t):
100
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
101
+ self.sigmas[1].log() - self.sigmas[0].log())
102
+
103
+ def total_noise(self, t):
104
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
105
+
106
+
107
+ class LogLinearNoise(Noise):
108
+ """Log Linear noise schedule.
109
+
110
+ Built such that 1 - 1/e^(n(t)) interpolates between 0 and
111
+ ~1 when t varies from 0 to 1. Total noise is
112
+ -log(1 - (1 - eps) * t), so the sigma will be
113
+ (1 - eps) * t.
114
+ """
115
+ def __init__(self, eps=1e-3):
116
+ super().__init__()
117
+ self.eps = eps
118
+ self.sigma_max = self.total_noise(torch.tensor(1.0))
119
+ self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
120
+
121
+ def rate_noise(self, t):
122
+ return (1 - self.eps) / (1 - (1 - self.eps) * t)
123
+
124
+ def total_noise(self, t):
125
+ return -torch.log1p(-(1 - self.eps) * t)
126
+
127
+ def importance_sampling_transformation(self, t):
128
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
129
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
130
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
131
+ t = - torch.expm1(- sigma_t) / (1 - self.eps)
132
+ return t
133
+
134
+ class LogPolyNoise(Noise):
135
+ """
136
+ Log Polynomial noise schedule for slower masking of peptide bond tokens
137
+ """
138
+ def __init__(self, eps=1e-3):
139
+ super().__init__()
140
+ self.eps = eps
141
+ self.sigma_max = self.total_noise(torch.tensor(1.0))
142
+ self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
143
+
144
+ def rate_noise(self, t):
145
+ # derivative of -log(1-t^w)
146
+ return ((3 * (t**2)) - self.eps) / (1 - (1 - self.eps) * (t**3))
147
+
148
+ def total_noise(self, t):
149
+ # -log(1-t^w)
150
+ return -torch.log1p(-(1 - self.eps) * (t**3))
peptide_mcts.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import random as rd
6
+ from utils.app import PeptideAnalyzer
7
+ from utils.timer import StepTimer
8
+ from scoring.scoring_functions import ScoringFunctions
9
+
10
+ import noise_schedule
11
+
12
+ ### for peptide multi-objective ###
13
+ def dominates(a, b):
14
+ a = np.asarray(a); b = np.asarray(b)
15
+ return np.all(a >= b) and np.any(a > b)
16
+
17
+ def dominated_by(a, b):
18
+ return dominates(b, a)
19
+
20
+
21
+ def updateParetoFront(paretoFront, node, scoreVector, totalSize=None, eps=1e-12):
22
+ """
23
+ Maintain a non-dominated set (Pareto front) of (node -> scoreVector).
24
+
25
+ - Accept 'node' iff it is NOT dominated by any node in the set.
26
+ - Remove any nodes that ARE dominated by 'node'.
27
+ - Skip insertion if an equal point already exists (within eps).
28
+ - If totalSize is given and the archive exceeds it, drop the item
29
+ with the smallest sum(scoreVector) as a simple tie-breaker.
30
+
31
+ Args:
32
+ paretoFront (dict): {node: scoreVector}
33
+ node: candidate node (used as dict key)
34
+ scoreVector (array-like): candidate scores (to be maximized)
35
+ totalSize (int|None): optional max size for the archive
36
+ eps (float): tolerance for equality/inequality checks
37
+
38
+ Returns:
39
+ dict: updated paretoFront
40
+ """
41
+ s = np.asarray(scoreVector, dtype=float)
42
+
43
+ def dominates(a, b):
44
+ # a >= b in all coords and > in at least one (with tolerance)
45
+ return np.all(a >= b - eps) and np.any(a > b + eps)
46
+
47
+ def equal(a, b):
48
+ return np.all(np.abs(a - b) <= eps)
49
+
50
+ # reject if candidate is dominated by any node already in the set
51
+ for v in paretoFront.values():
52
+ v = np.asarray(v, dtype=float)
53
+ if dominates(v, s):
54
+ return paretoFront # no change
55
+
56
+ # remove any nodes dominated by candidate node
57
+ survivors = {}
58
+ #has_equal = False
59
+ for k, v in paretoFront.items():
60
+ v_arr = np.asarray(v, dtype=float)
61
+ if dominates(s, v_arr):
62
+ continue # drop dominated incumbent
63
+ """if equal(s, v_arr):
64
+ has_equal = True # skip duplicate insertion later"""
65
+ survivors[k] = v_arr
66
+
67
+ # if an equal point exists, keep survivors as-is (no duplicate)
68
+ """if has_equal:
69
+ return survivors"""
70
+
71
+ # insert node
72
+ survivors[node] = s
73
+
74
+ # delete nodes if larger than total size
75
+ if totalSize is not None and totalSize > 0 and len(survivors) > totalSize:
76
+ # remove the item with the smallest sum(scoreVector)
77
+ keys = list(survivors.keys())
78
+ sums = np.array([np.sum(np.asarray(survivors[k], dtype=float)) for k in keys])
79
+ drop_idx = int(np.argmin(sums))
80
+ del survivors[keys[drop_idx]]
81
+
82
+ return survivors
83
+
84
+ ### BEGINNING OF NODE CLASS ###
85
+
86
+ class Node:
87
+ """
88
+ Node class: partially unmasked sequence
89
+ - parentNode: Node object at previous time step
90
+ - childNodes: set of M Node objects generated from sampling M distinct unmasking schemes
91
+ - totalReward: vector of cumulative rewards for all K objectives
92
+ - visits: number of times the node has been visited by an interation
93
+ - path: array of partially unmasked SMILES strings leading to the node from the completely masked root node
94
+ - timestep: the time step where the sequence was sampled
95
+ """
96
+ def __init__(self, args, tokens=None, log_rnd=None, log_policy_step=None, log_pretrained_step=None, parentNode=None, childNodes=None, totalReward=None, timestep=None):
97
+ self.args = args
98
+ self.parentNode = parentNode
99
+ # fixed child node list creation
100
+ self.childNodes = [] if childNodes is None else childNodes
101
+
102
+ self.log_rnd = log_rnd # stores the log_rnd up to that step
103
+
104
+ #self.log_p0 = 0 # stores the log probabiltiy of the unmasking step from the previous iteration
105
+ self.log_policy_step = log_policy_step # stores the log probability of the unmasking step under the current policy
106
+ self.log_pretrained_step = log_pretrained_step
107
+
108
+ # initialize total rewards to the reward of the roll out unmasked sequence
109
+ if totalReward is not None:
110
+ self.totalReward = totalReward # potential reward of the node based on generated children
111
+ else:
112
+ self.totalReward = np.zeros(self.args.num_obj)
113
+
114
+ # set initial visits to 1
115
+ self.visits = 1
116
+
117
+ # set timestep (value between 0 and num_steps)
118
+ self.timestep = timestep
119
+
120
+ # dict with 'seqs' as token array and 'attention_mask'
121
+ self.tokens = tokens
122
+
123
+ def selectNode(self):
124
+ """
125
+ Selects a node to move to among the children nodes based on select score
126
+ """
127
+ # extract the status of the current node
128
+ nodeStatus = self.getExpandStatus()
129
+
130
+ # if the node is a legal non-leaf node
131
+ if (nodeStatus == 3):
132
+ # initialize array that will store select score vectors of each child node
133
+
134
+ paretoFront = {}
135
+
136
+ for childNode in self.childNodes:
137
+ childStatus = childNode.getExpandStatus()
138
+ # only append child if it is legal leaf node (expandable) or legal non-leaf node
139
+ if childStatus == 2 or childStatus == 3:
140
+ selectScore = childNode.calcSelectScore()
141
+ paretoFront = updateParetoFront(paretoFront, childNode, selectScore)
142
+
143
+ selected = rd.choice(list(paretoFront.keys()))
144
+
145
+ # return selected child node and status
146
+ return selected, selected.getExpandStatus()
147
+
148
+ # if node is not valid non-leaf node
149
+ return self, nodeStatus
150
+
151
+ def addChildNode(self, tokens, log_rnd, log_policy_step, log_pretrained_step, totalReward):
152
+ """"
153
+ Adds a child node:
154
+ log_rnd: log_rnd of the path up to the added child node
155
+ log_policy_step: scalar value of the log-prob of sampling the step under the policy
156
+ log_pretrained_step: scalar value of the log-prob of sampling the step under the pretrained model
157
+ """
158
+ child = Node(args=self.args,
159
+ tokens=tokens,
160
+ log_rnd = log_rnd,
161
+ log_policy_step=log_policy_step,
162
+ log_pretrained_step=log_pretrained_step,
163
+ parentNode=self,
164
+ childNodes=[],
165
+ totalReward=totalReward,
166
+ timestep=self.timestep+1)
167
+
168
+ self.childNodes.append(child)
169
+ return child
170
+
171
+ def update_logrnd(self, log_policy_step, log_rnd):
172
+ self.log_policy_step = log_policy_step
173
+ self.log_rnd = log_rnd
174
+
175
+ def updateNode(self, rewards):
176
+ """
177
+ Updates the cumulative rewards vector with the reward vector at a descendent leaf node.
178
+ Increments the number of visits to the node.
179
+ """
180
+ self.visits += 1
181
+
182
+ self.totalReward += rewards # singleton tensor
183
+
184
+ def calcSelectScore(self):
185
+ """
186
+ Calculates the select score for the node from the cumulative rewards vector and number of visits.
187
+ - c: determines the degree of exploration
188
+ - minSelectScore: determines the
189
+ """
190
+ scaling = 0.1 # scaling of the second term in the select score
191
+
192
+ # K-dimensional vector of normalized rewards for each objective
193
+ normRewards = self.totalReward / self.visits
194
+
195
+ # scales the cumulative reward by the sampling probability
196
+
197
+ return normRewards + (scaling * self.log_policy_step.detach().cpu().item() * np.sqrt(self.parentNode.visits) / self.visits)
198
+
199
+ def getExpandStatus(self):
200
+ """
201
+ Returns an integer indicating whether the node is a:
202
+ 1. terminal node (sequence is fully unmasked)
203
+ 2. legal leaf node (partially unmasked sequence that can be expanded)
204
+ 3. legal non-leaf node (already expanded sequence with M child nodes)
205
+ """
206
+ if self.timestep == self.args.total_num_steps:
207
+ return 1
208
+ elif (self.timestep < self.args.total_num_steps) and (len(self.childNodes) == 0):
209
+ return 2
210
+ return 3
211
+
212
+ ### END OF NODE CLASS ###
213
+
214
+ ### BEGINNING OF MCTS CLASS ###
215
+
216
+ class MCTS:
217
+ def __init__(
218
+ self,
219
+ args,
220
+ config,
221
+ policy_model,
222
+ pretrained,
223
+ score_func_names=None,
224
+ prot_seqs=None,
225
+ rootNode=None,
226
+ reward_func=None,
227
+ num_obj=None,
228
+ ):
229
+ self.timer = StepTimer(policy_model.device)
230
+
231
+ self.device = policy_model.device
232
+
233
+ self.args = args
234
+ self.config = config
235
+ self.noise = noise_schedule.get_noise(config)
236
+ self.time_conditioning = args.time_conditioning
237
+
238
+ if score_func_names is None:
239
+ score_func_names = []
240
+ if num_obj is None:
241
+ num_obj = getattr(reward_func, "num_obj", None)
242
+ self.num_obj = num_obj if num_obj is not None else len(score_func_names)
243
+
244
+ self.mask_index = policy_model.mask_index
245
+ masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index
246
+ masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)}
247
+ if rootNode is None:
248
+ self.rootNode = Node(self.args, tokens = masked_tokens,
249
+ log_rnd=torch.zeros((), device=self.device),
250
+ log_policy_step=torch.zeros((), device=self.device),
251
+ log_pretrained_step=torch.zeros((), device=self.device),
252
+ totalReward=np.zeros(self.num_obj), timestep=0)
253
+ else:
254
+ self.rootNode = rootNode # stores the root node of the tree
255
+
256
+ # dictionary:
257
+ # "seq": final unmasked sequence
258
+ # "traj": list of (N_steps, L)
259
+ # "reward": reward of the trajectory
260
+ self.buffer = [] # List[Dict[str, Any]]
261
+
262
+ self.buffer_size = args.buffer_size
263
+
264
+ self.num_steps = args.total_num_steps
265
+ #self.num_sequences = args.num_sequences
266
+
267
+ # pretrained model
268
+ self.pretrained = pretrained
269
+
270
+ # the policy model that we want to finetune
271
+ self.policy_model = policy_model
272
+ #self.tokenizer = policy_model.tokenizer
273
+ self.device = policy_model.device
274
+
275
+ self.sequence_length = args.seq_length
276
+
277
+ self.num_iter = args.num_iter
278
+
279
+ self.num_children = args.num_children
280
+
281
+ # score functions
282
+
283
+ if reward_func is None:
284
+ self.rewardFunc = ScoringFunctions(score_func_names, prot_seqs, device=args.device)
285
+ else:
286
+ self.rewardFunc = reward_func
287
+
288
+ self.iter_num = 0
289
+
290
+ self.reward_log = [] # stores scalarized total rewards
291
+ self.logrnd_log = []
292
+ # stores each objective
293
+ self.valid_fraction_log = []
294
+ self.affinity1_log = []
295
+ self.affinity2_log = []
296
+ self.permeability_log = []
297
+ self.sol_log = []
298
+ self.hemo_log = []
299
+ self.nf_log = []
300
+
301
+ self.policy_model.eval()
302
+ self.pretrained.eval()
303
+
304
+ # for peptides
305
+ self.analyzer = PeptideAnalyzer()
306
+ self.tokenizer = policy_model.tokenizer
307
+
308
+
309
+ def reset(self, resetTree):
310
+ self.iter_num = 0
311
+ self.buffer = []
312
+ self.reward_log = []
313
+ self.logrnd_log = []
314
+
315
+ # reset logs for each objective
316
+ self.valid_fraction_log = []
317
+ self.affinity1_log = []
318
+ self.affinity2_log = []
319
+ self.permeability_log = []
320
+ self.sol_log = []
321
+ self.hemo_log = []
322
+ self.nf_log = []
323
+
324
+ # add option to continue with the same tree
325
+ if resetTree:
326
+ masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index
327
+ masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)}
328
+ self.rootNode = Node(self.args, tokens = masked_tokens,
329
+ log_rnd=torch.zeros((), device=self.device),
330
+ log_policy_step=torch.zeros((), device=self.device),
331
+ log_pretrained_step=torch.zeros((), device=self.device),
332
+ totalReward=np.zeros(self.num_obj), timestep=0)
333
+
334
+ def forward(self, resetTree=False):
335
+
336
+ self.reset(resetTree)
337
+
338
+ while (self.iter_num < self.num_iter):
339
+ self.iter_num += 1
340
+
341
+ # traverse the tree form the root node until a leaf node
342
+ with self.timer.section("select"):
343
+ leafNode, _ = self.select(self.rootNode)
344
+
345
+ # expand leaf node into num_children partially unmasked sequences at the next timestep
346
+ with self.timer.section("expand"):
347
+ self.expand(leafNode)
348
+
349
+ final_x, log_rnd, final_rewards, score_vectors, sequences = self.consolidateBuffer()
350
+ # return final_seqs (B, L), log_rnd (B, ), and final rewards (B, )
351
+
352
+ rows = self.timer.summary()
353
+ print("\n=== Timing summary (by total time) ===")
354
+ for name, cnt, total, mean, p50, p95 in rows:
355
+ print(f"{name:30s} n={cnt:5d} total={total:8.3f}s mean={mean*1e3:7.2f}ms "
356
+ f"p50={p50*1e3:7.2f}ms p95={p95*1e3:7.2f}ms")
357
+
358
+ return final_x, log_rnd, final_rewards, score_vectors, sequences
359
+
360
+ # new updateBuffer
361
+ def _debug_buffer_decision(self, sv, reason, extra=None):
362
+ if extra is None: extra = {}
363
+ print(f"[BUFFER] reason={reason} sv={np.round(sv,4)} "
364
+ f"buf_len={len(self.buffer)} extra={extra}")
365
+
366
+ def updateBuffer(self, x_final, log_rnd, score_vectors, childSequences):
367
+ B = x_final.shape[0]
368
+ traj_log_rnds, scalar_rewards = [], []
369
+
370
+ for i in range(B):
371
+ sv = np.asarray(score_vectors[i], dtype=float)
372
+
373
+ # determine how to scalarize the multi-objective rewards
374
+ if self.args.scalarization == "normalized":
375
+ pass
376
+ elif self.args.scalarization == "weighted":
377
+ pass
378
+ else:
379
+ scalar_reward = float(np.sum(sv))
380
+
381
+ traj_log_rnd = log_rnd[i] + (scalar_reward / self.args.alpha) # scale down by alpha
382
+
383
+ item = {
384
+ "x_final": x_final[i].clone(), # clone?
385
+ "log_rnd": traj_log_rnd.clone(),
386
+ "final_reward": scalar_reward,
387
+ "score_vector": sv.copy(),
388
+ "seq": childSequences[i],
389
+ }
390
+
391
+ # Drop if dominated by any existing
392
+ if any(dominated_by(sv, bi["score_vector"]) for bi in self.buffer):
393
+ # for debugging
394
+ self._debug_buffer_decision(sv, "rejected_dominated")
395
+ continue
396
+
397
+ # Remove any existing that this candidate dominates
398
+ keep = []
399
+ for bi in self.buffer:
400
+ if not dominates(sv, bi["score_vector"]):
401
+ keep.append(bi)
402
+ self.buffer = keep
403
+
404
+ # Insert with capacity rule
405
+ if len(self.buffer) < self.buffer_size:
406
+ self.buffer.append(item)
407
+ else:
408
+ # tie-breaker: replace the worst by a simple heuristic (min sum)
409
+ worst_i = int(np.argmin([np.sum(bi["score_vector"]) for bi in self.buffer]))
410
+ self.buffer[worst_i] = item
411
+
412
+ # for debugging
413
+ self._debug_buffer_decision(sv, "inserted", {"new_len": len(self.buffer)})
414
+
415
+ traj_log_rnds.append(traj_log_rnd)
416
+ scalar_rewards.append(scalar_reward)
417
+
418
+ traj_log_rnds = torch.stack(traj_log_rnds, dim=0) if traj_log_rnds else torch.empty(0)
419
+ scalar_rewards = np.asarray(scalar_rewards, dtype=float)
420
+ return traj_log_rnds, scalar_rewards
421
+
422
+ def consolidateBuffer(self):
423
+ """
424
+ returns x_final, log_rnd, and final_rewards in tensors
425
+ """
426
+ x_final = []
427
+ log_rnd = []
428
+ final_rewards = []
429
+ score_vectors = []
430
+ sequences = []
431
+ for item in self.buffer:
432
+ x_final.append(item["x_final"])
433
+ log_rnd.append(item["log_rnd"])
434
+ final_rewards.append(item["final_reward"])
435
+ score_vectors.append(item["score_vector"])
436
+ sequences.append(item["seq"])
437
+
438
+ x_final = torch.stack(x_final, dim=0) # (B, L)
439
+ log_rnd = torch.stack(log_rnd, dim=0).to(dtype=torch.float32) # (B)
440
+ final_rewards = np.stack(final_rewards, axis=0).astype(np.float32)
441
+ score_vectors = np.stack(score_vectors, axis=0).astype(np.float32)
442
+
443
+ return x_final, log_rnd, final_rewards, score_vectors, sequences
444
+
445
+
446
+ def isPathEnd(self, path, maxDepth):
447
+ """
448
+ Checks if the node is completely unmasked (ie. end of path)
449
+ or if the path is at the max depth
450
+ """
451
+ if (path[-1] != self.mask_index).all():
452
+ return True
453
+ elif len(path) >= maxDepth:
454
+ return True
455
+ return False
456
+
457
+ def select(self, currNode, eps=1e-5):
458
+ """
459
+ Traverse the tree from the root node until reaching a legal leaf node
460
+ """
461
+ updated_log_rnd = torch.zeros((), device=self.device)
462
+ while True:
463
+ currNode, nodeStatus = currNode.selectNode()
464
+
465
+ if currNode.parentNode is not None:
466
+ # compute new log_policy
467
+ child_tokens = currNode.tokens['seqs'].to(self.device)
468
+ attn_mask = currNode.tokens['attention_mask'].to(self.device)
469
+ parent = currNode.parentNode
470
+ parent_tokens = parent.tokens['seqs'].to(self.device)
471
+ t = torch.ones(1, device = self.device)
472
+ dt = (1 - eps) / self.num_steps
473
+ with torch.no_grad():
474
+ with self.timer.section("select.compute_log_policy"):
475
+ updated_log_policy_step = self.policy_model.compute_log_policy(parent_tokens,
476
+ child_tokens,
477
+ t=t, dt=dt)
478
+ updated_log_rnd += updated_log_policy_step
479
+
480
+ currNode.update_logrnd(updated_log_policy_step, updated_log_rnd) # update log_rnd
481
+
482
+ if nodeStatus != 3:
483
+ return currNode, nodeStatus
484
+
485
+ def expand(self, parentNode, eps=1e-5):
486
+ """
487
+ Sample unmasking steps from the pre-trained MDLM
488
+ adds num_children partially unmasked sequences to the children of the parentNode
489
+ """
490
+
491
+ num_children = self.num_children
492
+ # initialize child rewards that will be added to total rewards
493
+
494
+
495
+ # compute number of rollout steps
496
+ # if parentNode.timestep = self.num_steps then num_rollout_steps = 1
497
+ num_rollout_steps = self.num_steps - parentNode.timestep
498
+ # array of rollout timesteps from the timestep of parent node to 0
499
+ rollout_t = torch.linspace(1, eps, self.num_steps + 1, device=self.device)
500
+ dt = (1 - eps) / self.num_steps
501
+
502
+ # initialize x and attn_mask
503
+ x = parentNode.tokens['seqs'].to(self.device)
504
+ attn_mask = parentNode.tokens['attention_mask'].to(self.device)
505
+ parent_log_rnd = parentNode.log_rnd # stores the log_rnd up to parent node
506
+
507
+ t = rollout_t[parentNode.timestep] * torch.ones(1, 1, device = self.device)
508
+
509
+ # sample M child sequences and compute their log probabilities
510
+ with torch.no_grad():
511
+ with self.timer.section("expand.batch_mcts_reverse_step"):
512
+ _, x_children, child_log_policy_step, child_log_pretrained_step = \
513
+ self.policy_model.batch_mcts_reverse_step(token_array=x,
514
+ t=t, dt=dt,
515
+ batch_size=num_children,
516
+ pretrained=self.pretrained)
517
+
518
+ # compute weight of the step (num_children, 1)
519
+
520
+ child_log_rnd = (parent_log_rnd + (child_log_pretrained_step - child_log_policy_step)).to(self.device)
521
+
522
+ x_rollout = x_children
523
+
524
+ traj_log_rnd = child_log_rnd # initialize log_rnd for entire rolled out trajectory
525
+
526
+ # rollout under the policy and compute the log ratio at each step
527
+ with self.timer.section("expand.rollout_total"):
528
+ for i in range(1, num_rollout_steps):
529
+ t = rollout_t[parentNode.timestep + i] * torch.ones(num_children, 1, device = self.device)
530
+
531
+ with torch.no_grad():
532
+ _, x_next, log_policy_step, log_pretrained_step = \
533
+ self.policy_model.mcts_reverse_step(x_rollout,
534
+ t=t, dt=dt,
535
+ pretrained=self.pretrained)
536
+
537
+ # add the rollout step
538
+ traj_log_rnd += log_pretrained_step - log_policy_step
539
+
540
+ x_rollout = x_next
541
+
542
+
543
+ # if mask token remains, fully unmask
544
+ mask_positions = (x_rollout == self.mask_index) # (B, L) bool
545
+
546
+ # does **any** mask remain in any sequence
547
+ any_mask_global = mask_positions.any().item() # true if mask remains
548
+ if any_mask_global:
549
+ with torch.no_grad():
550
+ with self.timer.section("expand.noise_removal"):
551
+ log_p, x_next, log_policy_step, log_pretrained_step = \
552
+ self.policy_model.mcts_noise_removal(x_rollout,
553
+ t=t, dt=dt,
554
+ pretrained=self.pretrained)
555
+
556
+ traj_log_rnd += log_pretrained_step - log_policy_step
557
+
558
+ x_rollout = x_next
559
+
560
+ # stores the string sequences for reward evaluation
561
+ with self.timer.section("expand.decode"):
562
+ childSequences = self.tokenizer.batch_decode(x_rollout)
563
+
564
+ ## FOR PEPTIDES ONLY ##
565
+ valid_x_children = []
566
+ valid_x_final = []
567
+ validSequences = []
568
+ valid_traj_log_rnd = []
569
+
570
+ with self.timer.section("expand.filter_is_peptide"):
571
+ for i in range(num_children):
572
+ # string sequence
573
+ childSeq = childSequences[i]
574
+
575
+ # check if the peptide is valid
576
+ if self.analyzer.is_peptide(childSeq):
577
+ valid_x_children.append(x_children[i])
578
+ valid_x_final.append(x_rollout[i])
579
+ validSequences.append(childSeq)
580
+ valid_traj_log_rnd.append(traj_log_rnd[i])
581
+ else:
582
+ childTokens = {'seqs': x_children[i].to(dtype=torch.long), 'attention_mask': attn_mask}
583
+ parentNode.addChildNode(tokens=childTokens,
584
+ log_rnd=child_log_rnd[i],
585
+ log_policy_step=child_log_policy_step[i],
586
+ log_pretrained_step=child_log_pretrained_step[i],
587
+ totalReward=np.zeros(self.num_obj))
588
+
589
+ del traj_log_rnd
590
+
591
+ log_targets = [
592
+ self.affinity1_log,
593
+ self.sol_log,
594
+ self.hemo_log,
595
+ self.nf_log,
596
+ self.permeability_log,
597
+ ]
598
+
599
+ if len(validSequences) != 0:
600
+ # add scores to log
601
+ with self.timer.section("expand.scoring_functions"):
602
+ score_vectors = np.asarray(self.rewardFunc(input_seqs=validSequences))
603
+
604
+ if score_vectors.ndim == 1:
605
+ score_vectors = score_vectors[:, None]
606
+
607
+ average_scores = score_vectors.T
608
+ num_scores = average_scores.shape[0]
609
+ score_len = average_scores.shape[1]
610
+
611
+ for idx, log_list in enumerate(log_targets):
612
+ if idx < num_scores:
613
+ log_list.append(average_scores[idx])
614
+ else:
615
+ log_list.append(np.zeros(score_len, dtype=np.float32))
616
+ else:
617
+ # set the values added to log as 0s if there are no valid sequences
618
+ empty = np.zeros(self.num_children, dtype=np.float32)
619
+ for log_list in log_targets:
620
+ log_list.append(empty)
621
+
622
+ # convert to tensor
623
+ if len(valid_x_final) == 0:
624
+ # log and bail out gracefully for this expansion
625
+ self.valid_fraction_log.append(0.0)
626
+ return
627
+
628
+ valid_x_final = torch.stack(valid_x_final, dim=0)
629
+ valid_traj_log_rnd = torch.stack(valid_traj_log_rnd, dim=0)
630
+ # update buffer and get rewards
631
+ with self.timer.section("expand.update_buffer"):
632
+ traj_log_rnds, scalar_rewards = self.updateBuffer(valid_x_final, valid_traj_log_rnd, score_vectors, childSequences)
633
+
634
+ allChildReward = np.zeros_like(score_vectors[0])
635
+
636
+ for i in range(len(score_vectors)):
637
+ reward = score_vectors[i]
638
+
639
+ # add to all child reward vector for backprop
640
+ allChildReward += reward # (num_objectives,)
641
+
642
+ # create node for sequence and add to the children node of parent
643
+ childTokens = {'seqs': valid_x_children[i].to(dtype=torch.long), 'attention_mask': attn_mask}
644
+ parentNode.addChildNode(tokens=childTokens,
645
+ log_rnd=child_log_rnd[i],
646
+ log_policy_step=child_log_policy_step[i],
647
+ log_pretrained_step=child_log_pretrained_step[i],
648
+ totalReward=reward)
649
+
650
+ ### END OF FOR PEPTIDES ONLY ###
651
+
652
+ valid_fraction = len(validSequences) / num_children
653
+ self.valid_fraction_log.append(valid_fraction)
654
+
655
+ # debugging
656
+ print(f"[EXPAND] iter={self.iter_num} parent_t={parentNode.timestep} "
657
+ f"num_children={num_children} valid={len(validSequences)} any_mask={any_mask_global}")
658
+ if score_vectors is not None:
659
+ print(f"[SCORES] min={np.min(score_vectors,0)} max={np.max(score_vectors,0)} "
660
+ f"nan_any={np.isnan(score_vectors).any()}")
661
+ # end debugging
662
+
663
+ self.reward_log.append(scalar_rewards)
664
+ self.logrnd_log.append(traj_log_rnds.detach().cpu().numpy())
665
+
666
+ allChildReward = allChildReward / len(validSequences) # normalize by number of valid children
667
+ # backpropogate all child rewards
668
+ with self.timer.section("expand.backprop"):
669
+ self.backprop(parentNode, allChildReward)
670
+
671
+
672
+ def backprop(self, node, allChildReward):
673
+ # backpropogate rewards through the path leading to the leaf node from the root
674
+ while node:
675
+ node.updateNode(allChildReward)
676
+ node = node.parentNode
roformer.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import RoFormerConfig, RoFormerForMaskedLM
2
+ import torch.nn as nn
3
+ from torch.nn.parallel import DistributedDataParallel as DDP
4
+ import torch
5
+
6
+ class Roformer(nn.Module):
7
+ def __init__(self, config, tokenizer, device=None):
8
+ super(Roformer, self).__init__()
9
+
10
+ self.tokenizer = tokenizer
11
+ self.vocab_size = self.tokenizer.vocab_size
12
+
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
14
+
15
+
16
+ roformer_config = RoFormerConfig(
17
+ vocab_size=self.tokenizer.vocab_size,
18
+ embedding_size=config.roformer.hidden_size,
19
+ hidden_size=config.roformer.hidden_size,
20
+ num_hidden_layers=config.roformer.n_layers,
21
+ num_attention_heads=config.roformer.n_heads,
22
+ intermediate_size=config.roformer.hidden_size * 4,
23
+ max_position_embeddings=config.roformer.max_position_embeddings,
24
+ hidden_dropout_prob=0.1,
25
+ attention_probs_dropout_prob=0.1,
26
+ pad_token_id=0,
27
+ rotary_value=False
28
+ )
29
+
30
+ self.model = RoFormerForMaskedLM(roformer_config).to(self.device)
31
+
32
+ def freeze_model(self):
33
+ for param in self.model.parameters():
34
+ param.requires_grad = False
35
+
36
+ def unfreeze_all_layers(self):
37
+ for param in self.model.parameters():
38
+ param.requires_grad = True
39
+
40
+ def unfreeze_n_layers(self, n):
41
+ num_layers = 8
42
+
43
+ for i, layer in enumerate(self.model.roformer.encoder.layer):
44
+ # finetune final n layers
45
+ if i >= num_layers - n:
46
+ # unfreeze query weights
47
+ for module in layer.attention.self.query.modules():
48
+ for param in module.parameters():
49
+ param.requires_grad = True
50
+ # unfreeze key weights
51
+ for module in layer.attention.self.key.modules():
52
+ for param in module.parameters():
53
+ param.requires_grad = True
54
+
55
+ def forward(self, input_ids, attn_mask):
56
+
57
+ input_ids = input_ids.to(self.device)
58
+ attn_mask = attn_mask.to(self.device)
59
+
60
+ # get logits embeddings
61
+ logits = self.model(input_ids=input_ids, attention_mask=attn_mask)
62
+ # return logits
63
+ #print(logits.logits)
64
+ return logits.logits
65
+
66
+ def save_model(self, save_dir):
67
+ self.model.save_pretrained(save_dir)
68
+ self.tokenizer.save_pretrained(save_dir)
69
+
70
+ @classmethod
71
+ def load_model(cls, save_dir, config, tokenizer):
72
+ roformer = cls(config, tokenizer)
73
+ roformer.model = RoFormerForMaskedLM.from_pretrained(save_dir)
74
+ return roformer
scoring/functions/binding.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os, torch
3
+ import numpy as np
4
+ import torch
5
+ import pandas as pd
6
+ import torch.nn as nn
7
+ import esm
8
+ from transformers import AutoModelForMaskedLM
9
+
10
+
11
+ def _sanitize_token_ids(input_ids: torch.Tensor, vocab_size: int, unk_id: int) -> torch.Tensor:
12
+ if vocab_size <= 0 or input_ids.numel() == 0:
13
+ return input_ids
14
+ if torch.any(input_ids >= vocab_size) or torch.any(input_ids < 0):
15
+ # Replace out-of-range IDs with UNK to avoid embedding OOB.
16
+ unk = torch.tensor(unk_id, device=input_ids.device, dtype=input_ids.dtype)
17
+ input_ids = torch.where((input_ids >= vocab_size) | (input_ids < 0), unk, input_ids)
18
+ return input_ids
19
+
20
+ class ImprovedBindingPredictor(nn.Module):
21
+ def __init__(self,
22
+ esm_dim=1280,
23
+ smiles_dim=768,
24
+ hidden_dim=512,
25
+ n_heads=8,
26
+ n_layers=3,
27
+ dropout=0.1):
28
+ super().__init__()
29
+
30
+ # Define binding thresholds
31
+ self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
32
+ self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
33
+
34
+ # Project to same dimension
35
+ self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
36
+ self.protein_projection = nn.Linear(esm_dim, hidden_dim)
37
+ self.protein_norm = nn.LayerNorm(hidden_dim)
38
+ self.smiles_norm = nn.LayerNorm(hidden_dim)
39
+
40
+ # Cross attention blocks with layer norm
41
+ self.cross_attention_layers = nn.ModuleList([
42
+ nn.ModuleDict({
43
+ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
44
+ 'norm1': nn.LayerNorm(hidden_dim),
45
+ 'ffn': nn.Sequential(
46
+ nn.Linear(hidden_dim, hidden_dim * 4),
47
+ nn.ReLU(),
48
+ nn.Dropout(dropout),
49
+ nn.Linear(hidden_dim * 4, hidden_dim)
50
+ ),
51
+ 'norm2': nn.LayerNorm(hidden_dim)
52
+ }) for _ in range(n_layers)
53
+ ])
54
+
55
+ # Prediction heads
56
+ self.shared_head = nn.Sequential(
57
+ nn.Linear(hidden_dim * 2, hidden_dim),
58
+ nn.ReLU(),
59
+ nn.Dropout(dropout),
60
+ )
61
+
62
+ # Regression head
63
+ self.regression_head = nn.Linear(hidden_dim, 1)
64
+
65
+ # Classification head (3 classes: tight, medium, loose binding)
66
+ self.classification_head = nn.Linear(hidden_dim, 3)
67
+
68
+ def get_binding_class(self, affinity):
69
+ """Convert affinity values to class indices
70
+ 0: tight binding (>= 7.5)
71
+ 1: medium binding (6.0-7.5)
72
+ 2: weak binding (< 6.0)
73
+ """
74
+ if isinstance(affinity, torch.Tensor):
75
+ tight_mask = affinity >= self.tight_threshold
76
+ weak_mask = affinity < self.weak_threshold
77
+ medium_mask = ~(tight_mask | weak_mask)
78
+
79
+ classes = torch.zeros_like(affinity, dtype=torch.long)
80
+ classes[medium_mask] = 1
81
+ classes[weak_mask] = 2
82
+ return classes
83
+ else:
84
+ if affinity >= self.tight_threshold:
85
+ return 0 # tight binding
86
+ elif affinity < self.weak_threshold:
87
+ return 2 # weak binding
88
+ else:
89
+ return 1 # medium binding
90
+
91
+ def forward(self, protein_emb, smiles_emb):
92
+ protein = self.protein_norm(self.protein_projection(protein_emb))
93
+ smiles = self.smiles_norm(self.smiles_projection(smiles_emb))
94
+
95
+ #protein = protein.transpose(0, 1)
96
+ #smiles = smiles.transpose(0, 1)
97
+
98
+ # Cross attention layers
99
+ for layer in self.cross_attention_layers:
100
+ # Protein attending to SMILES
101
+ attended_protein = layer['attention'](
102
+ protein, smiles, smiles
103
+ )[0]
104
+ protein = layer['norm1'](protein + attended_protein)
105
+ protein = layer['norm2'](protein + layer['ffn'](protein))
106
+
107
+ # SMILES attending to protein
108
+ attended_smiles = layer['attention'](
109
+ smiles, protein, protein
110
+ )[0]
111
+ smiles = layer['norm1'](smiles + attended_smiles)
112
+ smiles = layer['norm2'](smiles + layer['ffn'](smiles))
113
+
114
+ # Get sequence-level representations
115
+ protein_pool = torch.mean(protein, dim=0)
116
+ smiles_pool = torch.mean(smiles, dim=0)
117
+
118
+ # Concatenate both representations
119
+ combined = torch.cat([protein_pool, smiles_pool], dim=-1)
120
+
121
+ # Shared features
122
+ shared_features = self.shared_head(combined)
123
+
124
+ regression_output = self.regression_head(shared_features)
125
+ classification_logits = self.classification_head(shared_features)
126
+
127
+ return regression_output, classification_logits
128
+
129
+ class BindingAffinity:
130
+ def __init__(self, prot_seq, tokenizer, base_path, device=None, emb_model=None):
131
+ super().__init__()
132
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
133
+
134
+ # peptide embeddings
135
+ if emb_model is not None:
136
+ self.pep_model = emb_model.to(self.device).eval()
137
+ else:
138
+ self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
139
+
140
+ self.pep_tokenizer = tokenizer
141
+ self.unk_id = getattr(self.pep_tokenizer, "unk_token_id", None)
142
+ if self.unk_id is None:
143
+ self.unk_id = self.pep_tokenizer.vocab.get(self.pep_tokenizer.unk_token, 0)
144
+ self.pep_vocab_size = None
145
+ self.max_pep_len = None
146
+ if hasattr(self.pep_model, "model") and hasattr(self.pep_model.model, "roformer"):
147
+ self.pep_vocab_size = self.pep_model.model.roformer.embeddings.word_embeddings.num_embeddings
148
+ self.max_pep_len = self.pep_model.model.roformer.config.max_position_embeddings
149
+ elif hasattr(self.pep_model, "roformer"):
150
+ self.pep_vocab_size = self.pep_model.roformer.embeddings.word_embeddings.num_embeddings
151
+ self.max_pep_len = self.pep_model.roformer.config.max_position_embeddings
152
+ elif hasattr(self.pep_model, "get_input_embeddings"):
153
+ self.pep_vocab_size = self.pep_model.get_input_embeddings().num_embeddings
154
+ self.max_pep_len = getattr(self.pep_model.config, "max_position_embeddings", None)
155
+
156
+ self.model = ImprovedBindingPredictor().to(self.device)
157
+ checkpoint = torch.load(f'{base_path}/tr2d2-pep/scoring/functions/classifiers/binding-affinity.pt',
158
+ map_location=self.device,
159
+ weights_only=False)
160
+ self.model.load_state_dict(checkpoint['model_state_dict'])
161
+
162
+ self.model.eval()
163
+
164
+ self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model
165
+ self.esm_model = self.esm_model.to(self.device).eval()
166
+ self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer
167
+
168
+ data = [("target", prot_seq)]
169
+ # get tokenized protein
170
+ _, _, prot_tokens = self.prot_tokenizer(data)
171
+ prot_tokens = prot_tokens.to(self.device)
172
+ with torch.no_grad():
173
+ results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2
174
+ prot_emb = results["representations"][33]
175
+
176
+ self.prot_emb = prot_emb[0].to(self.device)
177
+ self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True)
178
+
179
+
180
+ def forward(self, input_seqs):
181
+ with torch.no_grad():
182
+ scores = []
183
+ for seq in input_seqs:
184
+ pep_tokens = self.pep_tokenizer(
185
+ seq,
186
+ return_tensors='pt',
187
+ padding=True,
188
+ truncation=self.max_pep_len is not None,
189
+ max_length=self.max_pep_len,
190
+ )
191
+
192
+ pep_tokens = {k: v.to(self.device) for k, v in pep_tokens.items()}
193
+ pep_tokens["input_ids"] = _sanitize_token_ids(
194
+ pep_tokens["input_ids"], int(self.pep_vocab_size or 0), int(self.unk_id)
195
+ )
196
+
197
+ with torch.no_grad():
198
+ # Check if using custom Roformer wrapper or standard model
199
+ if hasattr(self.pep_model, 'model'):
200
+ # Custom roformer.Roformer wrapper - get hidden states from inner model
201
+ emb = self.pep_model.model.roformer(
202
+ input_ids=pep_tokens['input_ids'],
203
+ attention_mask=pep_tokens.get('attention_mask'),
204
+ output_hidden_states=True
205
+ )
206
+ pep_emb = emb.last_hidden_state.squeeze(0)
207
+ pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
208
+ else:
209
+ # Standard AutoModelForMaskedLM
210
+ emb = self.pep_model(
211
+ input_ids=pep_tokens['input_ids'],
212
+ attention_mask=pep_tokens.get('attention_mask'),
213
+ output_hidden_states=True
214
+ )
215
+ pep_emb = emb.last_hidden_state.squeeze(0)
216
+ pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
217
+
218
+ score, logits = self.model.forward(self.prot_emb, pep_emb)
219
+ scores.append(score.item())
220
+ return scores
221
+
222
+ def __call__(self, input_seqs: list):
223
+ return self.forward(input_seqs)
224
+
225
+
226
+ class MultiTargetBindingAffinity:
227
+ """
228
+ Binding affinity predictor that can handle multiple protein targets dynamically.
229
+
230
+ Unlike BindingAffinity which pre-computes a single target's embedding,
231
+ this class can switch between different protein targets on-the-fly.
232
+ """
233
+
234
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
235
+ """
236
+ Initialize multi-target binding affinity predictor.
237
+
238
+ Args:
239
+ tokenizer: Peptide tokenizer
240
+ base_path: Base path for model files
241
+ device: Device for computation (default: auto-detect)
242
+ emb_model: Optional pre-loaded embedding model
243
+ """
244
+ super().__init__()
245
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
246
+
247
+ # Peptide embeddings
248
+ if emb_model is not None:
249
+ self.pep_model = emb_model.to(self.device).eval()
250
+ else:
251
+ self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
252
+
253
+ self.pep_tokenizer = tokenizer
254
+ self.unk_id = getattr(self.pep_tokenizer, "unk_token_id", None)
255
+ if self.unk_id is None:
256
+ self.unk_id = self.pep_tokenizer.vocab.get(self.pep_tokenizer.unk_token, 0)
257
+ self.pep_vocab_size = None
258
+ self.max_pep_len = None
259
+ if hasattr(self.pep_model, "model") and hasattr(self.pep_model.model, "roformer"):
260
+ self.pep_vocab_size = self.pep_model.model.roformer.embeddings.word_embeddings.num_embeddings
261
+ self.max_pep_len = self.pep_model.model.roformer.config.max_position_embeddings
262
+ elif hasattr(self.pep_model, "roformer"):
263
+ self.pep_vocab_size = self.pep_model.roformer.embeddings.word_embeddings.num_embeddings
264
+ self.max_pep_len = self.pep_model.roformer.config.max_position_embeddings
265
+ elif hasattr(self.pep_model, "get_input_embeddings"):
266
+ self.pep_vocab_size = self.pep_model.get_input_embeddings().num_embeddings
267
+ self.max_pep_len = getattr(self.pep_model.config, "max_position_embeddings", None)
268
+
269
+ # Binding affinity prediction model
270
+ self.model = ImprovedBindingPredictor().to(self.device)
271
+ checkpoint = torch.load(f'{base_path}/tr2d2-pep/scoring/functions/classifiers/binding-affinity.pt',
272
+ map_location=self.device,
273
+ weights_only=False)
274
+ self.model.load_state_dict(checkpoint['model_state_dict'])
275
+ self.model.eval()
276
+
277
+ # Protein (ESM) model
278
+ self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
279
+ self.esm_model = self.esm_model.to(self.device).eval()
280
+ self.prot_tokenizer = alphabet.get_batch_converter()
281
+
282
+ # Cache for protein embeddings (target_seq -> embedding)
283
+ self.prot_emb_cache = {}
284
+
285
+ def get_protein_embedding(self, prot_seq: str):
286
+ """
287
+ Get protein embedding, using cache if available.
288
+
289
+ Args:
290
+ prot_seq: Protein amino acid sequence
291
+
292
+ Returns:
293
+ Protein embedding tensor
294
+ """
295
+ # Check cache first
296
+ if prot_seq in self.prot_emb_cache:
297
+ return self.prot_emb_cache[prot_seq]
298
+
299
+ # Compute embedding
300
+ data = [("target", prot_seq)]
301
+ _, _, prot_tokens = self.prot_tokenizer(data)
302
+ prot_tokens = prot_tokens.to(self.device)
303
+
304
+ with torch.no_grad():
305
+ results = self.esm_model.forward(prot_tokens, repr_layers=[33])
306
+ prot_emb = results["representations"][33]
307
+
308
+ prot_emb = prot_emb[0].to(self.device)
309
+ prot_emb = torch.mean(prot_emb, dim=0, keepdim=True)
310
+
311
+ # Cache for future use
312
+ self.prot_emb_cache[prot_seq] = prot_emb
313
+
314
+ return prot_emb
315
+
316
+ def forward(self, input_seqs, prot_seq: str):
317
+ """
318
+ Predict binding affinity for peptide-protein pairs.
319
+
320
+ Args:
321
+ input_seqs: List of peptide sequences
322
+ prot_seq: Protein target sequence
323
+
324
+ Returns:
325
+ List of binding affinity scores
326
+ """
327
+ # Get protein embedding (cached if previously computed)
328
+ prot_emb = self.get_protein_embedding(prot_seq)
329
+
330
+ with torch.no_grad():
331
+ scores = []
332
+ for seq in input_seqs:
333
+ pep_tokens = self.pep_tokenizer(
334
+ seq,
335
+ return_tensors='pt',
336
+ padding=True,
337
+ truncation=self.max_pep_len is not None,
338
+ max_length=self.max_pep_len,
339
+ )
340
+ pep_tokens = {k: v.to(self.device) for k, v in pep_tokens.items()}
341
+ pep_tokens["input_ids"] = _sanitize_token_ids(
342
+ pep_tokens["input_ids"], int(self.pep_vocab_size or 0), int(self.unk_id)
343
+ )
344
+
345
+ with torch.no_grad():
346
+ # Check if using custom Roformer wrapper or standard model
347
+ if hasattr(self.pep_model, 'model'):
348
+ # Custom roformer.Roformer wrapper - get hidden states from inner model
349
+ emb = self.pep_model.model.roformer(
350
+ input_ids=pep_tokens['input_ids'],
351
+ attention_mask=pep_tokens.get('attention_mask'),
352
+ output_hidden_states=True
353
+ )
354
+ pep_emb = emb.last_hidden_state.squeeze(0)
355
+ pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
356
+ else:
357
+ # Standard AutoModelForMaskedLM
358
+ emb = self.pep_model(
359
+ input_ids=pep_tokens['input_ids'],
360
+ attention_mask=pep_tokens.get('attention_mask'),
361
+ output_hidden_states=True
362
+ )
363
+ pep_emb = emb.last_hidden_state.squeeze(0)
364
+ pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
365
+
366
+ score, logits = self.model.forward(prot_emb, pep_emb)
367
+ scores.append(score.item())
368
+
369
+ return scores
370
+
371
+ def forward_from_probs(
372
+ self,
373
+ token_probs: torch.Tensor,
374
+ attention_mask: torch.Tensor,
375
+ prot_seq: str,
376
+ ) -> torch.Tensor:
377
+ """
378
+ Differentiable binding affinity from token probabilities.
379
+ """
380
+ if token_probs.dim() == 2:
381
+ token_probs = token_probs.unsqueeze(0)
382
+ token_probs = token_probs.to(self.device)
383
+ attention_mask = attention_mask.to(self.device)
384
+
385
+ roformer = None
386
+ if hasattr(self.pep_model, "model") and hasattr(self.pep_model.model, "roformer"):
387
+ roformer = self.pep_model.model.roformer
388
+ emb_weight = roformer.embeddings.word_embeddings.weight
389
+ elif hasattr(self.pep_model, "roformer"):
390
+ roformer = self.pep_model.roformer
391
+ emb_weight = roformer.embeddings.word_embeddings.weight
392
+ else:
393
+ emb_weight = self.pep_model.get_input_embeddings().weight
394
+
395
+ if token_probs.size(-1) != emb_weight.size(0):
396
+ raise ValueError(
397
+ f"Token vocab mismatch: probs={token_probs.size(-1)} vs model={emb_weight.size(0)}"
398
+ )
399
+
400
+ inputs_embeds = token_probs @ emb_weight
401
+ if roformer is not None:
402
+ outputs = roformer(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
403
+ hidden = outputs.last_hidden_state
404
+ else:
405
+ outputs = self.pep_model(
406
+ inputs_embeds=inputs_embeds,
407
+ attention_mask=attention_mask,
408
+ output_hidden_states=True,
409
+ return_dict=True,
410
+ )
411
+ hidden = outputs.hidden_states[-1]
412
+
413
+ mask = attention_mask.to(hidden.dtype).unsqueeze(-1)
414
+ pep_emb = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0)
415
+
416
+ prot_emb = self.get_protein_embedding(prot_seq).to(self.device)
417
+ prot_emb = prot_emb.expand(pep_emb.size(0), -1).unsqueeze(0)
418
+ pep_emb = pep_emb.unsqueeze(0)
419
+
420
+ score, _ = self.model.forward(prot_emb, pep_emb)
421
+ return score.squeeze(-1)
422
+
423
+ def __call__(self, input_seqs: list, prot_seq: str):
424
+ """
425
+ Predict binding affinity for peptide-protein pairs.
426
+
427
+ Args:
428
+ input_seqs: List of peptide sequences
429
+ prot_seq: Protein target sequence
430
+
431
+ Returns:
432
+ List of binding affinity scores
433
+ """
434
+ return self.forward(input_seqs, prot_seq)
435
+
436
+ def clear_cache(self):
437
+ """Clear the protein embedding cache to free memory."""
438
+ self.prot_emb_cache = {}
439
+
440
+
441
+ class TargetSpecificBindingAffinity:
442
+ """
443
+ Wrapper that binds a specific protein target to MultiTargetBindingAffinity.
444
+
445
+ This allows using MultiTargetBindingAffinity with the standard BindingAffinity interface
446
+ where only peptide sequences need to be provided.
447
+ """
448
+
449
+ def __init__(self, multi_target_predictor: MultiTargetBindingAffinity, prot_seq: str):
450
+ """
451
+ Create a target-specific binding affinity predictor.
452
+
453
+ Args:
454
+ multi_target_predictor: The underlying multi-target predictor
455
+ prot_seq: The protein target sequence to use
456
+ """
457
+ self.predictor = multi_target_predictor
458
+ self.prot_seq = prot_seq
459
+
460
+ def forward(self, input_seqs):
461
+ """
462
+ Predict binding affinity for peptides against the bound target.
463
+
464
+ Args:
465
+ input_seqs: List of peptide sequences
466
+
467
+ Returns:
468
+ List of binding affinity scores
469
+ """
470
+ return self.predictor.forward(input_seqs, self.prot_seq)
471
+
472
+ def __call__(self, input_seqs: list):
473
+ """
474
+ Predict binding affinity for peptides against the bound target.
475
+
476
+ Args:
477
+ input_seqs: List of peptide sequences
478
+
479
+ Returns:
480
+ List of binding affinity scores
481
+ """
482
+ return self.forward(input_seqs)
scoring/functions/classifiers/hemolysis-xgboost.json ADDED
The diff for this file is too large to render. See raw diff
 
scoring/functions/classifiers/nonfouling-xgboost.json ADDED
The diff for this file is too large to render. See raw diff
 
scoring/functions/classifiers/permeability-xgboost.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e5d8c84bdad75f7091b5b3963133d4b0ebd180ae45654618ca6c090eee0bc06
3
+ size 45249160
scoring/functions/classifiers/solubility-xgboost.json ADDED
The diff for this file is too large to render. See raw diff
 
scoring/functions/hemolysis.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xgboost as xgb
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoModelForMaskedLM
5
+ import warnings
6
+ import numpy as np
7
+ from rdkit import rdBase
8
+
9
+ rdBase.DisableLog('rdApp.error')
10
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
11
+ warnings.filterwarnings("ignore", category=UserWarning)
12
+ warnings.filterwarnings("ignore", category=FutureWarning)
13
+
14
+ class Hemolysis:
15
+
16
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
17
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
18
+ self.predictor = xgb.Booster(model_file=f'{base_path}/tr2d2-pep/scoring/functions/classifiers/hemolysis-xgboost.json')
19
+ self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
20
+ self.tokenizer = tokenizer
21
+
22
+ def generate_embeddings(self, sequences):
23
+ embeddings = []
24
+ for sequence in sequences:
25
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
26
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
27
+ with torch.no_grad():
28
+ output = self.emb_model(**tokenized)
29
+ # Mean pooling across sequence length
30
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
31
+ embeddings.append(embedding)
32
+ return np.array(embeddings)
33
+
34
+ def get_scores(self, input_seqs: list):
35
+ scores = np.ones(len(input_seqs))
36
+ features = self.generate_embeddings(input_seqs)
37
+
38
+ if len(features) == 0:
39
+ return scores
40
+
41
+ features = np.nan_to_num(features, nan=0.)
42
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
43
+
44
+ features = xgb.DMatrix(features)
45
+
46
+ probs = self.predictor.predict(features)
47
+ # return the probability of it being not hemolytic
48
+ return scores - probs
49
+
50
+ def __call__(self, input_seqs: list):
51
+ scores = self.get_scores(input_seqs)
52
+ return scores
53
+
54
+ def unittest():
55
+ hemo = Hemolysis()
56
+ seq = ["[te]NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
57
+ print(hemo.tokenizer.vocab_size)
58
+ scores = hemo(input_seqs=seq)
59
+ print(scores)
60
+
61
+
62
+ if __name__ == '__main__':
63
+ unittest()
scoring/functions/nonfouling.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import xgboost as xgb
4
+ import torch
5
+ import numpy as np
6
+ from transformers import AutoModelForMaskedLM
7
+ import warnings
8
+ import numpy as np
9
+ from rdkit import Chem, rdBase, DataStructs
10
+
11
+
12
+ rdBase.DisableLog('rdApp.error')
13
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
14
+ warnings.filterwarnings("ignore", category=UserWarning)
15
+ warnings.filterwarnings("ignore", category=FutureWarning)
16
+
17
+ class Nonfouling:
18
+
19
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
20
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
21
+ self.predictor = xgb.Booster(model_file=f'{base_path}/tr2d2-pep/scoring/functions/classifiers/nonfouling-xgboost.json')
22
+ self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
23
+ self.tokenizer = tokenizer
24
+
25
+ def generate_embeddings(self, sequences):
26
+ embeddings = []
27
+ for sequence in sequences:
28
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
29
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
30
+ with torch.no_grad():
31
+ output = self.emb_model(**tokenized)
32
+ # Mean pooling across sequence length
33
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
34
+ embeddings.append(embedding)
35
+ return np.array(embeddings)
36
+
37
+ def get_scores(self, input_seqs: list):
38
+ scores = np.zeros(len(input_seqs))
39
+ features = self.generate_embeddings(input_seqs)
40
+
41
+ if len(features) == 0:
42
+ return scores
43
+
44
+ features = np.nan_to_num(features, nan=0.)
45
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
46
+
47
+ features = xgb.DMatrix(features)
48
+
49
+ scores = self.predictor.predict(features)
50
+ # return the probability of it being not hemolytic
51
+ return scores
52
+
53
+ def __call__(self, input_seqs: list):
54
+ scores = self.get_scores(input_seqs)
55
+ return scores
56
+
57
+ def unittest():
58
+ nf = Nonfouling()
59
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
60
+
61
+ scores = nf(input_seqs=seq)
62
+ print(scores)
63
+
64
+
65
+ if __name__ == '__main__':
66
+ unittest()
scoring/functions/permeability.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import xgboost as xgb
4
+ import torch
5
+ import numpy as np
6
+ from transformers import AutoModelForMaskedLM
7
+ import warnings
8
+ import numpy as np
9
+ from rdkit.Chem import Descriptors, rdMolDescriptors
10
+ from rdkit import Chem, rdBase, DataStructs
11
+ from rdkit.Chem import AllChem
12
+ from typing import List
13
+ from transformers import AutoModelForMaskedLM
14
+
15
+
16
+ rdBase.DisableLog('rdApp.error')
17
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
18
+ warnings.filterwarnings("ignore", category=UserWarning)
19
+ warnings.filterwarnings("ignore", category=FutureWarning)
20
+
21
+ def fingerprints_from_smiles(smiles: List, size=2048):
22
+ """ Create ECFP fingerprints of smiles, with validity check """
23
+ fps = []
24
+ valid_mask = []
25
+ for i, smile in enumerate(smiles):
26
+ mol = Chem.MolFromSmiles(smile)
27
+ valid_mask.append(int(mol is not None))
28
+ fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
29
+ fps.append(fp)
30
+
31
+ fps = np.concatenate(fps, axis=0)
32
+ return fps, valid_mask
33
+
34
+
35
+ def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
36
+ """ Create ECFP fingerprint of a molecule """
37
+ if hashed:
38
+ fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
39
+ else:
40
+ fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
41
+ fp_np = np.zeros((1,))
42
+ DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
43
+ return fp_np.reshape(1, -1)
44
+
45
+ def getMolDescriptors(mol, missingVal=0):
46
+ """ calculate the full list of descriptors for a molecule """
47
+
48
+ values, names = [], []
49
+ for nm, fn in Descriptors._descList:
50
+ try:
51
+ val = fn(mol)
52
+ except:
53
+ val = missingVal
54
+ values.append(val)
55
+ names.append(nm)
56
+
57
+ custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
58
+ 'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
59
+ 'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
60
+
61
+ for nm, fn in custom_descriptors.items():
62
+ try:
63
+ val = fn(mol)
64
+ except:
65
+ val = missingVal
66
+ values.append(val)
67
+ names.append(nm)
68
+ return values, names
69
+
70
+ def get_pep_dps_from_smi(smi):
71
+ try:
72
+ mol = Chem.MolFromSmiles(smi)
73
+ except:
74
+ print(f"convert smi {smi} to molecule failed!")
75
+ mol = None
76
+
77
+ dps, _ = getMolDescriptors(mol)
78
+ return np.array(dps)
79
+
80
+
81
+ def get_pep_dps(smi_list):
82
+ if len(smi_list) == 0:
83
+ return np.zeros((0, 213))
84
+ return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
85
+
86
+ def check_smi_validity(smiles: list):
87
+ valid_smi, valid_idx = [], []
88
+ for idx, smi in enumerate(smiles):
89
+ try:
90
+ mol = Chem.MolFromSmiles(smi) if smi else None
91
+ if mol:
92
+ valid_smi.append(smi)
93
+ valid_idx.append(idx)
94
+ except Exception as e:
95
+ # logger.debug(f'Error: {e} in smiles {smi}')
96
+ pass
97
+ return valid_smi, valid_idx
98
+
99
+ class Permeability:
100
+
101
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
102
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
103
+ self.predictor = xgb.Booster(model_file=f'{base_path}/tr2d2-pep/scoring/functions/classifiers/permeability-xgboost.json')
104
+ if emb_model is not None:
105
+ self.emb_model = emb_model.to(self.device).eval()
106
+ else:
107
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
108
+
109
+ self.tokenizer = tokenizer
110
+
111
+ def generate_embeddings(self, sequences):
112
+ embeddings = []
113
+ for sequence in sequences:
114
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
115
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
116
+ with torch.no_grad():
117
+ output = self.emb_model(**tokenized)
118
+ # Mean pooling across sequence length
119
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
120
+ embeddings.append(embedding)
121
+ return np.array(embeddings)
122
+
123
+ def get_features(self, input_seqs: list, dps=False, fps=False):
124
+ #valid_smiles, valid_idxes = check_smi_validity(input_seqs)
125
+
126
+
127
+ if fps:
128
+ fingerprints = fingerprints_from_smiles(input_seqs)[0]
129
+ else:
130
+ fingerprints = torch.empty((len(input_seqs), 0))
131
+
132
+ if dps:
133
+ descriptors = get_pep_dps(input_seqs)
134
+ else:
135
+ descriptors = torch.empty((len(input_seqs), 0))
136
+
137
+ embeddings = self.generate_embeddings(input_seqs)
138
+ # logger.debug(f'X_fps.shape: {X_fps.shape}, X_dps.shape: {X_dps.shape}')
139
+
140
+ features = np.concatenate([fingerprints, descriptors, embeddings], axis=1)
141
+
142
+ return features
143
+
144
+ def get_scores(self, input_seqs: list):
145
+ scores = -10 * np.ones(len(input_seqs))
146
+ features = self.get_features(input_seqs)
147
+
148
+ if len(features) == 0:
149
+ return scores
150
+
151
+ features = np.nan_to_num(features, nan=0.)
152
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
153
+
154
+ features = xgb.DMatrix(features)
155
+
156
+ scores = self.predictor.predict(features)
157
+ return scores
158
+
159
+ def __call__(self, input_seqs: list):
160
+ scores = self.get_scores(input_seqs)
161
+ return scores
162
+
163
+ def unittest():
164
+ permeability = Permeability()
165
+ seq = ['N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1cNc2c1cc(O)cc2)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H]([C@@H](O)C(C)C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O']
166
+ scores = permeability(input_seqs=seq)
167
+ print(scores)
168
+
169
+
170
+ if __name__ == '__main__':
171
+ unittest()
scoring/functions/solubility.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xgboost as xgb
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoModelForMaskedLM
5
+ import warnings
6
+ import numpy as np
7
+ from rdkit import rdBase
8
+
9
+ rdBase.DisableLog('rdApp.error')
10
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
11
+ warnings.filterwarnings("ignore", category=UserWarning)
12
+ warnings.filterwarnings("ignore", category=FutureWarning)
13
+
14
+ class Solubility:
15
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
17
+ self.predictor = xgb.Booster(model_file=f'{base_path}/tr2d2-pep/scoring/functions/classifiers/solubility-xgboost.json')
18
+ if emb_model is not None:
19
+ self.emb_model = emb_model.to(self.device).eval()
20
+ else:
21
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
22
+
23
+ self.tokenizer = tokenizer
24
+
25
+ def generate_embeddings(self, sequences):
26
+ embeddings = []
27
+ for sequence in sequences:
28
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
29
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
30
+ with torch.no_grad():
31
+ output = self.emb_model(**tokenized)
32
+ # Mean pooling across sequence length
33
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
34
+ embeddings.append(embedding)
35
+ return np.array(embeddings)
36
+
37
+ def get_scores(self, input_seqs: list):
38
+ scores = np.zeros(len(input_seqs))
39
+ features = self.generate_embeddings(input_seqs)
40
+
41
+ if len(features) == 0:
42
+ return scores
43
+
44
+ features = np.nan_to_num(features, nan=0.)
45
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
46
+
47
+ features = xgb.DMatrix(features)
48
+
49
+ scores = self.predictor.predict(features)
50
+ return scores
51
+
52
+ def __call__(self, input_seqs: list):
53
+ scores = self.get_scores(input_seqs)
54
+ return scores
55
+
56
+ def unittest():
57
+ solubility = Solubility()
58
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
59
+ scores = solubility(input_seqs=seq)
60
+ print(scores)
61
+
62
+ if __name__ == '__main__':
63
+ unittest()
scoring/scoring_functions.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
4
+ from transformers import AutoModelForMaskedLM
5
+ from scoring.functions.binding import BindingAffinity
6
+ from scoring.functions.permeability import Permeability
7
+ from scoring.functions.solubility import Solubility
8
+ from scoring.functions.hemolysis import Hemolysis
9
+ from scoring.functions.nonfouling import Nonfouling
10
+
11
+ base_path = 'To Be Added'
12
+
13
+ def resolve_device(requested):
14
+ if requested is None or str(requested).lower() == "auto":
15
+ if torch.cuda.is_available() and torch.cuda.device_count() > 0:
16
+ return torch.device("cuda:0")
17
+ return torch.device("cpu")
18
+
19
+ try:
20
+ device = torch.device(requested)
21
+ except Exception:
22
+ return torch.device("cpu")
23
+
24
+ if device.type != "cuda":
25
+ return device
26
+
27
+ if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
28
+ return torch.device("cpu")
29
+
30
+ index = device.index if device.index is not None else 0
31
+ if index is None or index < 0 or index >= torch.cuda.device_count():
32
+ return torch.device("cuda:0")
33
+
34
+ return torch.device(f"cuda:{index}")
35
+
36
+ class ScoringFunctions:
37
+ def __init__(self, score_func_names=None, prot_seqs=None, device=None):
38
+ """
39
+ Class for generating score vectors given generated sequence
40
+
41
+ Args:
42
+ score_func_names: list of scoring function names to be evaluated
43
+ score_weights: weights to scale scores (default: 1)
44
+ target_protein: sequence of target protein binder
45
+ """
46
+ device = resolve_device(device)
47
+ emb_model = AutoModelForMaskedLM.from_pretrained(
48
+ 'aaronfeller/PeptideCLM-23M-all'
49
+ ).roformer.to(device).eval()
50
+ tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/tr2d2-pep/tokenizer/new_vocab.txt',
51
+ f'{base_path}/tr2d2-pep/tokenizer/new_splits.txt')
52
+ prot_seqs = prot_seqs if prot_seqs is not None else []
53
+
54
+ if score_func_names is None:
55
+ # just do unmasking based on validity of peptide bonds
56
+ self.score_func_names = []
57
+ else:
58
+ self.score_func_names = score_func_names
59
+
60
+ # self.weights = np.array([1] * len(self.score_func_names) if score_weights is None else score_weights)
61
+
62
+ # binding affinities
63
+ self.target_protein = prot_seqs
64
+ print(len(prot_seqs))
65
+
66
+ if ('binding_affinity1' in score_func_names) and (len(prot_seqs) == 1):
67
+ binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
68
+ binding_affinity2 = None
69
+ elif ('binding_affinity1' in score_func_names) and ('binding_affinity2' in score_func_names) and (len(prot_seqs) == 2):
70
+ binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
71
+ binding_affinity2 = BindingAffinity(prot_seqs[1], tokenizer=tokenizer, base_path=base_path, device=device)
72
+ else:
73
+ print("here")
74
+ binding_affinity1 = None
75
+ binding_affinity2 = None
76
+
77
+ permeability = Permeability(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
78
+ sol = Solubility(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
79
+ nonfouling = Nonfouling(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
80
+ hemo = Hemolysis(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
81
+
82
+ self.all_funcs = {'binding_affinity1': binding_affinity1,
83
+ 'binding_affinity2': binding_affinity2,
84
+ 'permeability': permeability,
85
+ 'nonfouling': nonfouling,
86
+ 'solubility': sol,
87
+ 'hemolysis': hemo
88
+ }
89
+
90
+ def forward(self, input_seqs):
91
+ scores = []
92
+
93
+ for i, score_func in enumerate(self.score_func_names):
94
+ score = self.all_funcs[score_func](input_seqs = input_seqs)
95
+
96
+ scores.append(score)
97
+
98
+ # convert to numpy arrays with shape (num_sequences, num_functions)
99
+ scores = np.float32(scores).T
100
+
101
+ return scores
102
+
103
+ def __call__(self, input_seqs: list):
104
+ return self.forward(input_seqs)
setup.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="td3b",
5
+ version="0.1.0",
6
+ description="TD3B: Transition-Directed Discrete Diffusion for Allosteric Binder Generation",
7
+ packages=find_packages(),
8
+ python_requires=">=3.10",
9
+ )
td3b/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD3B: Transition-Directed Discrete Diffusion for Binders
3
+ A module extending TR2-D2 with directional allosteric control.
4
+ """
5
+
6
+ from .direction_oracle import DirectionalOracle
7
+ from .td3b_scoring import TD3BRewardFunction, TD3BConfidenceWeighting, create_td3b_reward_function
8
+ from .td3b_losses import ContrastiveLoss, InfoNCELoss, TD3BTotalLoss, extract_embeddings_from_mdlm
9
+ from .td3b_mcts import TD3B_MCTS, create_td3b_mcts
10
+ from .td3b_finetune import td3b_finetune, add_td3b_sampling_to_model
11
+ from .data_utils import TD3BDataset, load_td3b_data
12
+
13
+ __all__ = [
14
+ 'DirectionalOracle',
15
+ 'TD3BRewardFunction',
16
+ 'TD3BConfidenceWeighting',
17
+ 'create_td3b_reward_function',
18
+ 'ContrastiveLoss',
19
+ 'InfoNCELoss',
20
+ 'TD3BTotalLoss',
21
+ 'extract_embeddings_from_mdlm',
22
+ 'TD3B_MCTS',
23
+ 'create_td3b_mcts',
24
+ 'td3b_finetune',
25
+ 'add_td3b_sampling_to_model',
26
+ 'TD3BDataset',
27
+ 'load_td3b_data',
28
+ ]
29
+
30
+ __version__ = '0.1.0'
td3b/data_utils.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD3B Data Utilities
3
+ Handles loading and preprocessing of TD3B_data.csv for both oracle training and finetuning.
4
+ """
5
+
6
+ import pandas as pd
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+ from typing import Dict, List, Optional, Tuple
11
+ import sys
12
+
13
+ try:
14
+ from rdkit import Chem
15
+ except ImportError: # pragma: no cover - rdkit may be optional in some setups
16
+ Chem = None
17
+
18
+ sys.path.append('..')
19
+
20
+ AA_SET = set("ACDEFGHIKLMNPQRSTVWY")
21
+
22
+
23
+ def is_amino_acid_sequence(seq: str) -> bool:
24
+ if not isinstance(seq, str) or not seq:
25
+ return False
26
+ seq = seq.strip().upper()
27
+ return all(ch in AA_SET for ch in seq)
28
+
29
+
30
+ def aa_sequence_to_smiles(seq: str) -> Optional[str]:
31
+ if Chem is None or not is_amino_acid_sequence(seq):
32
+ return None
33
+ try:
34
+ mol = Chem.MolFromSequence(seq)
35
+ except Exception:
36
+ return None
37
+ if mol is None:
38
+ return None
39
+ return Chem.MolToSmiles(mol, isomericSmiles=True)
40
+
41
+
42
+ def peptide_seq_to_smiles(seq: str) -> str:
43
+ smiles = aa_sequence_to_smiles(seq)
44
+ return smiles if smiles is not None else seq
45
+
46
+
47
+ def smiles_token_length(smiles: str, tokenizer) -> int:
48
+ if tokenizer is None:
49
+ return len(smiles)
50
+ tokens = tokenizer(smiles, return_tensors="pt")["input_ids"][0]
51
+ return int(tokens.numel())
52
+
53
+
54
+ class TD3BDataset(Dataset):
55
+ """
56
+ Dataset for TD3B that loads peptide-protein pairs with directional labels.
57
+
58
+ Supports both:
59
+ 1. Oracle training: uses all pairs for training f_φ
60
+ 2. Finetuning: provides target proteins for conditioning during RL
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ data_path: str,
66
+ mode: str = 'oracle', # 'oracle' or 'finetune'
67
+ peptide_tokenizer=None,
68
+ protein_tokenizer=None,
69
+ max_peptide_length: int = 200,
70
+ max_protein_length: int = 1000,
71
+ target_protein_id: Optional[str] = None, # For finetuning mode
72
+ convert_peptide_to_smiles: bool = True,
73
+ ):
74
+ """
75
+ Args:
76
+ data_path: Path to TD3B_data.csv
77
+ mode: 'oracle' for training f_φ, 'finetune' for RL conditioning
78
+ peptide_tokenizer: Tokenizer for peptide sequences
79
+ protein_tokenizer: Tokenizer for protein sequences (ESM-2)
80
+ max_peptide_length: Maximum peptide sequence length
81
+ max_protein_length: Maximum protein sequence length
82
+ target_protein_id: UniProt ID for target protein (finetuning mode)
83
+ """
84
+ self.mode = mode
85
+ self.data_path = data_path
86
+ self.peptide_tokenizer = peptide_tokenizer
87
+ self.protein_tokenizer = protein_tokenizer
88
+ self.max_peptide_length = max_peptide_length
89
+ self.max_protein_length = max_protein_length
90
+ self.convert_peptide_to_smiles = convert_peptide_to_smiles
91
+
92
+ # Load data
93
+ self.data = pd.read_csv(data_path)
94
+ print(f"Loaded {len(self.data)} peptide-protein pairs from {data_path}")
95
+
96
+ # Filter by target protein if in finetune mode
97
+ if mode == 'finetune' and target_protein_id is not None:
98
+ self.data = self.data[self.data['Target_UniProt_ID'] == target_protein_id]
99
+ print(f"Filtered to {len(self.data)} pairs for target {target_protein_id}")
100
+
101
+ # Process labels
102
+ self.label_map = {
103
+ 'agonist': 1.0,
104
+ 'antagonist': -1.0,
105
+ 'neutral': 0.0,
106
+ }
107
+
108
+ # Convert action descriptions to numerical labels
109
+ self.data['numeric_label'] = self.data['label'].map(self.label_map)
110
+
111
+ # Assign confidence based on action description
112
+ self.data['confidence'] = self.data['Action'].apply(self._action_to_confidence)
113
+
114
+ def _action_to_confidence(self, action: str) -> float:
115
+ """
116
+ Convert action description to confidence score.
117
+
118
+ Full agonist/antagonist: 1.0
119
+ Partial/Weak: 0.7
120
+ Others: 0.5
121
+ """
122
+ action_lower = action.lower()
123
+
124
+ if 'full' in action_lower:
125
+ return 1.0
126
+ elif 'partial' in action_lower or 'weak' in action_lower:
127
+ return 0.7
128
+ elif 'slows' in action_lower or 'modulator' in action_lower:
129
+ return 0.5
130
+ else:
131
+ return 0.8 # Default for unspecified agonist/antagonist
132
+
133
+ def __len__(self):
134
+ return len(self.data)
135
+
136
+ def __getitem__(self, idx):
137
+ row = self.data.iloc[idx]
138
+
139
+ # Get sequences
140
+ peptide_seq = row['Ligand_Sequence']
141
+ protein_seq = row['Target_Sequence']
142
+ peptide_smiles = self._peptide_to_smiles(peptide_seq)
143
+ peptide_smiles_length = smiles_token_length(peptide_smiles, self.peptide_tokenizer)
144
+
145
+ # Tokenize (placeholder - actual tokenization depends on mode)
146
+ if self.peptide_tokenizer is not None:
147
+ peptide_tokens = self._tokenize_peptide(peptide_smiles)
148
+ else:
149
+ peptide_tokens = torch.zeros(self.max_peptide_length, dtype=torch.long)
150
+
151
+ if self.protein_tokenizer is not None:
152
+ protein_tokens = self._tokenize_protein(protein_seq)
153
+ else:
154
+ protein_tokens = self._tokenize_protein_placeholder(protein_seq)
155
+
156
+ # Get label and confidence
157
+ label = torch.tensor(row['numeric_label'], dtype=torch.float32)
158
+ confidence = torch.tensor(row['confidence'], dtype=torch.float32)
159
+
160
+ return {
161
+ 'peptide_seq': peptide_seq,
162
+ 'peptide_smiles': peptide_smiles,
163
+ 'peptide_smiles_length': peptide_smiles_length,
164
+ 'protein_seq': protein_seq,
165
+ 'peptide_tokens': peptide_tokens,
166
+ 'protein_tokens': protein_tokens,
167
+ 'label': label,
168
+ 'confidence': confidence,
169
+ 'target_id': row['Target_UniProt_ID'],
170
+ 'ligand_id': row['Ligand_UniProt_ID'],
171
+ 'action': row['Action']
172
+ }
173
+
174
+ def _peptide_to_smiles(self, peptide_seq: str) -> str:
175
+ if not self.convert_peptide_to_smiles:
176
+ return peptide_seq
177
+ return peptide_seq_to_smiles(peptide_seq)
178
+
179
+ def _tokenize_peptide(self, peptide_seq: str) -> torch.Tensor:
180
+ """Tokenize peptide sequence using provided tokenizer."""
181
+ tokens = self.peptide_tokenizer(
182
+ peptide_seq,
183
+ return_tensors='pt',
184
+ padding='max_length',
185
+ max_length=self.max_peptide_length,
186
+ truncation=True
187
+ )['input_ids'].squeeze(0)
188
+ return tokens
189
+
190
+ def _tokenize_protein_placeholder(self, protein_seq: str) -> torch.Tensor:
191
+ """
192
+ Placeholder protein tokenizer (character-level).
193
+
194
+ NOTE: Replace with ESM-2 tokenizer in production:
195
+ from esm import pretrained
196
+ _, alphabet = pretrained.esm2_t33_650M_UR50D()
197
+ batch_converter = alphabet.get_batch_converter()
198
+ _, _, tokens = batch_converter([("protein", protein_seq)])
199
+ """
200
+ # Amino acid to index mapping
201
+ aa_to_idx = {aa: i+1 for i, aa in enumerate('ACDEFGHIKLMNPQRSTVWY')}
202
+ aa_to_idx['<PAD>'] = 0
203
+ aa_to_idx['<UNK>'] = 21
204
+
205
+ # Convert to indices
206
+ indices = [aa_to_idx.get(aa, aa_to_idx['<UNK>']) for aa in protein_seq]
207
+
208
+ # Pad or truncate
209
+ if len(indices) > self.max_protein_length:
210
+ indices = indices[:self.max_protein_length]
211
+ else:
212
+ indices += [0] * (self.max_protein_length - len(indices))
213
+
214
+ return torch.tensor(indices, dtype=torch.long)
215
+
216
+ def _tokenize_protein(self, protein_seq: str) -> torch.Tensor:
217
+ """Tokenize protein using ESM-2 tokenizer if available."""
218
+ if self.protein_tokenizer is None:
219
+ return self._tokenize_protein_placeholder(protein_seq)
220
+
221
+ # Use ESM-2 tokenizer
222
+ # TODO: Implement when ESM-2 is integrated
223
+ return self._tokenize_protein_placeholder(protein_seq)
224
+
225
+ def get_target_proteins(self) -> Dict[str, str]:
226
+ """
227
+ Get dictionary of unique target proteins.
228
+
229
+ Returns:
230
+ dict: {UniProt_ID: Sequence}
231
+ """
232
+ unique_targets = self.data.drop_duplicates(subset=['Target_UniProt_ID'])
233
+ return dict(zip(unique_targets['Target_UniProt_ID'], unique_targets['Target_Sequence']))
234
+
235
+ def get_ligands_for_target(self, target_id: str) -> List[Dict]:
236
+ """
237
+ Get all ligands (peptides) for a specific target protein.
238
+
239
+ Args:
240
+ target_id: Target protein UniProt ID
241
+
242
+ Returns:
243
+ List of dicts with ligand info
244
+ """
245
+ target_data = self.data[self.data['Target_UniProt_ID'] == target_id]
246
+
247
+ ligands = []
248
+ for _, row in target_data.iterrows():
249
+ ligands.append({
250
+ 'sequence': row['Ligand_Sequence'],
251
+ 'uniprot_id': row['Ligand_UniProt_ID'],
252
+ 'label': row['numeric_label'],
253
+ 'confidence': row['confidence'],
254
+ 'action': row['Action']
255
+ })
256
+
257
+ return ligands
258
+
259
+
260
+ def load_td3b_data(
261
+ data_path: str,
262
+ mode: str = 'oracle',
263
+ target_protein_id: Optional[str] = None
264
+ ) -> Tuple[pd.DataFrame, Dict]:
265
+ """
266
+ Load and summarize TD3B data.
267
+
268
+ Args:
269
+ data_path: Path to TD3B_data.csv
270
+ mode: 'oracle' or 'finetune'
271
+ target_protein_id: Filter by target protein (finetuning mode)
272
+
273
+ Returns:
274
+ data: Filtered DataFrame
275
+ stats: Dictionary of statistics
276
+ """
277
+ data = pd.read_csv(data_path)
278
+
279
+ # Filter if needed
280
+ if mode == 'finetune' and target_protein_id is not None:
281
+ data = data[data['Target_UniProt_ID'] == target_protein_id]
282
+
283
+ # Compute statistics
284
+ stats = {
285
+ 'total_pairs': len(data),
286
+ 'unique_targets': data['Target_UniProt_ID'].nunique(),
287
+ 'unique_ligands': data['Ligand_UniProt_ID'].nunique(),
288
+ 'agonist_count': (data['label'] == 'agonist').sum(),
289
+ 'antagonist_count': (data['label'] == 'antagonist').sum(),
290
+ 'action_distribution': data['Action'].value_counts().to_dict()
291
+ }
292
+
293
+ return data, stats
294
+
295
+
296
+ def create_target_dataset_for_finetuning(
297
+ data_path: str,
298
+ target_protein_id: str,
299
+ desired_direction: str = 'agonist'
300
+ ) -> Dict:
301
+ """
302
+ Create a dataset for TD3B finetuning focused on a specific target.
303
+
304
+ Args:
305
+ data_path: Path to TD3B_data.csv
306
+ target_protein_id: Target protein UniProt ID
307
+ desired_direction: 'agonist' or 'antagonist'
308
+
309
+ Returns:
310
+ dict with target protein info and example ligands
311
+ """
312
+ data = pd.read_csv(data_path)
313
+
314
+ # Get target protein info
315
+ target_data = data[data['Target_UniProt_ID'] == target_protein_id]
316
+
317
+ if len(target_data) == 0:
318
+ raise ValueError(f"No data found for target {target_protein_id}")
319
+
320
+ # Get protein sequence (should be same for all rows)
321
+ protein_seq = target_data.iloc[0]['Target_Sequence']
322
+
323
+ # Get ligands with desired direction
324
+ direction_map = {'agonist': 'agonist', 'antagonist': 'antagonist'}
325
+ direction_ligands = target_data[target_data['label'] == direction_map[desired_direction]]
326
+
327
+ # Also get opposite direction for contrastive learning
328
+ opposite_direction = 'antagonist' if desired_direction == 'agonist' else 'agonist'
329
+ opposite_ligands = target_data[target_data['label'] == opposite_direction]
330
+
331
+ return {
332
+ 'target_protein_id': target_protein_id,
333
+ 'target_protein_seq': protein_seq,
334
+ 'desired_direction': desired_direction,
335
+ 'n_desired_examples': len(direction_ligands),
336
+ 'n_opposite_examples': len(opposite_ligands),
337
+ 'desired_ligands': direction_ligands[['Ligand_Sequence', 'Action', 'Ligand_UniProt_ID']].to_dict('records'),
338
+ 'opposite_ligands': opposite_ligands[['Ligand_Sequence', 'Action', 'Ligand_UniProt_ID']].to_dict('records')
339
+ }
340
+
341
+
342
+ if __name__ == "__main__":
343
+ # Example usage
344
+ data_path = "../TD3B_data.csv"
345
+
346
+ print("=" * 80)
347
+ print("TD3B Data Loading Example")
348
+ print("=" * 80)
349
+
350
+ # Load and summarize data
351
+ data, stats = load_td3b_data(data_path, mode='oracle')
352
+
353
+ print("\nDataset Statistics:")
354
+ for key, value in stats.items():
355
+ print(f" {key}: {value}")
356
+
357
+ # Create dataset for oracle training
358
+ print("\n" + "=" * 80)
359
+ print("Oracle Training Dataset")
360
+ print("=" * 80)
361
+
362
+ dataset = TD3BDataset(data_path, mode='oracle')
363
+ print(f"Dataset size: {len(dataset)}")
364
+
365
+ # Sample first item
366
+ sample = dataset[0]
367
+ print(f"\nSample item:")
368
+ print(f" Target: {sample['target_id']}")
369
+ print(f" Ligand: {sample['ligand_id']}")
370
+ print(f" Label: {sample['label'].item()}")
371
+ print(f" Confidence: {sample['confidence'].item()}")
372
+ print(f" Action: {sample['action']}")
373
+
374
+ # Create finetuning dataset for a specific target
375
+ print("\n" + "=" * 80)
376
+ print("Finetuning Dataset Example")
377
+ print("=" * 80)
378
+
379
+ # Get first target
380
+ targets = dataset.get_target_proteins()
381
+ first_target_id = list(targets.keys())[0]
382
+
383
+ finetune_info = create_target_dataset_for_finetuning(
384
+ data_path,
385
+ first_target_id,
386
+ desired_direction='agonist'
387
+ )
388
+
389
+ print(f"\nTarget: {finetune_info['target_protein_id']}")
390
+ print(f"Desired direction: {finetune_info['desired_direction']}")
391
+ print(f"Number of agonist examples: {finetune_info['n_desired_examples']}")
392
+ print(f"Number of antagonist examples: {finetune_info['n_opposite_examples']}")
td3b/direction_oracle.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GPCR Agonist Classifier - TR2-D2 Inference Script
4
+ """
5
+
6
+ import argparse
7
+ import logging
8
+ import os
9
+ import sys
10
+ from types import SimpleNamespace
11
+ from typing import Dict, List, Optional, Tuple
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from transformers import EsmModel, EsmTokenizer
17
+
18
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
19
+ if PROJECT_ROOT not in sys.path:
20
+ sys.path.insert(0, PROJECT_ROOT)
21
+
22
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
23
+ from roformer import Roformer
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ def resolve_device(requested: Optional[str]) -> torch.device:
29
+ if requested is None or str(requested).lower() == "auto":
30
+ if torch.cuda.is_available() and torch.cuda.device_count() > 0:
31
+ return torch.device("cuda:0")
32
+ return torch.device("cpu")
33
+
34
+ try:
35
+ device = torch.device(requested)
36
+ except Exception as exc:
37
+ logger.warning("Invalid device '%s': %s. Falling back to CPU.", requested, exc)
38
+ return torch.device("cpu")
39
+ if device.type != "cuda":
40
+ return device
41
+
42
+ if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
43
+ logger.warning("CUDA requested but not available; falling back to CPU")
44
+ return torch.device("cpu")
45
+
46
+ index = device.index if device.index is not None else 0
47
+ count = torch.cuda.device_count()
48
+ if index is None or index < 0 or index >= count:
49
+ logger.warning(
50
+ "CUDA device %s requested but only %d visible; using cuda:0",
51
+ index,
52
+ count
53
+ )
54
+ return torch.device("cuda:0")
55
+
56
+ return torch.device(f"cuda:{index}")
57
+
58
+ # -------------------------
59
+ # Peptide to SMILES
60
+ # -------------------------
61
+ def peptide_to_smiles(seq: str) -> str:
62
+ from rdkit import Chem
63
+ seq = seq.strip().upper()
64
+ mol = Chem.MolFromSequence(seq)
65
+ if mol is None:
66
+ raise ValueError(f"RDKit failed to convert peptide '{seq}' to SMILES")
67
+ return Chem.MolToSmiles(mol)
68
+
69
+ # -------------------------
70
+ # Self-Attention Block
71
+ # -------------------------
72
+ class SelfAttentionBlock(nn.Module):
73
+ def __init__(self, d_model, n_heads, dropout=0.1):
74
+ super().__init__()
75
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout, batch_first=True)
76
+ self.norm = nn.LayerNorm(d_model)
77
+ self.dropout = nn.Dropout(dropout)
78
+
79
+ def forward(self, x, key_padding_mask=None):
80
+ attn_out, _ = self.self_attn(x, x, x, key_padding_mask=key_padding_mask)
81
+ x = self.norm(x + self.dropout(attn_out))
82
+ return x
83
+
84
+ # -------------------------
85
+ # Cross-Attention Module
86
+ # -------------------------
87
+ class BiMultiHeadCrossAttention(nn.Module):
88
+ def __init__(self, d_model, n_heads, dropout=0.1):
89
+ super().__init__()
90
+ self.prot_to_lig = nn.MultiheadAttention(d_model, n_heads, dropout, batch_first=True)
91
+ self.lig_to_prot = nn.MultiheadAttention(d_model, n_heads, dropout, batch_first=True)
92
+ self.prot_ln = nn.LayerNorm(d_model)
93
+ self.lig_ln = nn.LayerNorm(d_model)
94
+ self.dropout = nn.Dropout(dropout)
95
+
96
+ def forward(self, prot_h, lig_h, prot_kpm=None, lig_kpm=None):
97
+ prot_ctx, _ = self.prot_to_lig(prot_h, lig_h, lig_h, key_padding_mask=lig_kpm)
98
+ prot_h_out = self.prot_ln(prot_h + self.dropout(prot_ctx))
99
+
100
+ lig_ctx, _ = self.lig_to_prot(lig_h, prot_h, prot_h, key_padding_mask=prot_kpm)
101
+ lig_h_out = self.lig_ln(lig_h + self.dropout(lig_ctx))
102
+
103
+ return prot_h_out, lig_h_out
104
+
105
+ # -------------------------
106
+ # TR2-D2 Encoder Wrapper
107
+ # -------------------------
108
+ class TR2D2RoFormerEncoder(nn.Module):
109
+ def __init__(self, config, tokenizer, checkpoint_path=None, device="cpu"):
110
+ super().__init__()
111
+ self.device = device
112
+ self.encoder = Roformer(config, tokenizer, device=device)
113
+
114
+ if checkpoint_path:
115
+ print(f" Loading TR2-D2 checkpoint...")
116
+ ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
117
+ state_dict = ckpt.get("state_dict", ckpt)
118
+ roformer_state = {
119
+ k.replace("model.", "").replace("backbone.", ""): v
120
+ for k, v in state_dict.items()
121
+ if "roformer" in k or "encoder" in k or "backbone" in k
122
+ }
123
+ self.encoder.model.load_state_dict(roformer_state, strict=False)
124
+ print(" TR2-D2 checkpoint loaded")
125
+
126
+ for p in self.encoder.parameters():
127
+ p.requires_grad = False
128
+ self.encoder.eval()
129
+
130
+ def forward(self, input_ids, attention_mask, inputs_embeds=None):
131
+ if attention_mask is None:
132
+ raise ValueError("attention_mask is required for ligand encoding.")
133
+ attention_mask = attention_mask.to(self.device)
134
+ if inputs_embeds is not None:
135
+ inputs_embeds = inputs_embeds.to(self.device)
136
+ out = self.encoder.model.roformer(
137
+ inputs_embeds=inputs_embeds,
138
+ attention_mask=attention_mask
139
+ )
140
+ else:
141
+ input_ids = input_ids.to(self.device)
142
+ with torch.no_grad():
143
+ out = self.encoder.model.roformer(
144
+ input_ids=input_ids,
145
+ attention_mask=attention_mask
146
+ )
147
+ return out.last_hidden_state
148
+
149
+ # -------------------------
150
+ # Full GPCR Model
151
+ # -------------------------
152
+ class ESM_TR2D2_GPCRClassifier(nn.Module):
153
+ """
154
+ GPCR Agonist Classifier with TR2-D2
155
+
156
+ Architecture:
157
+ 1. ESM2 (protein) + TR2-D2 RoFormer (ligand)
158
+ 2. Projections to common dimension
159
+ 3. Self-Attention (1 layer each)
160
+ 4. BiDirectional Cross-Attention (2 stacked layers)
161
+ 5. Masked Average Pooling
162
+ 6. MLP Classifier
163
+ """
164
+ def __init__(
165
+ self,
166
+ esm_name,
167
+ tr2d2_config,
168
+ lig_tokenizer,
169
+ tr2d2_checkpoint=None,
170
+ d_model=256,
171
+ n_heads=4,
172
+ n_self_attn_layers=1,
173
+ n_bmca_layers=2,
174
+ dropout=0.3,
175
+ device="cuda",
176
+ esm_cache_dir=None,
177
+ esm_local_files_only=False
178
+ ):
179
+ super().__init__()
180
+ self.device = device
181
+
182
+ # Frozen encoders
183
+ print("Loading ESM2 protein encoder...")
184
+ self.esm = EsmModel.from_pretrained(
185
+ esm_name,
186
+ cache_dir=esm_cache_dir,
187
+ local_files_only=esm_local_files_only
188
+ )
189
+ for p in self.esm.parameters():
190
+ p.requires_grad = False
191
+ self.esm.eval()
192
+
193
+ print("Loading TR2-D2 ligand encoder...")
194
+ self.ligand_encoder = TR2D2RoFormerEncoder(
195
+ tr2d2_config, lig_tokenizer, tr2d2_checkpoint, device
196
+ )
197
+
198
+ esm_dim = self.esm.config.hidden_size
199
+ lig_dim = tr2d2_config.roformer.hidden_size
200
+
201
+ self.prot_proj = nn.Linear(esm_dim, d_model)
202
+ self.lig_proj = nn.Linear(lig_dim, d_model)
203
+
204
+ # Self-attention
205
+ self.prot_self_attn_layers = nn.ModuleList([
206
+ SelfAttentionBlock(d_model, n_heads, dropout)
207
+ for _ in range(n_self_attn_layers)
208
+ ])
209
+ self.lig_self_attn_layers = nn.ModuleList([
210
+ SelfAttentionBlock(d_model, n_heads, dropout)
211
+ for _ in range(n_self_attn_layers)
212
+ ])
213
+
214
+ # Cross-attention
215
+ self.bmca_layers = nn.ModuleList([
216
+ BiMultiHeadCrossAttention(d_model, n_heads, dropout)
217
+ for _ in range(n_bmca_layers)
218
+ ])
219
+
220
+ # Classifier
221
+ self.classifier = nn.Sequential(
222
+ nn.Linear(2 * d_model, d_model),
223
+ nn.ReLU(),
224
+ nn.Dropout(dropout),
225
+ nn.Linear(d_model, 2)
226
+ )
227
+
228
+ def forward(self, prot_tokens, lig_tokens, lig_inputs_embeds=None):
229
+ prot_kpm = prot_tokens["attention_mask"].eq(0)
230
+ lig_kpm = lig_tokens["attention_mask"].eq(0)
231
+
232
+ with torch.no_grad():
233
+ prot_out = self.esm(**prot_tokens).last_hidden_state
234
+
235
+ lig_out = self.ligand_encoder(
236
+ lig_tokens["input_ids"],
237
+ lig_tokens["attention_mask"],
238
+ inputs_embeds=lig_inputs_embeds
239
+ )
240
+
241
+ prot_h = self.prot_proj(prot_out)
242
+ lig_h = self.lig_proj(lig_out)
243
+
244
+ # Self-attention
245
+ for self_attn in self.prot_self_attn_layers:
246
+ prot_h = self_attn(prot_h, key_padding_mask=prot_kpm)
247
+ for self_attn in self.lig_self_attn_layers:
248
+ lig_h = self_attn(lig_h, key_padding_mask=lig_kpm)
249
+
250
+ # Cross-attention (2 stacked)
251
+ for bmca in self.bmca_layers:
252
+ prot_h, lig_h = bmca(prot_h, lig_h, prot_kpm=prot_kpm, lig_kpm=lig_kpm)
253
+
254
+ # Masked average pooling
255
+ prot_mask = prot_tokens["attention_mask"].unsqueeze(-1)
256
+ lig_mask = lig_tokens["attention_mask"].unsqueeze(-1)
257
+
258
+ prot_repr = (prot_h * prot_mask).sum(dim=1) / prot_mask.sum(dim=1).clamp(min=1)
259
+ lig_repr = (lig_h * lig_mask).sum(dim=1) / lig_mask.sum(dim=1).clamp(min=1)
260
+
261
+ return self.classifier(torch.cat([prot_repr, lig_repr], dim=-1))
262
+
263
+ # -------------------------
264
+ # Tokenization
265
+ # -------------------------
266
+ def create_tr2d2_config(vocab_size):
267
+ return SimpleNamespace(
268
+ roformer=SimpleNamespace(
269
+ vocab_size=vocab_size,
270
+ hidden_size=768,
271
+ n_layers=8,
272
+ n_heads=8,
273
+ max_position_embeddings=1035
274
+ )
275
+ )
276
+
277
+
278
+ def _load_state_dict_flexible(model: nn.Module, state_dict: Dict, strict: bool = True) -> None:
279
+ try:
280
+ model.load_state_dict(state_dict, strict=strict)
281
+ return
282
+ except RuntimeError as exc:
283
+ model_keys = set(model.state_dict().keys())
284
+ filtered = {k: v for k, v in state_dict.items() if k in model_keys}
285
+ logger.warning("Strict load failed: %s", exc)
286
+ logger.warning(
287
+ "Retrying with filtered keys (%d/%d) and strict=False",
288
+ len(filtered),
289
+ len(state_dict)
290
+ )
291
+ incompatible = model.load_state_dict(filtered, strict=False)
292
+ if incompatible.missing_keys:
293
+ logger.warning("Missing keys (first 10): %s", incompatible.missing_keys[:10])
294
+ if incompatible.unexpected_keys:
295
+ logger.warning("Unexpected keys (first 10): %s", incompatible.unexpected_keys[:10])
296
+
297
+ def tokenize_protein(seq, tokenizer, device):
298
+ out = tokenizer(
299
+ seq,
300
+ return_tensors="pt",
301
+ padding=True,
302
+ truncation=True,
303
+ max_length=1024,
304
+ add_special_tokens=True
305
+ )
306
+ return {k: v.to(device) for k, v in out.items()}
307
+
308
+ def tokenize_ligand(smiles, tokenizer, max_len, device):
309
+ enc = tokenizer(
310
+ smiles,
311
+ return_tensors="pt",
312
+ truncation=True,
313
+ max_length=max_len,
314
+ add_special_tokens=True
315
+ )
316
+ ids = enc["input_ids"].squeeze(0)
317
+ att = enc["attention_mask"].squeeze(0)
318
+
319
+ pad = max_len - ids.numel()
320
+ if pad > 0:
321
+ ids = torch.cat([ids, torch.full((pad,), tokenizer.pad_token_id)])
322
+ att = torch.cat([att, torch.zeros(pad)])
323
+
324
+ return {
325
+ "input_ids": ids.unsqueeze(0).to(device),
326
+ "attention_mask": att.unsqueeze(0).to(device)
327
+ }
328
+
329
+ # -------------------------
330
+ # Training-Compatible Oracle Wrapper
331
+ # -------------------------
332
+ class DirectionalOracle(nn.Module):
333
+ """
334
+ Batch-capable oracle wrapper with TD3B-compatible predict_with_confidence().
335
+
336
+ This class is intended for training integration where peptide/protein tokens
337
+ are provided directly (batched) and the oracle runs in inference-only mode.
338
+ """
339
+ def __init__(
340
+ self,
341
+ model_ckpt: str,
342
+ tr2d2_checkpoint: str,
343
+ tokenizer_vocab: str,
344
+ tokenizer_splits: str,
345
+ esm_name: str = "facebook/esm2_t33_650M_UR50D",
346
+ d_model: int = 256,
347
+ n_heads: int = 4,
348
+ n_self_attn_layers: int = 1,
349
+ n_bmca_layers: int = 2,
350
+ dropout: float = 0.3,
351
+ max_ligand_length: int = 768,
352
+ max_protein_length: int = 1024,
353
+ device: Optional[str] = None,
354
+ esm_cache_dir: Optional[str] = None,
355
+ esm_local_files_only: bool = False
356
+ ):
357
+ super().__init__()
358
+
359
+ if isinstance(device, torch.device):
360
+ device = str(device)
361
+ self.device = resolve_device(device)
362
+
363
+ self.max_ligand_length = max_ligand_length
364
+ self.max_protein_length = max_protein_length
365
+ self._warned_ligand_truncation = False
366
+ self._warned_protein_truncation = False
367
+
368
+ self.lig_tokenizer = SMILES_SPE_Tokenizer(tokenizer_vocab, tokenizer_splits)
369
+ self.prot_tokenizer = EsmTokenizer.from_pretrained(
370
+ esm_name,
371
+ cache_dir=esm_cache_dir,
372
+ local_files_only=esm_local_files_only
373
+ )
374
+
375
+ tr2d2_cfg = create_tr2d2_config(self.lig_tokenizer.vocab_size)
376
+ self.model = ESM_TR2D2_GPCRClassifier(
377
+ esm_name=esm_name,
378
+ tr2d2_config=tr2d2_cfg,
379
+ lig_tokenizer=self.lig_tokenizer,
380
+ tr2d2_checkpoint=tr2d2_checkpoint,
381
+ d_model=d_model,
382
+ n_heads=n_heads,
383
+ n_self_attn_layers=n_self_attn_layers,
384
+ n_bmca_layers=n_bmca_layers,
385
+ dropout=dropout,
386
+ device=self.device,
387
+ esm_cache_dir=esm_cache_dir,
388
+ esm_local_files_only=esm_local_files_only
389
+ )
390
+
391
+ state_dict = torch.load(model_ckpt, map_location=self.device, weights_only=False)
392
+ if isinstance(state_dict, dict) and "model_state_dict" in state_dict:
393
+ state_dict = state_dict["model_state_dict"]
394
+ _load_state_dict_flexible(self.model, state_dict, strict=True)
395
+ self.model.to(self.device).eval()
396
+
397
+ for param in self.model.parameters():
398
+ param.requires_grad = False
399
+
400
+ self._lig_pad_token_id = self.lig_tokenizer.pad_token_id
401
+ if self._lig_pad_token_id is None:
402
+ self._lig_pad_token_id = 0
403
+ self._prot_pad_token_id = self.prot_tokenizer.pad_token_id
404
+ if self._prot_pad_token_id is None:
405
+ self._prot_pad_token_id = 0
406
+
407
+ def encode_protein(self, protein_seq: str) -> torch.Tensor:
408
+ tokens = self.prot_tokenizer(
409
+ protein_seq,
410
+ return_tensors="pt",
411
+ padding=True,
412
+ truncation=True,
413
+ max_length=self.max_protein_length,
414
+ add_special_tokens=True
415
+ )
416
+ return tokens["input_ids"].to(self.device)
417
+
418
+ def _normalize_token_dict(
419
+ self,
420
+ tokens: torch.Tensor,
421
+ pad_token_id: int,
422
+ max_length: int,
423
+ warned_attr: str
424
+ ) -> Dict[str, torch.Tensor]:
425
+ if isinstance(tokens, dict):
426
+ input_ids = tokens.get("input_ids")
427
+ if input_ids is None:
428
+ raise ValueError("Token dict must include input_ids.")
429
+ attention_mask = tokens.get("attention_mask")
430
+ input_ids = input_ids.to(self.device)
431
+ if attention_mask is None:
432
+ attention_mask = (input_ids != pad_token_id).long()
433
+ else:
434
+ attention_mask = attention_mask.to(self.device)
435
+ else:
436
+ input_ids = tokens
437
+ if input_ids.dim() == 1:
438
+ input_ids = input_ids.unsqueeze(0)
439
+ input_ids = input_ids.to(self.device)
440
+ attention_mask = (input_ids != pad_token_id).long()
441
+
442
+ if max_length is not None and input_ids.size(1) > max_length:
443
+ if not getattr(self, warned_attr):
444
+ logger.warning(
445
+ "Truncating input from length %d to max_length=%d",
446
+ input_ids.size(1),
447
+ max_length
448
+ )
449
+ setattr(self, warned_attr, True)
450
+ input_ids = input_ids[:, :max_length]
451
+ attention_mask = attention_mask[:, :max_length]
452
+
453
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
454
+
455
+ def _normalize_prob_inputs(
456
+ self,
457
+ probs: torch.Tensor,
458
+ attention_mask: Optional[torch.Tensor],
459
+ max_length: int,
460
+ warned_attr: str,
461
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
462
+ if probs.dim() == 2:
463
+ probs = probs.unsqueeze(0)
464
+ probs = probs.to(self.device)
465
+ if attention_mask is None:
466
+ attention_mask = torch.ones(
467
+ probs.size(0), probs.size(1), device=self.device, dtype=torch.long
468
+ )
469
+ else:
470
+ if attention_mask.dim() == 1:
471
+ attention_mask = attention_mask.unsqueeze(0)
472
+ attention_mask = attention_mask.to(self.device).long()
473
+
474
+ if max_length is not None and probs.size(1) > max_length:
475
+ if not getattr(self, warned_attr):
476
+ logger.warning(
477
+ "Truncating input from length %d to max_length=%d",
478
+ probs.size(1),
479
+ max_length
480
+ )
481
+ setattr(self, warned_attr, True)
482
+ probs = probs[:, :max_length]
483
+ attention_mask = attention_mask[:, :max_length]
484
+
485
+ return probs, attention_mask
486
+
487
+ @torch.no_grad()
488
+ def predict_with_confidence(
489
+ self,
490
+ peptide_tokens: torch.Tensor,
491
+ protein_tokens: torch.Tensor
492
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
493
+ lig_tokens = self._normalize_token_dict(
494
+ peptide_tokens,
495
+ self._lig_pad_token_id,
496
+ self.max_ligand_length,
497
+ "_warned_ligand_truncation"
498
+ )
499
+ prot_tokens = self._normalize_token_dict(
500
+ protein_tokens,
501
+ self._prot_pad_token_id,
502
+ self.max_protein_length,
503
+ "_warned_protein_truncation"
504
+ )
505
+
506
+ lig_batch = lig_tokens["input_ids"].size(0)
507
+ prot_batch = prot_tokens["input_ids"].size(0)
508
+ if prot_batch == 1 and lig_batch > 1:
509
+ prot_tokens = {k: v.expand(lig_batch, -1) for k, v in prot_tokens.items()}
510
+ elif prot_batch != lig_batch:
511
+ raise ValueError(
512
+ f"Batch size mismatch: peptide_tokens={lig_batch}, protein_tokens={prot_batch}"
513
+ )
514
+
515
+ logits = self.model(prot_tokens, lig_tokens)
516
+ probs = F.softmax(logits, dim=-1)
517
+ p_agonist = probs[:, 1]
518
+ confidence = torch.max(probs, dim=-1).values
519
+ return p_agonist, confidence
520
+
521
+ def predict_from_probs(
522
+ self,
523
+ ligand_probs: torch.Tensor,
524
+ protein_tokens: torch.Tensor,
525
+ ligand_attention_mask: Optional[torch.Tensor] = None,
526
+ ) -> torch.Tensor:
527
+ lig_probs, lig_attention = self._normalize_prob_inputs(
528
+ ligand_probs,
529
+ ligand_attention_mask,
530
+ self.max_ligand_length,
531
+ "_warned_ligand_truncation",
532
+ )
533
+ prot_tokens = self._normalize_token_dict(
534
+ protein_tokens,
535
+ self._prot_pad_token_id,
536
+ self.max_protein_length,
537
+ "_warned_protein_truncation"
538
+ )
539
+
540
+ lig_batch = lig_probs.size(0)
541
+ prot_batch = prot_tokens["input_ids"].size(0)
542
+ if prot_batch == 1 and lig_batch > 1:
543
+ prot_tokens = {k: v.expand(lig_batch, -1) for k, v in prot_tokens.items()}
544
+ elif prot_batch != lig_batch:
545
+ raise ValueError(
546
+ f"Batch size mismatch: ligand_probs={lig_batch}, protein_tokens={prot_batch}"
547
+ )
548
+
549
+ emb_weight = self.model.ligand_encoder.encoder.model.roformer.embeddings.word_embeddings.weight
550
+ if lig_probs.size(-1) != emb_weight.size(0):
551
+ raise ValueError(
552
+ f"Ligand vocab mismatch: probs={lig_probs.size(-1)} vs oracle={emb_weight.size(0)}"
553
+ )
554
+ lig_inputs_embeds = lig_probs @ emb_weight
555
+ lig_input_ids = torch.zeros(
556
+ lig_probs.size(0), lig_probs.size(1), device=lig_probs.device, dtype=torch.long
557
+ )
558
+ lig_tokens = {"input_ids": lig_input_ids, "attention_mask": lig_attention}
559
+ logits = self.model(prot_tokens, lig_tokens, lig_inputs_embeds=lig_inputs_embeds)
560
+ probs = F.softmax(logits, dim=-1)
561
+ return probs[:, 1]
562
+
563
+ # -------------------------
564
+ # Prediction
565
+ # -------------------------
566
+ @torch.no_grad()
567
+ def predict(model, prot_tok, lig_tok, protein_seq, peptide_seq, device, threshold=0.5):
568
+ """
569
+ Predict agonist activity
570
+
571
+ Returns:
572
+ dict with keys: smiles, non_agonist_prob, agonist_prob, prediction, confidence
573
+ """
574
+ # Convert peptide to SMILES
575
+ smiles = peptide_to_smiles(peptide_seq)
576
+
577
+ # Tokenize
578
+ prot_tokens = tokenize_protein(protein_seq, prot_tok, device)
579
+ lig_tokens = tokenize_ligand(smiles, lig_tok, 768, device) # FIXED: 768 not 256!
580
+
581
+ # Predict
582
+ logits = model(prot_tokens, lig_tokens)
583
+ probs = F.softmax(logits, dim=-1).squeeze(0)
584
+
585
+ p_non_agonist = probs[0].item()
586
+ p_agonist = probs[1].item()
587
+ prediction = "agonist" if p_agonist >= threshold else "non-agonist"
588
+
589
+ return {
590
+ "smiles": smiles,
591
+ "non_agonist_prob": p_non_agonist,
592
+ "agonist_prob": p_agonist,
593
+ "prediction": prediction,
594
+ "confidence": max(p_non_agonist, p_agonist)
595
+ }
596
+
597
+ # -------------------------
598
+ # MAIN
599
+ # -------------------------
600
+ def main():
601
+ parser = argparse.ArgumentParser(
602
+ description="GPCR Agonist Classifier - TR2-D2 Inference"
603
+ )
604
+ parser.add_argument("--model_ckpt", required=True,
605
+ help="Path to trained model checkpoint")
606
+ parser.add_argument("--tr2d2_checkpoint", required=True,
607
+ help="Path to TR2-D2 pretrained checkpoint")
608
+ parser.add_argument("--tokenizer_vocab", required=True,
609
+ help="Path to tokenizer vocabulary")
610
+ parser.add_argument("--tokenizer_splits", required=True,
611
+ help="Path to tokenizer splits")
612
+ parser.add_argument("--protein_seq", required=True,
613
+ help="GPCR protein sequence")
614
+ parser.add_argument("--ligand_peptide", required=True,
615
+ help="Ligand peptide sequence")
616
+ parser.add_argument("--threshold", type=float, default=0.5,
617
+ help="Classification threshold (default: 0.5)")
618
+ parser.add_argument("--d_model", type=int, default=256,
619
+ help="Hidden dimension (must match training)")
620
+ parser.add_argument("--n_heads", type=int, default=4,
621
+ help="Number of attention heads (must match training)")
622
+ parser.add_argument("--n_self_attn_layers", type=int, default=1,
623
+ help="Number of self-attention layers (must match training)")
624
+ parser.add_argument("--n_bmca_layers", type=int, default=2,
625
+ help="Number of cross-attention layers (must match training)")
626
+ parser.add_argument("--dropout", type=float, default=0.3,
627
+ help="Dropout rate (must match training)")
628
+ parser.add_argument("--device", default=None,
629
+ help="Device (cuda/cpu, default: auto)")
630
+ parser.add_argument("--esm_name", default="facebook/esm2_t33_650M_UR50D",
631
+ help="ESM model name or local path")
632
+ parser.add_argument("--esm_cache_dir", default=None,
633
+ help="Optional cache directory for ESM model")
634
+ parser.add_argument("--esm_local_files_only", action="store_true",
635
+ help="Load ESM from local cache only (no network)")
636
+
637
+ args = parser.parse_args()
638
+
639
+ # Device
640
+ device = resolve_device(args.device)
641
+
642
+ print(f"Device: {device}")
643
+ print("")
644
+
645
+ # Load tokenizers
646
+ print("Loading tokenizers...")
647
+ prot_tok = EsmTokenizer.from_pretrained(
648
+ args.esm_name,
649
+ cache_dir=args.esm_cache_dir,
650
+ local_files_only=args.esm_local_files_only
651
+ )
652
+ lig_tok = SMILES_SPE_Tokenizer(args.tokenizer_vocab, args.tokenizer_splits)
653
+ print(f" Vocab size: {lig_tok.vocab_size}")
654
+ print("")
655
+
656
+ # Create config
657
+ tr2d2_cfg = create_tr2d2_config(lig_tok.vocab_size)
658
+
659
+ # Load model
660
+ print("Loading model...")
661
+ model = ESM_TR2D2_GPCRClassifier(
662
+ esm_name=args.esm_name,
663
+ tr2d2_config=tr2d2_cfg,
664
+ lig_tokenizer=lig_tok,
665
+ tr2d2_checkpoint=args.tr2d2_checkpoint,
666
+ d_model=args.d_model,
667
+ n_heads=args.n_heads,
668
+ n_self_attn_layers=args.n_self_attn_layers,
669
+ n_bmca_layers=args.n_bmca_layers,
670
+ dropout=args.dropout,
671
+ device=device,
672
+ esm_cache_dir=args.esm_cache_dir,
673
+ esm_local_files_only=args.esm_local_files_only
674
+ )
675
+
676
+ # Load trained weights
677
+ print(" Loading trained weights...")
678
+ state_dict = torch.load(args.model_ckpt, map_location=device)
679
+ _load_state_dict_flexible(model, state_dict, strict=True)
680
+ model.to(device).eval()
681
+ print(" Model ready.")
682
+ print("")
683
+
684
+ # Predict
685
+ print("Running inference...")
686
+ result = predict(
687
+ model, prot_tok, lig_tok,
688
+ args.protein_seq, args.ligand_peptide,
689
+ device, args.threshold
690
+ )
691
+
692
+ # Display results
693
+ print("")
694
+ print("=" * 70)
695
+ print("RESULTS")
696
+ print("=" * 70)
697
+ print(f"Protein: {args.protein_seq[:50]}{'...' if len(args.protein_seq) > 50 else ''}")
698
+ print(f"Ligand: {args.ligand_peptide}")
699
+ print(f"SMILES: {result['smiles']}")
700
+ print("")
701
+ print(f"Non-agonist probability: {result['non_agonist_prob']:.4f}")
702
+ print(f"Agonist probability: {result['agonist_prob']:.4f}")
703
+ print("")
704
+ print(f"Prediction (threshold={args.threshold}): {result['prediction'].upper()}")
705
+ print(f"Confidence: {result['confidence']:.4f}")
706
+ print("=" * 70)
707
+
708
+ if __name__ == "__main__":
709
+ main()
td3b/td3b_finetune.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD3B Finetuning Loop
3
+ Extends TR2-D2 training with contrastive loss and directional rewards.
4
+ """
5
+
6
+ import numpy as np
7
+ import torch
8
+ import wandb
9
+ import os
10
+ from finetune_utils import loss_wdce
11
+ from .td3b_losses import TD3BTotalLoss, extract_embeddings_from_mdlm
12
+ from tqdm import tqdm
13
+ import pandas as pd
14
+ from plotting import plot_data_with_distribution_seaborn, plot_data
15
+
16
+
17
+ def td3b_finetune(
18
+ args,
19
+ cfg,
20
+ policy_model,
21
+ reward_model,
22
+ mcts=None,
23
+ pretrained=None,
24
+ filename=None,
25
+ prot_name=None,
26
+ eps=1e-5,
27
+ # TD3B-specific arguments
28
+ contrastive_weight=0.1,
29
+ contrastive_margin=1.0,
30
+ contrastive_type='margin',
31
+ embedding_pool_method='mean',
32
+ kl_beta=0.1
33
+ ):
34
+ """
35
+ TD3B finetuning with combined WDCE + contrastive loss + KL regularization.
36
+
37
+ Args:
38
+ args: Configuration arguments
39
+ cfg: Hydra config
40
+ policy_model: Policy model (MDLM)
41
+ reward_model: Reward scoring functions (TD3BRewardFunction)
42
+ mcts: TD3B_MCTS instance
43
+ pretrained: Pretrained model (for no-MCTS mode)
44
+ filename: Output filename
45
+ prot_name: Target protein name
46
+ eps: Small epsilon
47
+ contrastive_weight: λ for contrastive loss
48
+ contrastive_margin: Margin for margin-based contrastive loss
49
+ contrastive_type: 'margin' or 'infonce'
50
+ embedding_pool_method: 'mean', 'max', or 'cls'
51
+ kl_beta: β coefficient for KL divergence regularization
52
+ Returns:
53
+ batch_losses: List of training losses
54
+ """
55
+ base_path = args.base_path
56
+ dt = (1 - eps) / args.total_num_steps
57
+
58
+ if args.no_mcts:
59
+ assert pretrained is not None, "pretrained model is required for no mcts"
60
+ else:
61
+ assert mcts is not None, "mcts is required for mcts"
62
+
63
+ # Create reference model (frozen copy of policy model at start of training)
64
+ # Cannot use copy.deepcopy() due to unpicklable objects (file handles, etc.)
65
+ # Instead, create a new model instance and load CLONED state dict
66
+ print("[TD3B] Creating reference model for KL regularization...")
67
+
68
+ # Import Diffusion class
69
+ from diffusion import Diffusion
70
+
71
+ # Create new instance with same config
72
+ reference_model = Diffusion(
73
+ config=policy_model.config,
74
+ tokenizer=policy_model.tokenizer,
75
+ mode="eval",
76
+ device=policy_model.device if hasattr(policy_model, 'device') else args.device
77
+ )
78
+
79
+ # Get the device from policy model
80
+ device = policy_model.device if hasattr(policy_model, 'device') else args.device
81
+ if device is None:
82
+ device = next(policy_model.parameters()).device
83
+
84
+ # IMPORTANT: Clone the state dict to create independent tensors
85
+ # This ensures no memory sharing between policy and reference model
86
+ state_dict_copy = {
87
+ key: value.clone().detach()
88
+ for key, value in policy_model.state_dict().items()
89
+ }
90
+ reference_model.load_state_dict(state_dict_copy)
91
+
92
+ # Move reference model to same device as policy model
93
+ reference_model = reference_model.to(device)
94
+
95
+ # Freeze and set to eval mode
96
+ reference_model.eval()
97
+ for param in reference_model.parameters():
98
+ param.requires_grad = False
99
+
100
+ print(f"[TD3B] Reference model frozen with {sum(p.numel() for p in reference_model.parameters())} parameters")
101
+ print(f"[TD3B] Reference model on device: {device}")
102
+
103
+ # Verify no parameter sharing
104
+ policy_params = {id(p) for p in policy_model.parameters()}
105
+ ref_params = {id(p) for p in reference_model.parameters()}
106
+ assert len(policy_params.intersection(ref_params)) == 0, \
107
+ "ERROR: Reference model shares parameters with policy model!"
108
+ print("[TD3B] ✓ Verified: No parameter sharing between policy and reference model")
109
+
110
+ # Initialize TD3B total loss
111
+ td3b_loss_fn = TD3BTotalLoss(
112
+ contrastive_weight=contrastive_weight,
113
+ contrastive_margin=contrastive_margin,
114
+ contrastive_type=contrastive_type,
115
+ kl_beta=kl_beta,
116
+ reference_model=reference_model
117
+ )
118
+
119
+ # Set model to train mode
120
+ policy_model.train()
121
+ torch.set_grad_enabled(True)
122
+ optim = torch.optim.AdamW(policy_model.parameters(), lr=args.learning_rate)
123
+
124
+ # Record metrics
125
+ batch_losses = []
126
+ batch_wdce_losses = []
127
+ batch_contrastive_losses = []
128
+ batch_kl_losses = []
129
+
130
+ # Initialize saved trajectories
131
+ x_saved, log_rnd_saved, final_rewards_saved = None, None, None
132
+ directional_labels_saved, confidences_saved = None, None
133
+
134
+ # Logs
135
+ valid_fraction_log = []
136
+ affinity_log = []
137
+ gated_reward_log = []
138
+ confidence_log = []
139
+ direction_prediction_log = [] # Oracle predictions f_φ ∈ [0, 1]
140
+ consistency_reward_log = [] # d* × (f_φ - 0.5)
141
+
142
+ ### Fine-Tuning Loop ###
143
+ pbar = tqdm(range(args.num_epochs))
144
+
145
+ for epoch in pbar:
146
+ rewards = []
147
+ losses = []
148
+
149
+ policy_model.train()
150
+
151
+ with torch.no_grad():
152
+ if x_saved is None or epoch % args.resample_every_n_step == 0:
153
+ # Generate trajectories
154
+ if args.no_mcts:
155
+ # Direct sampling (not typical for TD3B, but keep for compatibility)
156
+ x_final, log_rnd, final_rewards = policy_model.sample_finetuned_with_rnd(
157
+ args, reward_model, pretrained
158
+ )
159
+ directional_labels = torch.zeros(x_final.size(0), dtype=torch.float32)
160
+ confidences = torch.ones(x_final.size(0), dtype=torch.float32)
161
+ else:
162
+ # TD3B MCTS forward pass
163
+ # For dual-direction mode, sample BOTH directions in the same batch
164
+ if hasattr(args, 'target_direction') and args.target_direction == 'both':
165
+ print(f"[Dual-direction] Epoch {epoch}: Sampling BOTH agonist and antagonist binders")
166
+
167
+ # Sample agonist binders (d* = +1)
168
+ reward_model.target_direction = 1.0
169
+ if epoch % args.reset_every_n_step == 0:
170
+ results_agonist = mcts.forward(resetTree=True)
171
+ else:
172
+ results_agonist = mcts.forward(resetTree=False)
173
+
174
+ # Sample antagonist binders (d* = -1)
175
+ reward_model.target_direction = -1.0
176
+ # Don't reset tree for antagonist to save computation
177
+ results_antagonist = mcts.forward(resetTree=False)
178
+
179
+ # Unpack both results
180
+ if len(results_agonist) == 7 and len(results_antagonist) == 7:
181
+ x_agonist, log_rnd_agonist, rewards_agonist, _, _, labels_agonist, conf_agonist = results_agonist
182
+ x_antagonist, log_rnd_antagonist, rewards_antagonist, _, _, labels_antagonist, conf_antagonist = results_antagonist
183
+
184
+ # Force labels to be correct (in case oracle is wrong)
185
+ labels_agonist = torch.ones(x_agonist.size(0), dtype=torch.float32) * 1.0 # +1 for agonist
186
+ labels_antagonist = torch.ones(x_antagonist.size(0), dtype=torch.float32) * -1.0 # -1 for antagonist
187
+
188
+ # Combine both directions into single batch
189
+ x_final = torch.cat([x_agonist, x_antagonist], dim=0)
190
+ log_rnd = torch.cat([log_rnd_agonist, log_rnd_antagonist], dim=0)
191
+ final_rewards = np.concatenate([rewards_agonist, rewards_antagonist], axis=0)
192
+ directional_labels = torch.cat([labels_agonist, labels_antagonist], dim=0)
193
+ confidences = torch.cat([
194
+ conf_agonist if isinstance(conf_agonist, torch.Tensor) else torch.tensor(conf_agonist),
195
+ conf_antagonist if isinstance(conf_antagonist, torch.Tensor) else torch.tensor(conf_antagonist)
196
+ ], dim=0)
197
+
198
+ print(f" → Combined batch: {x_agonist.size(0)} agonists + {x_antagonist.size(0)} antagonists = {x_final.size(0)} total")
199
+ print(f" → Directional labels: {torch.unique(directional_labels).tolist()} (DIVERSITY CONFIRMED!)")
200
+ else:
201
+ raise ValueError("Dual-direction mode requires 7-value return from MCTS")
202
+ else:
203
+ # Single-direction mode
204
+ if epoch % args.reset_every_n_step == 0:
205
+ results = mcts.forward(resetTree=True)
206
+ else:
207
+ results = mcts.forward(resetTree=False)
208
+
209
+ # Unpack results (TD3B version includes directional labels and confidences)
210
+ if len(results) == 7:
211
+ x_final, log_rnd, final_rewards, score_vectors, sequences, directional_labels, confidences = results
212
+ # Convert numpy arrays to tensors immediately for consistency
213
+ if not isinstance(directional_labels, torch.Tensor):
214
+ directional_labels = torch.tensor(directional_labels, dtype=torch.float32)
215
+ if not isinstance(confidences, torch.Tensor):
216
+ confidences = torch.tensor(confidences, dtype=torch.float32)
217
+ else:
218
+ # Fallback for compatibility with base MCTS
219
+ x_final, log_rnd, final_rewards, score_vectors, sequences = results
220
+ directional_labels = torch.zeros(x_final.size(0), dtype=torch.float32)
221
+ confidences = torch.ones(x_final.size(0), dtype=torch.float32)
222
+
223
+ # Save for next iteration
224
+ x_saved = x_final
225
+ log_rnd_saved = log_rnd
226
+ final_rewards_saved = final_rewards
227
+ directional_labels_saved = directional_labels
228
+ confidences_saved = confidences
229
+ else:
230
+ # Reuse cached trajectories
231
+ x_final = x_saved
232
+ log_rnd = log_rnd_saved
233
+ final_rewards = final_rewards_saved
234
+ directional_labels = directional_labels_saved
235
+ confidences = confidences_saved
236
+
237
+ # Compute WDCE loss
238
+ wdce_loss = loss_wdce(
239
+ policy_model,
240
+ log_rnd,
241
+ x_final,
242
+ num_replicates=args.wdce_num_replicates,
243
+ centering=args.centering
244
+ )
245
+
246
+ # Compute KL divergence loss
247
+ # Use a random masking and forward pass for KL computation
248
+ mask_index = policy_model.mask_index
249
+ device = x_final.device
250
+
251
+ # Sample random noise level
252
+ lamda = torch.rand(x_final.shape[0], device=device) # (B,)
253
+ sigma_kl = -torch.log1p(-(1 - eps) * lamda)
254
+
255
+ # Apply random masking
256
+ masked_index = torch.rand(*x_final.shape, device=device) < lamda[..., None] # (B, L)
257
+ perturbed_batch = torch.where(masked_index, mask_index, x_final)
258
+ attn_mask_kl = torch.ones_like(perturbed_batch).to(device)
259
+
260
+ # Compute KL loss
261
+ kl_loss = td3b_loss_fn.compute_kl_loss(
262
+ policy_model,
263
+ perturbed_batch,
264
+ attn_mask_kl,
265
+ sigma_kl
266
+ )
267
+
268
+ # Extract embeddings for contrastive loss
269
+ # Only compute if we have directional labels
270
+ if directional_labels is not None and len(torch.unique(directional_labels)) > 1:
271
+ # Get device from backbone
272
+ device = policy_model.backbone.device if hasattr(policy_model.backbone, 'device') else x_final.device
273
+
274
+ embeddings = extract_embeddings_from_mdlm(
275
+ policy_model,
276
+ x_final.to(device),
277
+ pool_method=embedding_pool_method
278
+ )
279
+
280
+ # Move directional labels to same device
281
+ directional_labels = directional_labels.to(embeddings.device)
282
+
283
+ # Enable debug mode for first 3 epochs or if loss was zero last epoch
284
+ debug_mode = (epoch < 3) or (epoch > 0 and batch_contrastive_losses and batch_contrastive_losses[-1] < 1e-6)
285
+
286
+ # Compute total TD3B loss
287
+ total_loss, loss_dict = td3b_loss_fn.compute_loss(
288
+ wdce_loss,
289
+ embeddings,
290
+ directional_labels,
291
+ kl_loss=kl_loss, # Pass KL loss
292
+ debug=debug_mode # Enable debugging when needed
293
+ )
294
+ else:
295
+ # If no directional diversity, skip contrastive loss
296
+ print(f"[WARNING] Epoch {epoch}: No directional diversity! Skipping contrastive loss.")
297
+ print(f" Labels: {directional_labels.cpu().tolist() if directional_labels is not None else 'None'}")
298
+ total_loss = wdce_loss + td3b_loss_fn.kl_beta * kl_loss
299
+ loss_dict = {
300
+ 'total_loss': total_loss.item(),
301
+ 'wdce_loss': wdce_loss.item(),
302
+ 'contrastive_loss': 0.0,
303
+ 'kl_loss': kl_loss.item()
304
+ }
305
+
306
+ # Gradient descent
307
+ total_loss.backward()
308
+
309
+ # Gradient clipping
310
+ if args.grad_clip:
311
+ torch.nn.utils.clip_grad_norm_(policy_model.parameters(), args.gradnorm_clip)
312
+
313
+ optim.step()
314
+ optim.zero_grad()
315
+
316
+ pbar.set_postfix(
317
+ total_loss=loss_dict['total_loss'],
318
+ wdce=loss_dict['wdce_loss'],
319
+ ctr=loss_dict['contrastive_loss']
320
+ )
321
+
322
+ # Evaluation sampling
323
+ x_eval, eval_metrics = policy_model.sample_finetuned_td3b(
324
+ args,
325
+ reward_model,
326
+ batch_size=50,
327
+ dataframe=False
328
+ )
329
+
330
+ # Extract metrics (TD3B-specific)
331
+ affinity = eval_metrics.get('affinity', [0])
332
+ gated_reward = eval_metrics.get('gated_reward', [0])
333
+ confidence = eval_metrics.get('confidence', [1])
334
+ valid_fraction = eval_metrics.get('valid_fraction', 0)
335
+
336
+ # Extract direction predictions (f_φ ∈ [0, 1])
337
+ direction_predictions = eval_metrics.get('direction_predictions', [0.5])
338
+
339
+ # Compute consistency reward: d* × (f_φ - 0.5)
340
+ # Get target direction d* from reward_model
341
+ d_star = reward_model.target_direction # +1 or -1
342
+ consistency_rewards = [d_star * (f_phi - 0.5) for f_phi in direction_predictions]
343
+
344
+ # Append to logs
345
+ affinity_log.append(affinity)
346
+ gated_reward_log.append(gated_reward)
347
+ confidence_log.append(confidence)
348
+ valid_fraction_log.append(valid_fraction)
349
+ direction_prediction_log.append(direction_predictions)
350
+ consistency_reward_log.append(consistency_rewards)
351
+
352
+ batch_losses.append(loss_dict['total_loss'])
353
+ batch_wdce_losses.append(loss_dict['wdce_loss'])
354
+ batch_contrastive_losses.append(loss_dict['contrastive_loss'])
355
+ batch_kl_losses.append(loss_dict.get('kl_loss', 0.0))
356
+
357
+ # Compute search statistics
358
+ if args.no_mcts:
359
+ mean_reward_search = final_rewards.mean().item()
360
+ min_reward_search = final_rewards.min().item()
361
+ max_reward_search = final_rewards.max().item()
362
+ median_reward_search = final_rewards.median().item()
363
+ else:
364
+ mean_reward_search = np.mean(final_rewards)
365
+ min_reward_search = np.min(final_rewards)
366
+ max_reward_search = np.max(final_rewards)
367
+ median_reward_search = np.median(final_rewards)
368
+
369
+ # Compute direction oracle and consistency reward statistics
370
+ mean_direction = np.mean(direction_predictions) if len(direction_predictions) > 0 else 0.5
371
+ std_direction = np.std(direction_predictions) if len(direction_predictions) > 0 else 0.0
372
+ mean_consistency = np.mean(consistency_rewards) if len(consistency_rewards) > 0 else 0.0
373
+ std_consistency = np.std(consistency_rewards) if len(consistency_rewards) > 0 else 0.0
374
+
375
+ print(
376
+ f"epoch {epoch} | "
377
+ f"affinity {np.mean(affinity):.4f} | "
378
+ f"gated_reward {np.mean(gated_reward):.4f} | "
379
+ f"confidence {np.mean(confidence):.4f} | "
380
+ f"valid_frac {valid_fraction:.4f} | "
381
+ f"direction_oracle {mean_direction:.4f}±{std_direction:.4f} | "
382
+ f"consistency_reward {mean_consistency:.4f}±{std_consistency:.4f} | "
383
+ f"total_loss {loss_dict['total_loss']:.4f} | "
384
+ f"wdce_loss {loss_dict['wdce_loss']:.4f} | "
385
+ f"contrastive_loss {loss_dict['contrastive_loss']:.4f} | "
386
+ f"kl_loss {loss_dict.get('kl_loss', 0.0):.4f}"
387
+ )
388
+
389
+ # W&B logging
390
+ wandb.log({
391
+ "epoch": epoch,
392
+ "affinity": np.mean(affinity),
393
+ "gated_reward": np.mean(gated_reward),
394
+ "confidence": np.mean(confidence),
395
+ "valid_fraction": valid_fraction,
396
+ "direction_oracle/mean": mean_direction,
397
+ "direction_oracle/std": std_direction,
398
+ "consistency_reward/mean": mean_consistency,
399
+ "consistency_reward/std": std_consistency,
400
+ "total_loss": loss_dict['total_loss'],
401
+ "wdce_loss": loss_dict['wdce_loss'],
402
+ "contrastive_loss": loss_dict['contrastive_loss'],
403
+ "kl_loss": loss_dict.get('kl_loss', 0.0),
404
+ "mean_reward_search": mean_reward_search,
405
+ "min_reward_search": min_reward_search,
406
+ "max_reward_search": max_reward_search,
407
+ "median_reward_search": median_reward_search
408
+ })
409
+
410
+ # Save checkpoint
411
+ if (epoch + 1) % args.save_every_n_epochs == 0:
412
+ model_path = os.path.join(args.save_path, f'model_{epoch}.ckpt')
413
+ torch.save(policy_model.state_dict(), model_path)
414
+ print(f"model saved at epoch {epoch}")
415
+
416
+ ### End of Fine-Tuning Loop ###
417
+
418
+ wandb.finish()
419
+
420
+ # Save logs and plots
421
+ plot_path = f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}'
422
+ os.makedirs(plot_path, exist_ok=True)
423
+ output_log_path = f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/log_{filename}.csv'
424
+ save_td3b_logs_to_file(
425
+ valid_fraction_log,
426
+ affinity_log,
427
+ gated_reward_log,
428
+ confidence_log,
429
+ direction_prediction_log,
430
+ consistency_reward_log,
431
+ output_log_path
432
+ )
433
+
434
+ plot_data(valid_fraction_log,
435
+ save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/valid_{filename}.png')
436
+
437
+ plot_data_with_distribution_seaborn(
438
+ log1=affinity_log,
439
+ save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/affinity_{filename}.png',
440
+ label1=f"Average Affinity to {prot_name}",
441
+ title=f"Average Affinity to {prot_name} Over Iterations"
442
+ )
443
+
444
+ plot_data_with_distribution_seaborn(
445
+ log1=gated_reward_log,
446
+ save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/gated_reward_{filename}.png',
447
+ label1="Average Gated Reward",
448
+ title="Average Gated Reward Over Iterations"
449
+ )
450
+
451
+ plot_data_with_distribution_seaborn(
452
+ log1=confidence_log,
453
+ save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/confidence_{filename}.png',
454
+ label1="Average Confidence",
455
+ title="Average Confidence Over Iterations"
456
+ )
457
+
458
+ # Final evaluation
459
+ x_eval, eval_metrics, df = policy_model.sample_finetuned_td3b(
460
+ args,
461
+ reward_model,
462
+ batch_size=200,
463
+ dataframe=True
464
+ )
465
+ df.to_csv(f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/{prot_name}_generation_results.csv', index=False)
466
+
467
+ return batch_losses
468
+
469
+
470
+ def save_td3b_logs_to_file(valid_fraction_log, affinity_log, gated_reward_log, confidence_log,
471
+ direction_prediction_log, consistency_reward_log, output_path):
472
+ """
473
+ Saves TD3B-specific logs to a CSV file.
474
+
475
+ Parameters:
476
+ valid_fraction_log (list): Log of valid fractions over iterations.
477
+ affinity_log (list): Log of binding affinity over iterations.
478
+ gated_reward_log (list): Log of gated rewards over iterations.
479
+ confidence_log (list): Log of confidence scores over iterations.
480
+ direction_prediction_log (list): Log of direction oracle predictions over iterations.
481
+ consistency_reward_log (list): Log of consistency rewards over iterations.
482
+ output_path (str): Path to save the log CSV file.
483
+ """
484
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
485
+
486
+ # Combine logs into a DataFrame
487
+ log_data = {
488
+ "Iteration": list(range(1, len(valid_fraction_log) + 1)),
489
+ "Valid Fraction": valid_fraction_log,
490
+ "Binding Affinity": affinity_log,
491
+ "Gated Reward": gated_reward_log,
492
+ "Confidence": confidence_log,
493
+ "Direction Oracle": direction_prediction_log,
494
+ "Consistency Reward": consistency_reward_log
495
+ }
496
+
497
+ df = pd.DataFrame(log_data)
498
+
499
+ # Save to CSV
500
+ df.to_csv(output_path, index=False)
501
+ print(f"Logs saved to {output_path}")
502
+
503
+
504
+ # Add sampling method to diffusion model (monkey patch or extend)
505
+ def add_td3b_sampling_to_model(model):
506
+ """
507
+ Adds TD3B-specific sampling method to the model.
508
+ This is a helper function to extend the existing model.
509
+ """
510
+ def sample_finetuned_td3b(self, args, reward_model, batch_size=50, dataframe=False):
511
+ """
512
+ TD3B-specific sampling that returns directional metrics.
513
+ """
514
+ self.backbone.eval()
515
+ self.noise.eval()
516
+
517
+ if batch_size is None:
518
+ batch_size = args.batch_size
519
+
520
+ eps = getattr(args, "sampling_eps", 1e-5)
521
+ num_steps = args.total_num_steps
522
+ x_rollout = self.sample_prior(
523
+ batch_size,
524
+ args.seq_length).to(self.device, dtype=torch.long)
525
+
526
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
527
+ dt = torch.tensor((1 - eps) / num_steps, device=self.device)
528
+
529
+ for i in range(num_steps):
530
+ t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
531
+ log_p, x_next = self.single_reverse_step(x_rollout, t=t, dt=dt)
532
+ x_rollout = x_next.to(self.device)
533
+
534
+ mask_positions = (x_rollout == self.mask_index)
535
+ if mask_positions.any().item():
536
+ log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
537
+ x_rollout = x_next.to(self.device)
538
+
539
+ # Convert x to sequences to get valid ones
540
+ from utils.app import PeptideAnalyzer
541
+ analyzer = PeptideAnalyzer()
542
+ sequences = self.tokenizer.batch_decode(x_rollout)
543
+ valid_mask = torch.tensor([analyzer.is_peptide(seq) for seq in sequences], device=self.device)
544
+ valid_sequences = [seq for seq, keep in zip(sequences, valid_mask.tolist()) if keep]
545
+ valid_x_final = x_rollout[valid_mask] if valid_mask.any().item() else torch.empty(0, device=self.device)
546
+ valid_fraction = len(valid_sequences) / batch_size
547
+
548
+ if len(valid_sequences) > 0:
549
+ result = reward_model(valid_sequences)
550
+ if isinstance(result, tuple):
551
+ total_rewards, info = result
552
+ affinity = np.asarray(info.get('affinities', total_rewards))
553
+ confidence = np.asarray(info.get('confidences', np.ones_like(affinity)))
554
+ direction_predictions = np.asarray(info.get('directions', np.zeros_like(affinity)))
555
+ else:
556
+ total_rewards = np.asarray(result)
557
+ if total_rewards.ndim > 1:
558
+ affinity = total_rewards[:, 0]
559
+ else:
560
+ affinity = total_rewards
561
+ confidence = np.ones_like(affinity)
562
+ direction_predictions = np.zeros_like(affinity)
563
+
564
+ rewards_t = torch.as_tensor(total_rewards, dtype=torch.float32, device=self.device)
565
+ alpha = max(float(getattr(args, "alpha", 0.1)), 1e-6)
566
+ weights = torch.softmax(rewards_t / alpha, dim=0)
567
+ idx = torch.multinomial(weights, num_samples=batch_size, replacement=True)
568
+
569
+ idx_np = idx.detach().cpu().numpy()
570
+ x_resampled = valid_x_final[idx]
571
+ sequences = [valid_sequences[i] for i in idx_np]
572
+ total_rewards = total_rewards[idx_np]
573
+ affinity = affinity[idx_np]
574
+ confidence = confidence[idx_np]
575
+ direction_predictions = direction_predictions[idx_np]
576
+ else:
577
+ x_resampled = x_rollout
578
+ total_rewards = np.array([])
579
+ affinity = np.array([])
580
+ confidence = np.array([])
581
+ direction_predictions = np.array([])
582
+
583
+ eval_metrics = {
584
+ 'affinity': affinity,
585
+ 'gated_reward': total_rewards,
586
+ 'confidence': confidence,
587
+ 'direction_predictions': direction_predictions,
588
+ 'valid_fraction': valid_fraction
589
+ }
590
+
591
+ if dataframe:
592
+ df = pd.DataFrame({
593
+ 'sequence': sequences if len(total_rewards) else [],
594
+ 'affinity': affinity,
595
+ 'gated_reward': total_rewards,
596
+ 'confidence': confidence
597
+ })
598
+ return x_resampled, eval_metrics, df
599
+ else:
600
+ return x_resampled, eval_metrics
601
+
602
+ # Attach method to model
603
+ model.sample_finetuned_td3b = sample_finetuned_td3b.__get__(model, type(model))
604
+ return model
td3b/td3b_losses.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD3B Loss Functions
3
+ Implements contrastive loss for separating agonist/antagonist embeddings.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Optional, Tuple
10
+
11
+
12
+ class ContrastiveLoss(nn.Module):
13
+ """
14
+ Margin-based contrastive loss for separating agonist and antagonist embeddings.
15
+
16
+ For a pair of sequences (y_i, y_j):
17
+ - If both are agonists OR both are antagonists (similar): minimize distance
18
+ - If one is agonist and one is antagonist (dissimilar): maximize distance
19
+
20
+ Loss formula:
21
+ L_ctr = (1 - y_ij) * 0.5 * d²
22
+ + y_ij * 0.5 * max(0, margin - d)²
23
+
24
+ where:
25
+ - d = ||emb_i - emb_j||_2 (Euclidean distance)
26
+ - y_ij = 0 if similar, 1 if dissimilar
27
+ - margin = minimum distance between dissimilar pairs
28
+ """
29
+
30
+ def __init__(self, margin: float = 1.0, distance_metric: str = 'euclidean', adaptive_margin: bool = False):
31
+ """
32
+ Args:
33
+ margin: Minimum distance between dissimilar pairs (base margin)
34
+ distance_metric: 'euclidean' or 'cosine'
35
+ adaptive_margin: If True, adjust margin based on actual dissimilar distances
36
+ """
37
+ super().__init__()
38
+ self.base_margin = margin
39
+ self.distance_metric = distance_metric
40
+ self.adaptive_margin = adaptive_margin
41
+
42
+ def compute_distance(self, emb1: torch.Tensor, emb2: torch.Tensor) -> torch.Tensor:
43
+ """
44
+ Compute pairwise distance between embeddings.
45
+
46
+ Args:
47
+ emb1: (batch_size, embedding_dim)
48
+ emb2: (batch_size, embedding_dim)
49
+ Returns:
50
+ distances: (batch_size,)
51
+ """
52
+ if self.distance_metric == 'euclidean':
53
+ # L2 distance
54
+ distances = torch.norm(emb1 - emb2, p=2, dim=-1) # (B,)
55
+ elif self.distance_metric == 'cosine':
56
+ # Cosine distance = 1 - cosine_similarity
57
+ cos_sim = F.cosine_similarity(emb1, emb2, dim=-1) # (B,)
58
+ distances = 1.0 - cos_sim
59
+ else:
60
+ raise ValueError(f"Unknown distance metric: {self.distance_metric}")
61
+
62
+ return distances
63
+
64
+ def forward(
65
+ self,
66
+ embeddings: torch.Tensor,
67
+ labels: torch.Tensor,
68
+ confidences: Optional[torch.Tensor] = None,
69
+ debug: bool = False
70
+ ) -> torch.Tensor:
71
+ """
72
+ Compute contrastive loss for a batch.
73
+
74
+ Args:
75
+ embeddings: (batch_size, embedding_dim) sequence embeddings
76
+ labels: (batch_size,) directional labels in {-1, +1}
77
+ +1 = agonist, -1 = antagonist
78
+ confidences: (batch_size,) oracle confidence scores; pairs with product <= 0 are masked out
79
+ debug: If True, print detailed debugging information
80
+ Returns:
81
+ loss: scalar contrastive loss
82
+ """
83
+ batch_size = embeddings.size(0)
84
+ if batch_size < 2:
85
+ if debug:
86
+ print(f"[ContrastiveLoss DEBUG] Batch size {batch_size} < 2, returning 0 loss")
87
+ return torch.tensor(0.0, device=embeddings.device)
88
+
89
+ if confidences is not None:
90
+ if not torch.is_tensor(confidences):
91
+ confidences = torch.as_tensor(confidences, device=embeddings.device)
92
+ else:
93
+ confidences = confidences.to(embeddings.device)
94
+ confidences = confidences.view(-1)
95
+ if confidences.numel() != batch_size:
96
+ raise ValueError(
97
+ f"Confidences size {confidences.numel()} does not match batch size {batch_size}"
98
+ )
99
+
100
+ # Compute pairwise distances (all pairs)
101
+ if self.distance_metric == 'euclidean':
102
+ distances = torch.cdist(embeddings, embeddings, p=2) # (B, B)
103
+ elif self.distance_metric == 'cosine':
104
+ emb_norm = F.normalize(embeddings, p=2, dim=-1)
105
+ distances = 1.0 - torch.matmul(emb_norm, emb_norm.T) # (B, B)
106
+ else:
107
+ raise ValueError(f"Unknown distance metric: {self.distance_metric}")
108
+
109
+ # Compute pairwise similarity labels
110
+ # y_ij = 0 if same class (both agonist or both antagonist)
111
+ # y_ij = 1 if different class
112
+ labels = labels.view(-1)
113
+ labels_expanded = labels.unsqueeze(1) # (B, 1)
114
+ label_product = labels_expanded * labels_expanded.T # (B, B)
115
+ # label_product > 0 means same class (both +1 or both -1)
116
+ # label_product < 0 means different class
117
+ dissimilar_mask = (label_product < 0) # (B, B) bool
118
+
119
+ # Exclude diagonal
120
+ eye_mask = torch.eye(batch_size, device=embeddings.device, dtype=torch.bool)
121
+ pos_mask = (~dissimilar_mask) & ~eye_mask
122
+ neg_mask = dissimilar_mask & ~eye_mask
123
+
124
+ # Apply confidence mask: remove pairs with confidence product <= 0
125
+ conf_mask = None
126
+ if confidences is not None:
127
+ conf_product = confidences.unsqueeze(0) * confidences.unsqueeze(1)
128
+ conf_mask = conf_product > 0
129
+ pos_mask = pos_mask & conf_mask
130
+ neg_mask = neg_mask & conf_mask
131
+
132
+ # Adaptive margin: set margin based on actual dissimilar distances
133
+ if self.adaptive_margin and neg_mask.any():
134
+ # Get all dissimilar distances
135
+ dissimilar_distances = distances[neg_mask]
136
+ # Set margin to 150% of mean dissimilar distance
137
+ # This ensures there's always room for optimization
138
+ adaptive_margin = 1.5 * dissimilar_distances.mean().item()
139
+ # Use max of base_margin and adaptive_margin
140
+ margin = max(self.base_margin, adaptive_margin)
141
+ else:
142
+ margin = self.base_margin
143
+
144
+ pos_count = pos_mask.sum()
145
+ neg_count = neg_mask.sum()
146
+ total_pairs = pos_count + neg_count
147
+ if total_pairs.item() == 0:
148
+ if debug:
149
+ print("[ContrastiveLoss DEBUG] No valid pairs after filtering, returning 0 loss")
150
+ return torch.tensor(0.0, device=embeddings.device)
151
+
152
+ # Contrastive loss
153
+ # For similar pairs: minimize squared distance
154
+ # For dissimilar pairs: squared hinge loss with margin
155
+ pos_loss = distances[pos_mask].pow(2).sum() / (pos_count + 1e-8)
156
+ neg_loss = torch.clamp(margin - distances[neg_mask], min=0.0).pow(2).sum() / (neg_count + 1e-8)
157
+ loss = pos_loss + neg_loss
158
+
159
+ if debug:
160
+ print(f"\n[ContrastiveLoss DEBUG]")
161
+ print(f" Batch size: {batch_size}")
162
+ print(f" Labels: {labels.cpu().tolist()}")
163
+ print(f" Unique labels: {torch.unique(labels).cpu().tolist()}")
164
+ print(f" Embedding shape: {embeddings.shape}")
165
+ print(f" Embedding norm (mean): {embeddings.norm(dim=-1).mean().item():.4f}")
166
+ print(f" Embedding norm (std): {embeddings.norm(dim=-1).std().item():.4f}")
167
+ valid_mask = pos_mask | neg_mask
168
+ if valid_mask.any():
169
+ valid_dists = distances[valid_mask]
170
+ print(f" Distance stats (valid pairs): mean={valid_dists.mean().item():.4f} "
171
+ f"min={valid_dists.min().item():.4f} max={valid_dists.max().item():.4f}")
172
+ if self.adaptive_margin and neg_mask.any():
173
+ print(f" Margin: {margin:.4f} (adaptive, base={self.base_margin})")
174
+ else:
175
+ print(f" Margin: {margin:.4f} (fixed)")
176
+ print(f" Num similar pairs: {pos_count.item():.0f}")
177
+ print(f" Num dissimilar pairs: {neg_count.item():.0f}")
178
+ if conf_mask is not None:
179
+ print(f" Confidence-passing pairs: {conf_mask.sum().item():.0f}")
180
+ print(f" Similar loss (mean): {pos_loss.item():.4f}")
181
+ print(f" Dissimilar loss (mean): {neg_loss.item():.4f}")
182
+ print(f" Total loss: {loss.item():.4f}")
183
+
184
+ # Show which dissimilar pairs have margin violations
185
+ margin_violations = (distances < margin) & neg_mask
186
+ if margin_violations.sum() > 0:
187
+ print(f" Margin violations: {margin_violations.sum().item():.0f} dissimilar pairs have distance < margin")
188
+ else:
189
+ print(f" Margin violations: 0 (all dissimilar pairs are already separated)")
190
+
191
+ return loss
192
+
193
+
194
+ class InfoNCELoss(nn.Module):
195
+ """
196
+ Alternative: InfoNCE contrastive loss (used in SimCLR, CLIP).
197
+ Treats agonists as positive class, antagonists as negative class.
198
+
199
+ For each agonist, pull it close to other agonists and push away from antagonists.
200
+ For each antagonist, pull it close to other antagonists and push away from agonists.
201
+ """
202
+
203
+ def __init__(self, temperature: float = 0.1):
204
+ """
205
+ Args:
206
+ temperature: Temperature parameter for softmax
207
+ """
208
+ super().__init__()
209
+ self.temperature = temperature
210
+
211
+ def forward(
212
+ self,
213
+ embeddings: torch.Tensor,
214
+ labels: torch.Tensor,
215
+ confidences: Optional[torch.Tensor] = None,
216
+ debug: bool = False
217
+ ) -> torch.Tensor:
218
+ """
219
+ Compute InfoNCE loss.
220
+
221
+ Args:
222
+ embeddings: (batch_size, embedding_dim)
223
+ labels: (batch_size,) in {-1, +1}
224
+ confidences: (batch_size,) oracle confidence scores; pairs with product <= 0 are masked out
225
+ debug: Unused (kept for API compatibility)
226
+ Returns:
227
+ loss: scalar
228
+ """
229
+ batch_size = embeddings.size(0)
230
+ if confidences is not None:
231
+ if not torch.is_tensor(confidences):
232
+ confidences = torch.as_tensor(confidences, device=embeddings.device)
233
+ else:
234
+ confidences = confidences.to(embeddings.device)
235
+ confidences = confidences.view(-1)
236
+ if confidences.numel() != batch_size:
237
+ raise ValueError(
238
+ f"Confidences size {confidences.numel()} does not match batch size {batch_size}"
239
+ )
240
+ if batch_size < 2:
241
+ return torch.tensor(0.0, device=embeddings.device)
242
+
243
+ # Normalize embeddings
244
+ embeddings = F.normalize(embeddings, p=2, dim=-1) # (B, D)
245
+
246
+ # Compute similarity matrix
247
+ similarity = torch.matmul(embeddings, embeddings.T) / self.temperature # (B, B)
248
+
249
+ # Create positive/negative masks
250
+ labels_expanded = labels.unsqueeze(1) # (B, 1)
251
+ label_product = labels_expanded * labels_expanded.T # (B, B)
252
+ positive_mask = (label_product > 0) # Same class
253
+ negative_mask = (label_product < 0) # Different class
254
+
255
+ # Remove self-similarity
256
+ positive_mask.fill_diagonal_(0)
257
+
258
+ if confidences is not None:
259
+ conf_product = confidences.unsqueeze(0) * confidences.unsqueeze(1)
260
+ conf_mask = conf_product > 0
261
+ positive_mask = positive_mask & conf_mask
262
+ negative_mask = negative_mask & conf_mask
263
+
264
+ # For each sample, compute InfoNCE loss
265
+ # log( exp(sim_pos) / (exp(sim_pos) + sum(exp(sim_neg))) )
266
+ losses = []
267
+ for i in range(batch_size):
268
+ # Positive samples
269
+ pos_sims = similarity[i][positive_mask[i]] # (num_pos,)
270
+ # Negative samples
271
+ neg_sims = similarity[i][negative_mask[i]] # (num_neg,)
272
+
273
+ # Check if there are positive samples
274
+ if pos_sims.numel() == 0:
275
+ continue
276
+
277
+ # LogSumExp for numerical stability
278
+ pos_exp = torch.exp(pos_sims) # (num_pos,)
279
+ neg_exp = torch.exp(neg_sims) # (num_neg,)
280
+
281
+ if neg_exp.numel() == 0:
282
+ continue
283
+
284
+ # Average over positive samples
285
+ denominator = pos_exp.sum() + neg_exp.sum()
286
+ loss_i = -torch.log(pos_exp.sum() / (denominator + 1e-8))
287
+ losses.append(loss_i)
288
+
289
+ if len(losses) == 0:
290
+ return torch.tensor(0.0, device=embeddings.device)
291
+
292
+ return torch.stack(losses).mean()
293
+
294
+
295
+ class TD3BTotalLoss:
296
+ """
297
+ Combined TD3B loss: L_total = L_WDCE + λ * L_ctr + β * L_KL
298
+
299
+ Components:
300
+ - L_WDCE: Weighted Denoising Cross-Entropy (from TR2-D2)
301
+ - L_ctr: Contrastive loss for agonist/antagonist separation
302
+ - L_KL: KL divergence regularization between policy and reference model
303
+ """
304
+
305
+ def __init__(
306
+ self,
307
+ contrastive_weight: float = 0.1,
308
+ contrastive_margin: float = 1.0,
309
+ contrastive_type: str = 'margin', # 'margin' or 'infonce'
310
+ kl_beta: float = 0.1, # β coefficient for KL divergence
311
+ reference_model: Optional[nn.Module] = None,
312
+ adaptive_margin: bool = True # Enable adaptive margin by default
313
+ ):
314
+ """
315
+ Args:
316
+ contrastive_weight: λ coefficient for contrastive loss
317
+ contrastive_margin: Margin for margin-based contrastive loss (base margin if adaptive)
318
+ contrastive_type: Type of contrastive loss ('margin' or 'infonce')
319
+ kl_beta: β coefficient for KL divergence regularization
320
+ reference_model: Frozen reference model for KL divergence (deepcopy of pretrained)
321
+ adaptive_margin: If True, automatically adjust margin based on dissimilar distances
322
+ """
323
+ self.contrastive_weight = contrastive_weight
324
+ self.kl_beta = kl_beta
325
+ self.reference_model = reference_model
326
+
327
+ # Freeze reference model if provided
328
+ if self.reference_model is not None:
329
+ self.reference_model.eval()
330
+ for param in self.reference_model.parameters():
331
+ param.requires_grad = False
332
+
333
+ # Verify all parameters are frozen
334
+ assert all(not p.requires_grad for p in self.reference_model.parameters()), \
335
+ "ERROR: Reference model has parameters with requires_grad=True!"
336
+
337
+ if contrastive_type == 'margin':
338
+ self.contrastive_loss = ContrastiveLoss(
339
+ margin=contrastive_margin,
340
+ distance_metric='euclidean',
341
+ adaptive_margin=adaptive_margin
342
+ )
343
+ elif contrastive_type == 'infonce':
344
+ self.contrastive_loss = InfoNCELoss(temperature=0.1)
345
+ else:
346
+ raise ValueError(f"Unknown contrastive type: {contrastive_type}")
347
+
348
+ def compute_kl_categorical(
349
+ self,
350
+ log_p: torch.Tensor,
351
+ log_ref_p: torch.Tensor
352
+ ) -> torch.Tensor:
353
+ """
354
+ Compute KL divergence between categorical distributions.
355
+
356
+ KL(P || Q) = Σ P(x) * log(P(x) / Q(x))
357
+ = Σ P(x) * (log P(x) - log Q(x))
358
+
359
+ Args:
360
+ log_p: (B, L, Vocab) log-probabilities from policy model
361
+ log_ref_p: (B, L, Vocab) log-probabilities from reference model
362
+ Returns:
363
+ kl: (B, L) KL divergence per position
364
+ """
365
+ # Convert log-probs to probabilities
366
+ p = torch.exp(log_p) # (B, L, Vocab)
367
+
368
+ # KL divergence element-wise
369
+ kl_elementwise = p * (log_p - log_ref_p) # (B, L, Vocab)
370
+
371
+ # Handle numerical issues: 0 * log(0) should be 0
372
+ # Replace NaNs or Infs that occur at -inf locations with 0
373
+ kl_elementwise = torch.where(
374
+ torch.isfinite(kl_elementwise),
375
+ kl_elementwise,
376
+ torch.zeros_like(kl_elementwise)
377
+ )
378
+
379
+ # Sum over vocabulary dimension
380
+ kl = kl_elementwise.sum(dim=-1) # (B, L)
381
+
382
+ return kl
383
+
384
+ def compute_kl_loss(
385
+ self,
386
+ policy_model: nn.Module,
387
+ sequences: torch.Tensor,
388
+ attn_mask: torch.Tensor,
389
+ sigma: torch.Tensor
390
+ ) -> torch.Tensor:
391
+ """
392
+ Compute KL divergence loss between policy model and reference model.
393
+
394
+ Args:
395
+ policy_model: Current policy model (being trained)
396
+ sequences: (B, L) input sequences
397
+ attn_mask: (B, L) attention mask
398
+ sigma: (B,) noise schedule
399
+ Returns:
400
+ kl_loss: Scalar KL divergence loss
401
+ """
402
+ if self.reference_model is None:
403
+ return torch.tensor(0.0, device=sequences.device)
404
+
405
+ # Ensure reference model is in eval mode
406
+ assert not self.reference_model.training, \
407
+ "ERROR: Reference model is in training mode! It should always be in eval mode."
408
+
409
+ # Forward through policy model (already computed in WDCE, but need logits)
410
+ policy_logits = policy_model(sequences, attn_mask=attn_mask, sigma=sigma) # (B, L, Vocab)
411
+
412
+ # Forward through reference model (frozen, no gradients)
413
+ with torch.no_grad():
414
+ ref_logits = self.reference_model(sequences, attn_mask=attn_mask, sigma=sigma) # (B, L, Vocab)
415
+
416
+ # Convert to log-probabilities
417
+ log_p = F.log_softmax(policy_logits, dim=-1) # (B, L, Vocab)
418
+ log_ref_p = F.log_softmax(ref_logits, dim=-1) # (B, L, Vocab)
419
+
420
+ # Compute KL divergence
421
+ kl_per_position = self.compute_kl_categorical(log_p, log_ref_p) # (B, L)
422
+
423
+ # Mask out padding positions
424
+ kl_masked = kl_per_position * attn_mask.float() # (B, L)
425
+
426
+ # Average over all non-padding positions
427
+ num_valid = attn_mask.float().sum()
428
+ kl_loss = kl_masked.sum() / (num_valid + 1e-8)
429
+
430
+ return kl_loss
431
+
432
+ def compute_loss(
433
+ self,
434
+ wdce_loss: torch.Tensor,
435
+ embeddings: torch.Tensor,
436
+ directional_labels: torch.Tensor,
437
+ confidences: Optional[torch.Tensor] = None,
438
+ kl_loss: Optional[torch.Tensor] = None,
439
+ debug: bool = False
440
+ ) -> Tuple[torch.Tensor, dict]:
441
+ """
442
+ Compute total TD3B loss.
443
+
444
+ Args:
445
+ wdce_loss: Precomputed WDCE loss (scalar)
446
+ embeddings: (batch_size, embedding_dim) sequence embeddings from MDLM
447
+ directional_labels: (batch_size,) labels in {-1, +1}
448
+ confidences: (batch_size,) oracle confidence scores; pairs with product <= 0 are masked out
449
+ kl_loss: Precomputed KL divergence loss (optional)
450
+ debug: If True, enable debugging output in contrastive loss
451
+ Returns:
452
+ total_loss: Combined loss
453
+ loss_dict: Dictionary with individual loss components
454
+ """
455
+ # Contrastive loss (pass debug flag)
456
+ contrastive_loss = self.contrastive_loss(
457
+ embeddings,
458
+ directional_labels,
459
+ confidences=confidences,
460
+ debug=debug
461
+ )
462
+
463
+ # KL divergence loss
464
+ if kl_loss is None:
465
+ kl_loss = torch.tensor(0.0, device=embeddings.device)
466
+
467
+ # Total loss: L_total = L_WDCE + λ * L_ctr + β * L_KL
468
+ total_loss = wdce_loss + self.contrastive_weight * contrastive_loss + self.kl_beta * kl_loss
469
+
470
+ loss_dict = {
471
+ 'total_loss': total_loss.item(),
472
+ 'wdce_loss': wdce_loss.item(),
473
+ 'contrastive_loss': contrastive_loss.item(),
474
+ 'kl_loss': kl_loss.item() if isinstance(kl_loss, torch.Tensor) else kl_loss
475
+ }
476
+
477
+ return total_loss, loss_dict
478
+
479
+
480
+ def extract_embeddings_from_mdlm(
481
+ model,
482
+ sequences: torch.Tensor,
483
+ pool_method: str = 'mean'
484
+ ) -> torch.Tensor:
485
+ """
486
+ Extract sequence-level embeddings from MDLM backbone.
487
+
488
+ Args:
489
+ model: MDLM model with backbone (Roformer)
490
+ sequences: (batch_size, seq_len) token sequences
491
+ pool_method: 'mean', 'max', or 'cls'
492
+ Returns:
493
+ embeddings: (batch_size, hidden_dim)
494
+ """
495
+ # Create attention mask (1 for real tokens, 0 for padding)
496
+ attn_mask = (sequences != 0).long() # (B, L)
497
+
498
+ # Forward through Roformer backbone to get hidden states
499
+ # IMPORTANT: DO NOT use torch.no_grad() here - we need gradients for backprop!
500
+ # Access the underlying RoFormerForMaskedLM model and request hidden states
501
+ outputs = model.backbone.model(
502
+ input_ids=sequences,
503
+ attention_mask=attn_mask,
504
+ output_hidden_states=True,
505
+ return_dict=True
506
+ )
507
+
508
+ # Extract last hidden state from outputs
509
+ # outputs.hidden_states is a tuple of (embedding_output, layer1, layer2, ..., layerN)
510
+ # We want the last layer's hidden states
511
+ hidden_states = outputs.hidden_states[-1] # (B, L, D)
512
+
513
+ # Pool to get sequence-level embeddings
514
+ if pool_method == 'mean':
515
+ # Mean pooling (ignore padding)
516
+ mask = attn_mask.float().unsqueeze(-1) # (B, L, 1)
517
+ pooled = (hidden_states * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-8) # (B, D)
518
+ elif pool_method == 'max':
519
+ # Max pooling
520
+ pooled = hidden_states.max(dim=1)[0] # (B, D)
521
+ elif pool_method == 'cls':
522
+ # Use first token (CLS-style)
523
+ pooled = hidden_states[:, 0, :] # (B, D)
524
+ else:
525
+ raise ValueError(f"Unknown pool method: {pool_method}")
526
+
527
+ return pooled
td3b/td3b_mcts.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD3B-specific MCTS modifications.
3
+ Extends the base MCTS to support directional rewards and confidence weighting.
4
+ """
5
+
6
+ import numpy as np
7
+ import torch
8
+ from peptide_mcts import MCTS as BaseMCTS
9
+ from .td3b_scoring import TD3BRewardFunction, TD3BConfidenceWeighting
10
+
11
+
12
+ class TD3B_MCTS(BaseMCTS):
13
+ """
14
+ TD3B version of MCTS that:
15
+ 1. Uses gated directional rewards instead of multi-objective scalarization
16
+ 2. Stores directional labels and confidence scores in the buffer
17
+ 3. Applies confidence-weighted importance sampling
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ args,
23
+ diffusion_model,
24
+ td3b_reward_function: TD3BRewardFunction,
25
+ confidence_weighting: TD3BConfidenceWeighting,
26
+ mask_index: int,
27
+ buffer_size: int = 100,
28
+ noise=None,
29
+ tokenizer=None
30
+ ):
31
+ """
32
+ Args:
33
+ args: Configuration arguments
34
+ diffusion_model: MDLM model for sampling
35
+ td3b_reward_function: TD3BRewardFunction instance
36
+ confidence_weighting: TD3BConfidenceWeighting instance
37
+ mask_index: Token ID for masked positions
38
+ buffer_size: Maximum buffer size
39
+ noise: Noise schedule
40
+ tokenizer: Peptide tokenizer
41
+ """
42
+ # Initialize base MCTS (will set self.rewardFunc later)
43
+ # Note: base MCTS expects 'policy_model' not 'diffusion_model'
44
+ # Create a minimal config object for base MCTS
45
+ class MinimalConfig:
46
+ def __init__(self):
47
+ self.noise = type('obj', (object,), {
48
+ 'type': 'loglinear',
49
+ 'sigma_min': 1e-4,
50
+ 'sigma_max': 20
51
+ })()
52
+ config = MinimalConfig()
53
+
54
+ super().__init__(
55
+ args=args,
56
+ config=config,
57
+ policy_model=diffusion_model,
58
+ pretrained=diffusion_model, # Use same model
59
+ score_func_names=['affinity', 'gated_reward', 'placeholder1', 'placeholder2', 'placeholder3'] # 5 objectives
60
+ )
61
+
62
+ # Set TD3B-specific attributes
63
+ self.td3b_reward_func = td3b_reward_function
64
+ self.confidence_weighting = confidence_weighting
65
+ self.mask_index = mask_index
66
+ self.buffer_size = buffer_size
67
+ self.noise = noise
68
+ self.tokenizer = tokenizer if tokenizer is not None else diffusion_model.tokenizer
69
+
70
+ # Override num_obj to ensure it's 5 (matching our padded rewards)
71
+ self.num_obj = 5
72
+
73
+ # Override rewardFunc for compatibility
74
+ self.rewardFunc = self._td3b_reward_wrapper
75
+
76
+ def _td3b_reward_wrapper(self, input_seqs):
77
+ """
78
+ Wrapper to make TD3BRewardFunction compatible with existing MCTS interface.
79
+ Returns (N, 5) array to match base MCTS expectations.
80
+ The 5 columns are: [affinity, gated_reward, 0, 0, 0] (padding last 3)
81
+ """
82
+ import numpy as np
83
+ total_rewards, info = self.td3b_reward_func(input_seqs)
84
+ # info contains: 'affinities', 'confidences', 'score_vectors'
85
+
86
+ # Store confidences for later use (attach to self for access in updateBuffer)
87
+ self._last_confidences = info['confidences']
88
+
89
+ # Pad score_vectors from (N, 2) to (N, 5) to match base MCTS
90
+ # Original columns: [affinity, gated_reward]
91
+ # Padded to: [affinity, gated_reward, 0, 0, 0]
92
+ score_vectors = info['score_vectors'] # (N, 2)
93
+ padded = np.zeros((score_vectors.shape[0], 5))
94
+ padded[:, :2] = score_vectors # Copy affinity and gated_reward
95
+
96
+ return padded
97
+
98
+ def updateBuffer(self, x_final, log_rnd, score_vectors, childSequences):
99
+ """
100
+ TD3B version: stores directional labels and confidence scores.
101
+
102
+ Args:
103
+ x_final: (B, L) final sequence tokens
104
+ log_rnd: (B,) log importance weights (trajectory-level)
105
+ score_vectors: (B, K) score arrays
106
+ childSequences: List of B SMILES strings
107
+ Returns:
108
+ traj_log_rnds: (B,) updated log importance weights
109
+ scalar_rewards: (B,) scalar rewards
110
+ """
111
+ B = x_final.shape[0]
112
+ traj_log_rnds, scalar_rewards = [], []
113
+
114
+ # Get confidences from last reward computation
115
+ confidences = getattr(self, '_last_confidences', np.ones(B))
116
+
117
+ for i in range(B):
118
+ sv = np.asarray(score_vectors[i], dtype=float) # [affinity, gated_reward]
119
+ confidence = confidences[i]
120
+
121
+ # For TD3B, the "scalar reward" is the gated reward (second element)
122
+ scalar_reward = float(sv[1]) # gated_reward = g_ψ · (d* · sigmoid(f_φ-0.5)/α)
123
+
124
+ # Compute confidence-weighted importance weight
125
+ # w(y) = κ(y) · exp(S_total / α)
126
+ # In log space: log w(y) = log κ(y) + S_total / α
127
+ log_confidence = np.log(np.maximum(confidence, self.confidence_weighting.min_confidence))
128
+ traj_log_rnd = log_rnd[i] + (scalar_reward / self.args.alpha) + log_confidence
129
+
130
+ # Infer directional label from oracle (sign of gated reward)
131
+ # If gated_reward > 0, peptide is predicted as target direction
132
+ # This is approximate; in practice you might want to query f_φ directly
133
+ directional_label = np.sign(scalar_reward) if scalar_reward != 0 else 0.0
134
+
135
+ item = {
136
+ "x_final": x_final[i].clone(),
137
+ "log_rnd": traj_log_rnd.clone() if isinstance(traj_log_rnd, torch.Tensor) else torch.tensor(traj_log_rnd),
138
+ "final_reward": scalar_reward,
139
+ "score_vector": sv.copy(),
140
+ "seq": childSequences[i],
141
+ # TD3B-specific additions
142
+ "directional_label": directional_label,
143
+ "confidence": confidence,
144
+ }
145
+
146
+ # Pareto dominance filtering (same as base class)
147
+ from peptide_mcts import dominated_by, dominates
148
+
149
+ if any(dominated_by(sv, bi["score_vector"]) for bi in self.buffer):
150
+ self._debug_buffer_decision(sv, "rejected_dominated")
151
+ continue
152
+
153
+ # Remove dominated items
154
+ keep = []
155
+ for bi in self.buffer:
156
+ if not dominates(sv, bi["score_vector"]):
157
+ keep.append(bi)
158
+ self.buffer = keep
159
+
160
+ # Insert with capacity constraint
161
+ if len(self.buffer) < self.buffer_size:
162
+ self.buffer.append(item)
163
+ else:
164
+ # Replace worst item
165
+ worst_i = int(np.argmin([np.sum(bi["score_vector"]) for bi in self.buffer]))
166
+ self.buffer[worst_i] = item
167
+
168
+ self._debug_buffer_decision(sv, "inserted", {"new_len": len(self.buffer)})
169
+
170
+ traj_log_rnds.append(traj_log_rnd)
171
+ scalar_rewards.append(scalar_reward)
172
+
173
+ traj_log_rnds = torch.stack([torch.tensor(x) if not isinstance(x, torch.Tensor) else x for x in traj_log_rnds], dim=0) if traj_log_rnds else torch.empty(0)
174
+ scalar_rewards = np.asarray(scalar_rewards, dtype=float)
175
+ return traj_log_rnds, scalar_rewards
176
+
177
+ def forward(self, resetTree=False):
178
+ """
179
+ TD3B version of forward that returns 7 values.
180
+
181
+ Returns:
182
+ x_final: (N, L) sequence tokens
183
+ log_rnd: (N,) log importance weights
184
+ final_rewards: (N,) scalar rewards
185
+ score_vectors: (N, K) score arrays
186
+ sequences: List of N SMILES strings
187
+ directional_labels: (N,) directional labels
188
+ confidences: (N,) confidence scores
189
+ """
190
+ self.reset(resetTree)
191
+
192
+ while (self.iter_num < self.num_iter):
193
+ self.iter_num += 1
194
+
195
+ # traverse the tree form the root node until a leaf node
196
+ with self.timer.section("select"):
197
+ leafNode, _ = self.select(self.rootNode)
198
+
199
+ # expand leaf node into num_children partially unmasked sequences at the next timestep
200
+ with self.timer.section("expand"):
201
+ self.expand(leafNode)
202
+
203
+ final_x, log_rnd, final_rewards, score_vectors, sequences, directional_labels, confidences = self.consolidateBuffer()
204
+
205
+ rows = self.timer.summary()
206
+ print("\n=== Timing summary (by total time) ===")
207
+ for name, cnt, total, mean, p50, p95 in rows:
208
+ print(f"{name:30s} n={cnt:5d} total={total:8.3f}s mean={mean*1e3:7.2f}ms "
209
+ f"p50={p50*1e3:7.2f}ms p95={p95*1e3:7.2f}ms")
210
+
211
+ return final_x, log_rnd, final_rewards, score_vectors, sequences, directional_labels, confidences
212
+
213
+ def consolidateBuffer(self):
214
+ """
215
+ TD3B version: includes directional labels and confidences.
216
+
217
+ Returns:
218
+ x_final: (N, L) sequence tokens
219
+ log_rnd: (N,) log importance weights
220
+ final_rewards: (N,) scalar rewards
221
+ score_vectors: (N, K) score arrays
222
+ sequences: List of N SMILES strings
223
+ directional_labels: (N,) directional labels
224
+ confidences: (N,) confidence scores
225
+ """
226
+ # Handle empty buffer case - return empty tensors/arrays
227
+ if len(self.buffer) == 0:
228
+ import logging
229
+ logger = logging.getLogger(__name__)
230
+ logger.warning("MCTS buffer is empty - no valid sequences found. Returning empty results.")
231
+
232
+ # Return empty tensors/arrays with correct shapes
233
+ # Use policy_model (set by base MCTS class) to get device
234
+ device = self.policy_model.device if hasattr(self.policy_model, 'device') else 'cpu'
235
+ return (
236
+ torch.empty(0, 0, dtype=torch.long, device=device), # x_final: (0, 0)
237
+ torch.empty(0, dtype=torch.float32, device=device), # log_rnd: (0,)
238
+ np.empty(0, dtype=np.float32), # final_rewards: (0,)
239
+ np.empty((0, 0), dtype=np.float32), # score_vectors: (0, 0)
240
+ [], # sequences: empty list
241
+ np.empty(0, dtype=np.float32), # directional_labels: (0,)
242
+ np.empty(0, dtype=np.float32) # confidences: (0,)
243
+ )
244
+
245
+ x_final = []
246
+ log_rnd = []
247
+ final_rewards = []
248
+ score_vectors = []
249
+ sequences = []
250
+ directional_labels = []
251
+ confidences = []
252
+
253
+ for item in self.buffer:
254
+ x_final.append(item["x_final"])
255
+ log_rnd.append(item["log_rnd"])
256
+ final_rewards.append(item["final_reward"])
257
+ score_vectors.append(item["score_vector"])
258
+ sequences.append(item["seq"])
259
+ directional_labels.append(item.get("directional_label", 0.0))
260
+ confidences.append(item.get("confidence", 1.0))
261
+
262
+ x_final = torch.stack(x_final, dim=0) # (N, L)
263
+ log_rnd = torch.stack(log_rnd, dim=0).to(dtype=torch.float32) # (N,)
264
+ final_rewards = np.stack(final_rewards, axis=0).astype(np.float32)
265
+ score_vectors = np.stack(score_vectors, axis=0).astype(np.float32)
266
+ directional_labels = np.array(directional_labels, dtype=np.float32)
267
+ confidences = np.array(confidences, dtype=np.float32)
268
+
269
+ return x_final, log_rnd, final_rewards, score_vectors, sequences, directional_labels, confidences
270
+
271
+
272
+ def create_td3b_mcts(
273
+ args,
274
+ diffusion_model,
275
+ td3b_reward_function: TD3BRewardFunction,
276
+ alpha: float = 0.1,
277
+ **kwargs
278
+ ) -> TD3B_MCTS:
279
+ """
280
+ Factory function to create TD3B MCTS instance.
281
+
282
+ Args:
283
+ args: Configuration arguments
284
+ diffusion_model: MDLM model
285
+ td3b_reward_function: TD3BRewardFunction instance
286
+ alpha: Temperature for importance weighting
287
+ **kwargs: Additional MCTS arguments
288
+
289
+ Returns:
290
+ mcts: TD3B_MCTS instance
291
+ """
292
+ # Create confidence weighting module
293
+ confidence_weighting = TD3BConfidenceWeighting(
294
+ alpha=alpha,
295
+ min_confidence=0.1
296
+ )
297
+
298
+ # Create TD3B MCTS
299
+ mcts = TD3B_MCTS(
300
+ args=args,
301
+ diffusion_model=diffusion_model,
302
+ td3b_reward_function=td3b_reward_function,
303
+ confidence_weighting=confidence_weighting,
304
+ **kwargs
305
+ )
306
+
307
+ return mcts
td3b/td3b_scoring.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TD3B Scoring Functions
3
+ Implements gated allosteric reward combining affinity prediction and directional oracle.
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import numpy as np
9
+ from typing import List, Tuple, Optional
10
+ from .direction_oracle import DirectionalOracle
11
+ from scoring.functions.binding import BindingAffinity
12
+
13
+
14
+ class TD3BRewardFunction:
15
+ """
16
+ Implements the TD3B gated total reward with sigmoid temperature scaling:
17
+ S_total(y; d*, x) = g_ψ(y, x) · σ(d* · (f_φ(y, x) -0.5) / τ)
18
+
19
+ where:
20
+ - g_ψ(y, x): affinity predictor (BindingAffinity)
21
+ - σ: sigmoid function σ(z) = 1 / (1 + exp(-z))
22
+ - d* ∈ {+1, -1}: target direction (agonist/antagonist)
23
+ - f_φ(y, x): directional oracle (DirectionalOracle)
24
+ * Directional oracle outputs p(agonist) in [0, 1]
25
+ - τ: temperature parameter (lower = sharper distribution)
26
+ - y: peptide sequence
27
+ - x: target protein sequence
28
+
29
+ Note: The placeholder oracle outputs 0.5, which makes (f_φ - 0.5) = 0, resulting in
30
+ neutral gating during initial training before a real oracle is trained.
31
+
32
+ Benefits of sigmoid formulation:
33
+ 1. Output always in [0, 1] → bounded gated rewards
34
+ 2. Temperature τ controls sharpness of selection
35
+ 3. Differentiable gating for smooth optimization
36
+ 4. Sharper discrimination between aligned and misaligned directions
37
+
38
+ OLD FORMULA (replaced):
39
+ S_total(y; d*, x) = g_ψ(y, x) · (d* · f_φ(y, x))
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ affinity_predictor: BindingAffinity,
45
+ directional_oracle: DirectionalOracle,
46
+ target_direction: float, # +1 for agonist, -1 for antagonist
47
+ target_protein_tokens: torch.Tensor,
48
+ peptide_tokenizer,
49
+ device: torch.device,
50
+ min_affinity_threshold: float = 0.0, # Minimum g_ψ for allosteric control
51
+ use_confidence_weighting: bool = True,
52
+ temperature: float = 0.1 # Temperature for sigmoid sharpening
53
+ ):
54
+ """
55
+ Args:
56
+ affinity_predictor: Pretrained g_ψ model (BindingAffinity)
57
+ directional_oracle: Pretrained f_φ model (DirectionalOracle)
58
+ target_direction: d* in {+1, -1} for agonist/antagonist
59
+ target_protein_tokens: Tokenized target protein sequence
60
+ peptide_tokenizer: Tokenizer for converting SMILES to tokens
61
+ device: Computation device
62
+ min_affinity_threshold: Only apply directional control if g_ψ > threshold
63
+ use_confidence_weighting: Whether to use κ(y) for importance weights
64
+ temperature: Temperature τ for sigmoid sharpening (lower = sharper)
65
+ Default 0.1 makes distribution sharper than standard sigmoid
66
+ """
67
+ self.g_psi = affinity_predictor # Affinity predictor
68
+ self.f_phi = directional_oracle # Directional oracle
69
+ self.target_direction = target_direction # d* ∈ {+1, -1}
70
+ self.protein_tokens = target_protein_tokens
71
+ self.peptide_tokenizer = peptide_tokenizer
72
+ self.device = device
73
+ self.min_affinity_threshold = min_affinity_threshold
74
+ self.use_confidence_weighting = use_confidence_weighting
75
+ self.temperature = temperature # τ for sigmoid temperature
76
+
77
+ def compute_affinity(self, peptide_seqs: List[str]) -> np.ndarray:
78
+ """
79
+ Compute binding affinity g_ψ(y, x).
80
+
81
+ Args:
82
+ peptide_seqs: List of peptide SMILES strings
83
+ Returns:
84
+ affinities: (N,) array of affinity scores
85
+ """
86
+ affinities = self.g_psi(peptide_seqs) # Returns list of scores
87
+ return np.array(affinities)
88
+
89
+ def compute_direction(self, peptide_seqs: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
90
+ """
91
+ Compute directional bias f_φ(y, x) and confidence κ(y).
92
+
93
+ Args:
94
+ peptide_seqs: List of peptide SMILES strings
95
+ Returns:
96
+ directions: (N,) tensor of directional biases
97
+ - DirectionalOracle: p(agonist) in [0, 1]
98
+ confidences: (N,) tensor of confidence scores in [0, 1]
99
+ """
100
+ # Tokenize peptides in a single batch for speed
101
+ peptide_tokens = None
102
+ peptide_token_dict = None
103
+ try:
104
+ peptide_token_dict = self.peptide_tokenizer(
105
+ peptide_seqs,
106
+ return_tensors='pt',
107
+ padding=True
108
+ )
109
+ peptide_token_dict = {k: v.to(self.device) for k, v in peptide_token_dict.items()}
110
+ peptide_tokens = peptide_token_dict.get('input_ids')
111
+ except Exception:
112
+ peptide_tokens_list = []
113
+ for seq in peptide_seqs:
114
+ tokens = self.peptide_tokenizer(seq, return_tensors='pt', padding=True)
115
+ peptide_tokens_list.append(tokens['input_ids'].to(self.device))
116
+
117
+ # Batch tokenization (simple stacking, assumes same length after padding)
118
+ try:
119
+ peptide_tokens = torch.cat(peptide_tokens_list, dim=0) # (N, L)
120
+ except Exception:
121
+ # Fallback: pad to max length
122
+ max_len = max(t.size(1) for t in peptide_tokens_list)
123
+ peptide_tokens = torch.zeros(len(peptide_tokens_list), max_len, dtype=torch.long, device=self.device)
124
+ for i, tokens in enumerate(peptide_tokens_list):
125
+ peptide_tokens[i, :tokens.size(1)] = tokens[0]
126
+
127
+ # Expand protein tokens to batch size
128
+ protein_tokens = self.protein_tokens.expand(len(peptide_seqs), -1) # (N, L_prot)
129
+
130
+ # Compute direction and confidence
131
+ with torch.no_grad():
132
+ if peptide_token_dict is not None and hasattr(self.f_phi, "_normalize_token_dict"):
133
+ directions, confidences = self.f_phi.predict_with_confidence(
134
+ peptide_token_dict, protein_tokens
135
+ )
136
+ else:
137
+ directions, confidences = self.f_phi.predict_with_confidence(
138
+ peptide_tokens, protein_tokens
139
+ )
140
+
141
+ return directions, confidences
142
+
143
+ def compute_gated_reward(
144
+ self,
145
+ peptide_seqs: List[str]
146
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
147
+ """
148
+ Compute gated total reward with sigmoid temperature scaling.
149
+
150
+ NEW FORMULA:
151
+ S_total = g_ψ · σ(d* · (f_φ-0.5) / τ)
152
+
153
+ Where:
154
+ - g_ψ: affinity score
155
+ - σ: sigmoid function
156
+ - d*: target direction (+1 or -1)
157
+ - f_φ: directional oracle prediction (in [-1, +1])
158
+ - τ: temperature (lower = sharper distribution)
159
+
160
+ OLD FORMULA (replaced):
161
+ S_total = g_ψ · (d* · f_φ)
162
+
163
+ Args:
164
+ peptide_seqs: List of peptide SMILES strings
165
+ Returns:
166
+ total_rewards: (N,) array of gated total rewards
167
+ affinities: (N,) array of affinity scores g_ψ
168
+ confidences: (N,) array of confidence scores κ
169
+ directions: (N,) array of directional predictions f_φ
170
+ """
171
+ # Compute affinity g_ψ(y, x)
172
+ affinities = self.compute_affinity(peptide_seqs) # (N,)
173
+
174
+ # Compute directional bias f_φ(y, x) and confidence κ(y)
175
+ directions, confidences = self.compute_direction(peptide_seqs) # (N,), (N,)
176
+ directions = directions.cpu().numpy()
177
+ confidences = confidences.cpu().numpy()
178
+
179
+ # NEW: Sigmoid-based gated reward with temperature scaling
180
+ # S_total = g_ψ · σ(d* · (f_φ-0.5) / τ), use 0.5 as the threshold to make it balanced/symmetric.
181
+ directional_score = self.target_direction * (directions - 0.5) # (N,) in [-1, +1]
182
+
183
+ # Apply temperature scaling (lower τ → sharper sigmoid)
184
+ scaled_score = directional_score / self.temperature # (N,)
185
+
186
+ # Apply sigmoid to get value in [0, 1]
187
+ # σ(x) = 1 / (1 + exp(-x))
188
+ sigmoid_weight = 1.0 / (1.0 + np.exp(-scaled_score)) # (N,) in [0, 1]
189
+
190
+ # Gate affinity with sigmoid weight
191
+ gated_rewards = affinities * sigmoid_weight # (N,)
192
+
193
+ # Optional: only apply directional control if affinity is high enough
194
+ # This implements the "allosteric control only for binders" principle
195
+ low_affinity_mask = affinities < self.min_affinity_threshold
196
+ gated_rewards[low_affinity_mask] = affinities[low_affinity_mask] * 0.1 # Downweight
197
+
198
+ return gated_rewards, affinities, confidences, directions
199
+
200
+ def __call__(
201
+ self,
202
+ input_seqs: List[str]
203
+ ) -> Tuple[np.ndarray, dict]:
204
+ """
205
+ Main interface for reward computation.
206
+
207
+ Args:
208
+ input_seqs: List of peptide SMILES strings
209
+ Returns:
210
+ rewards: (N,) array of total rewards
211
+ info: dict with 'affinities', 'confidences', 'directions', 'score_vectors'
212
+ """
213
+ total_rewards, affinities, confidences, directions = self.compute_gated_reward(input_seqs)
214
+
215
+ info = {
216
+ 'affinities': affinities,
217
+ 'confidences': confidences,
218
+ 'directions': directions, # Add direction predictions
219
+ 'score_vectors': np.stack([affinities, total_rewards], axis=1) # (N, 2)
220
+ }
221
+
222
+ return total_rewards, info
223
+
224
+
225
+ class TD3BConfidenceWeighting:
226
+ """
227
+ Implements confidence-weighted importance sampling for TD3B.
228
+
229
+ The importance weights w(y) are modulated by confidence κ(y):
230
+ w(y) = κ(y) · exp(S_total(y) / α)
231
+
232
+ This distinguishes between:
233
+ - Full agonists/antagonists: high κ (|f_φ| ≈ 1)
234
+ - Partial agonists/antagonists: medium κ (|f_φ| ≈ 0.5)
235
+ - Non-selective binders: low κ (|f_φ| ≈ 0)
236
+ """
237
+
238
+ def __init__(
239
+ self,
240
+ alpha: float = 0.1, # Temperature for reward scaling
241
+ min_confidence: float = 0.1 # Minimum confidence to avoid zero weights
242
+ ):
243
+ """
244
+ Args:
245
+ alpha: Temperature parameter for reward scaling
246
+ min_confidence: Minimum confidence threshold
247
+ """
248
+ self.alpha = alpha
249
+ self.min_confidence = min_confidence
250
+
251
+ def compute_importance_weights(
252
+ self,
253
+ rewards: np.ndarray,
254
+ confidences: np.ndarray
255
+ ) -> np.ndarray:
256
+ """
257
+ Compute confidence-weighted importance weights.
258
+
259
+ Args:
260
+ rewards: (N,) array of total rewards S_total
261
+ confidences: (N,) array of confidence scores κ ∈ [0, 1]
262
+ Returns:
263
+ weights: (N,) array of importance weights
264
+ """
265
+ # Clip confidences to avoid zero weights
266
+ confidences = np.maximum(confidences, self.min_confidence)
267
+
268
+ # Compute importance weights: w(y) = κ(y) · exp(S_total / α)
269
+ log_weights = rewards / self.alpha # (N,)
270
+ weights = confidences * np.exp(log_weights) # (N,)
271
+
272
+ return weights
273
+
274
+ def compute_log_importance_weights(
275
+ self,
276
+ rewards: np.ndarray,
277
+ confidences: np.ndarray
278
+ ) -> np.ndarray:
279
+ """
280
+ Compute log importance weights for numerical stability.
281
+
282
+ Args:
283
+ rewards: (N,) array of total rewards
284
+ confidences: (N,) array of confidence scores
285
+ Returns:
286
+ log_weights: (N,) array of log importance weights
287
+ """
288
+ # Clip confidences
289
+ confidences = np.maximum(confidences, self.min_confidence)
290
+
291
+ # log w(y) = log κ(y) + S_total / α
292
+ log_weights = np.log(confidences) + (rewards / self.alpha) # (N,)
293
+
294
+ return log_weights
295
+
296
+
297
+ # Factory function for creating TD3B reward function
298
+ def create_td3b_reward_function(
299
+ affinity_predictor: BindingAffinity,
300
+ target_protein_seq: str,
301
+ target_direction: str, # 'agonist' or 'antagonist'
302
+ peptide_tokenizer,
303
+ device: torch.device,
304
+ directional_oracle: Optional[DirectionalOracle] = None,
305
+ directional_oracle_checkpoint: Optional[str] = None,
306
+ base_path: Optional[str] = None,
307
+ direction_oracle_tr2d2_checkpoint: Optional[str] = None,
308
+ direction_oracle_tokenizer_vocab: Optional[str] = None,
309
+ direction_oracle_tokenizer_splits: Optional[str] = None,
310
+ direction_oracle_esm_name: str = "facebook/esm2_t33_650M_UR50D",
311
+ direction_oracle_esm_cache_dir: Optional[str] = None,
312
+ direction_oracle_esm_local_files_only: bool = False,
313
+ direction_oracle_max_ligand_length: int = 768,
314
+ direction_oracle_max_protein_length: int = 1024,
315
+ direction_oracle_d_model: int = 256,
316
+ direction_oracle_n_heads: int = 4,
317
+ direction_oracle_n_self_attn_layers: int = 1,
318
+ direction_oracle_n_bmca_layers: int = 2,
319
+ direction_oracle_dropout: float = 0.3,
320
+ **kwargs
321
+ ) -> TD3BRewardFunction:
322
+ """
323
+ Factory function to create TD3B reward function.
324
+
325
+ Args:
326
+ affinity_predictor: Pretrained binding affinity model
327
+ directional_oracle: Preloaded DirectionalOracle instance (optional)
328
+ directional_oracle_checkpoint: Path to Directional oracle checkpoint (optional if instance provided)
329
+ base_path: Base path for default oracle assets
330
+ direction_oracle_tr2d2_checkpoint: TR2-D2 checkpoint for ligand encoder
331
+ direction_oracle_tokenizer_vocab: SMILES tokenizer vocab path
332
+ direction_oracle_tokenizer_splits: SMILES tokenizer splits path
333
+ target_protein_seq: Target protein amino acid sequence
334
+ target_direction: 'agonist' (+1) or 'antagonist' (-1)
335
+ peptide_tokenizer: Tokenizer for peptides
336
+ device: Computation device
337
+ **kwargs: Additional arguments for TD3BRewardFunction
338
+
339
+ Returns:
340
+ reward_function: TD3BRewardFunction instance
341
+ """
342
+ if directional_oracle is None:
343
+ if base_path is None:
344
+ base_path = "To Be Added"
345
+ tr2d2_root = os.path.join(base_path, "tr2d2-pep")
346
+ if directional_oracle_checkpoint is None:
347
+ directional_oracle_checkpoint = os.path.join(
348
+ tr2d2_root, "direction_oracle.pt"
349
+ )
350
+ if direction_oracle_tr2d2_checkpoint is None:
351
+ direction_oracle_tr2d2_checkpoint = os.path.join(
352
+ tr2d2_root, "pretrained", "peptune-pretrained.ckpt"
353
+ )
354
+ if direction_oracle_tokenizer_vocab is None:
355
+ direction_oracle_tokenizer_vocab = os.path.join(
356
+ tr2d2_root, "tokenizer", "new_vocab.txt"
357
+ )
358
+ if direction_oracle_tokenizer_splits is None:
359
+ direction_oracle_tokenizer_splits = os.path.join(
360
+ tr2d2_root, "tokenizer", "new_splits.txt"
361
+ )
362
+
363
+ directional_oracle = DirectionalOracle(
364
+ model_ckpt=directional_oracle_checkpoint,
365
+ tr2d2_checkpoint=direction_oracle_tr2d2_checkpoint,
366
+ tokenizer_vocab=direction_oracle_tokenizer_vocab,
367
+ tokenizer_splits=direction_oracle_tokenizer_splits,
368
+ esm_name=direction_oracle_esm_name,
369
+ d_model=direction_oracle_d_model,
370
+ n_heads=direction_oracle_n_heads,
371
+ n_self_attn_layers=direction_oracle_n_self_attn_layers,
372
+ n_bmca_layers=direction_oracle_n_bmca_layers,
373
+ dropout=direction_oracle_dropout,
374
+ max_ligand_length=direction_oracle_max_ligand_length,
375
+ max_protein_length=direction_oracle_max_protein_length,
376
+ device=device,
377
+ esm_cache_dir=direction_oracle_esm_cache_dir,
378
+ esm_local_files_only=direction_oracle_esm_local_files_only,
379
+ )
380
+
381
+ directional_oracle.eval()
382
+
383
+ protein_tokens = directional_oracle.encode_protein(target_protein_seq)
384
+
385
+ # Convert direction string to numerical value
386
+ direction_map = {'agonist': +1.0, 'antagonist': -1.0}
387
+ d_star = direction_map.get(target_direction.lower(), +1.0)
388
+
389
+ # Create reward function
390
+ reward_function = TD3BRewardFunction(
391
+ affinity_predictor=affinity_predictor,
392
+ directional_oracle=directional_oracle,
393
+ target_direction=d_star,
394
+ target_protein_tokens=protein_tokens,
395
+ peptide_tokenizer=peptide_tokenizer,
396
+ device=device,
397
+ **kwargs
398
+ )
399
+
400
+ return reward_function
tokenizer/my_tokenizers.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os
3
+ import re
4
+ from typing import List, Optional
5
+ from transformers import PreTrainedTokenizer
6
+ from SmilesPE.tokenizer import SPE_Tokenizer
7
+ import torch
8
+
9
+ def load_vocab(vocab_file):
10
+ """Loads a vocabulary file into a dictionary."""
11
+ vocab = collections.OrderedDict()
12
+ with open(vocab_file, "r", encoding="utf-8") as reader:
13
+ tokens = reader.readlines()
14
+ for index, token in enumerate(tokens):
15
+ token = token.rstrip("\n")
16
+ vocab[token] = index
17
+ return vocab
18
+
19
+ class Atomwise_Tokenizer(object):
20
+ """Run atom-level SMILES tokenization"""
21
+
22
+ def __init__(self):
23
+ """ Constructs a atom-level Tokenizer.
24
+ """
25
+ # self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
26
+ self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
27
+
28
+ self.regex = re.compile(self.regex_pattern)
29
+
30
+ def tokenize(self, text):
31
+ """ Basic Tokenization of a SMILES.
32
+ """
33
+ tokens = [token for token in self.regex.findall(text)]
34
+ return tokens
35
+
36
+ class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
37
+ r"""
38
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
39
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
40
+ should refer to the superclass for more information regarding methods.
41
+ Args:
42
+ vocab_file (:obj:`string`):
43
+ File containing the vocabulary.
44
+ spe_file (:obj:`string`):
45
+ File containing the trained SMILES Pair Encoding vocabulary.
46
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
47
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
48
+ token instead.
49
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
50
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
51
+ for sequence classification or for a text and a question for question answering.
52
+ It is also used as the last token of a sequence built with special tokens.
53
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
54
+ The token used for padding, for example when batching sequences of different lengths.
55
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
56
+ The classifier token which is used when doing sequence classification (classification of the whole
57
+ sequence instead of per-token classification). It is the first token of the sequence when built with
58
+ special tokens.
59
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
60
+ The token used for masking values. This is the token used when training this model with masked language
61
+ modeling. This is the token which the model will try to predict.
62
+ """
63
+
64
+ def __init__(self, vocab_file, spe_file,
65
+ unk_token="[UNK]",
66
+ sep_token="[SEP]",
67
+ pad_token="[PAD]",
68
+ cls_token="[CLS]",
69
+ mask_token="[MASK]",
70
+ **kwargs):
71
+ if not os.path.isfile(vocab_file):
72
+ raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
73
+ if not os.path.isfile(spe_file):
74
+ raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
75
+
76
+ self.vocab = load_vocab(vocab_file)
77
+ self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
78
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
79
+ self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
80
+
81
+ super().__init__(
82
+ unk_token=unk_token,
83
+ sep_token=sep_token,
84
+ pad_token=pad_token,
85
+ cls_token=cls_token,
86
+ mask_token=mask_token,
87
+ **kwargs)
88
+
89
+ @property
90
+ def vocab_size(self):
91
+ return len(self.vocab)
92
+
93
+ def get_vocab(self):
94
+ return dict(self.vocab, **self.added_tokens_encoder)
95
+
96
+ def _tokenize(self, text):
97
+ return self.spe_tokenizer.tokenize(text).split(' ')
98
+
99
+ def _convert_token_to_id(self, token):
100
+ """ Converts a token (str) in an id using the vocab. """
101
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
102
+
103
+ # changed encode and decode functions
104
+ def encode(self, token_array):
105
+ token_ids = []
106
+ token_ids.append(2)
107
+ for token in token_array:
108
+ id = self._convert_token_to_id(token)
109
+ token_ids.append(id)
110
+ token_ids.append(3)
111
+ token_ids = torch.tensor([token_ids])
112
+ attn_mask = torch.ones_like(token_ids)
113
+ return {'input_ids': token_ids, 'attention_mask': attn_mask}
114
+
115
+ def decode(self, token_ids, skip_special_tokens=True):
116
+ token_ids = token_ids.squeeze(0).cpu().tolist()
117
+ token_array = []
118
+ for idx in token_ids:
119
+ if idx == 3: # Stop decoding when token ID 3 is encountered
120
+ break
121
+ if skip_special_tokens and idx in self.all_special_ids:
122
+ continue
123
+ token = self._convert_id_to_token(idx)
124
+ token_array.append(token)
125
+ sequence = "".join(token_array)
126
+ return sequence
127
+
128
+ def batch_decode(self, batch_token_ids, skip_special_tokens=True):
129
+ sequences = []
130
+ for token_ids in batch_token_ids:
131
+ sequences.append(self.decode(token_ids))
132
+ return sequences
133
+
134
+ def get_token_split(self, token_ids):
135
+ if isinstance(token_ids, torch.Tensor):
136
+ token_ids = token_ids.cpu().tolist()
137
+
138
+ token_array = []
139
+ for seq_ids in token_ids:
140
+ seq_array = []
141
+ for id in seq_ids:
142
+ token = self._convert_id_to_token(id)
143
+ seq_array.append(token)
144
+ token_array.append(seq_array)
145
+
146
+ return token_array
147
+
148
+ def _convert_id_to_token(self, index):
149
+ """Converts an index (integer) in a token (str) using the vocab."""
150
+ return self.ids_to_tokens.get(index, self.unk_token)
151
+
152
+ def convert_tokens_to_string(self, tokens):
153
+ """ Converts a sequence of tokens (string) in a single string. """
154
+ out_string = " ".join(tokens).replace(" ##", "").strip()
155
+ return out_string
156
+
157
+ def build_inputs_with_special_tokens(
158
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
159
+ ) -> List[int]:
160
+ """
161
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
162
+ by concatenating and adding special tokens.
163
+ A BERT sequence has the following format:
164
+ - single sequence: ``[CLS] X [SEP]``
165
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
166
+ Args:
167
+ token_ids_0 (:obj:`List[int]`):
168
+ List of IDs to which the special tokens will be added
169
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
170
+ Optional second list of IDs for sequence pairs.
171
+ Returns:
172
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
173
+ """
174
+ if token_ids_1 is None:
175
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
176
+ cls = [self.cls_token_id]
177
+ sep = [self.sep_token_id]
178
+ return cls + token_ids_0 + sep + token_ids_1 + sep
179
+
180
+ def get_special_tokens_mask(
181
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
182
+ ) -> List[int]:
183
+ """
184
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
185
+ special tokens using the tokenizer ``prepare_for_model`` method.
186
+ Args:
187
+ token_ids_0 (:obj:`List[int]`):
188
+ List of ids.
189
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
190
+ Optional second list of IDs for sequence pairs.
191
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
192
+ Set to True if the token list is already formatted with special tokens for the model
193
+ Returns:
194
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
195
+ """
196
+
197
+ if already_has_special_tokens:
198
+ if token_ids_1 is not None:
199
+ raise ValueError(
200
+ "You should not supply a second sequence if the provided sequence of "
201
+ "ids is already formated with special tokens for the model."
202
+ )
203
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
204
+
205
+ if token_ids_1 is not None:
206
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
207
+ return [1] + ([0] * len(token_ids_0)) + [1]
208
+
209
+ def create_token_type_ids_from_sequences(
210
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
211
+ ) -> List[int]:
212
+ """
213
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
214
+ A BERT sequence pair mask has the following format:
215
+ ::
216
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
217
+ | first sequence | second sequence |
218
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
219
+ Args:
220
+ token_ids_0 (:obj:`List[int]`):
221
+ List of ids.
222
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
223
+ Optional second list of IDs for sequence pairs.
224
+ Returns:
225
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
226
+ sequence(s).
227
+ """
228
+ sep = [self.sep_token_id]
229
+ cls = [self.cls_token_id]
230
+ if token_ids_1 is None:
231
+ return len(cls + token_ids_0 + sep) * [0]
232
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
233
+
234
+ def save_vocabulary(self, vocab_path):
235
+ """
236
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
237
+ Args:
238
+ vocab_path (:obj:`str`):
239
+ The directory in which to save the vocabulary.
240
+ Returns:
241
+ :obj:`Tuple(str)`: Paths to the files saved.
242
+ """
243
+ index = 0
244
+ vocab_file = vocab_path
245
+ with open(vocab_file, "w", encoding="utf-8") as writer:
246
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
247
+ if index != token_index:
248
+ index = token_index
249
+ writer.write(token + "\n")
250
+ index += 1
251
+ return (vocab_file,)
252
+
253
+ class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
254
+ r"""
255
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
256
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
257
+ should refer to the superclass for more information regarding methods.
258
+ Args:
259
+ vocab_file (:obj:`string`):
260
+ File containing the vocabulary.
261
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
262
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
263
+ token instead.
264
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
265
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
266
+ for sequence classification or for a text and a question for question answering.
267
+ It is also used as the last token of a sequence built with special tokens.
268
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
269
+ The token used for padding, for example when batching sequences of different lengths.
270
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
271
+ The classifier token which is used when doing sequence classification (classification of the whole
272
+ sequence instead of per-token classification). It is the first token of the sequence when built with
273
+ special tokens.
274
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
275
+ The token used for masking values. This is the token used when training this model with masked language
276
+ modeling. This is the token which the model will try to predict.
277
+ """
278
+
279
+ def __init__(
280
+ self,
281
+ vocab_file,
282
+ unk_token="[UNK]",
283
+ sep_token="[SEP]",
284
+ pad_token="[PAD]",
285
+ cls_token="[CLS]",
286
+ mask_token="[MASK]",
287
+ **kwargs
288
+ ):
289
+ super().__init__(
290
+ unk_token=unk_token,
291
+ sep_token=sep_token,
292
+ pad_token=pad_token,
293
+ cls_token=cls_token,
294
+ mask_token=mask_token,
295
+ **kwargs,
296
+ )
297
+
298
+ if not os.path.isfile(vocab_file):
299
+ raise ValueError(
300
+ "Can't find a vocabulary file at path '{}'.".format(vocab_file)
301
+ )
302
+ self.vocab = load_vocab(vocab_file)
303
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
304
+ self.tokenizer = Atomwise_Tokenizer()
305
+
306
+ @property
307
+ def vocab_size(self):
308
+ return len(self.vocab)
309
+
310
+ def get_vocab(self):
311
+ return dict(self.vocab, **self.added_tokens_encoder)
312
+
313
+
314
+ def _tokenize(self, text):
315
+ return self.tokenizer.tokenize(text)
316
+
317
+ def _convert_token_to_id(self, token):
318
+ """ Converts a token (str) in an id using the vocab. """
319
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
320
+
321
+ def _convert_id_to_token(self, index):
322
+ """Converts an index (integer) in a token (str) using the vocab."""
323
+ return self.ids_to_tokens.get(index, self.unk_token)
324
+
325
+ def convert_tokens_to_string(self, tokens):
326
+ """ Converts a sequence of tokens (string) in a single string. """
327
+ out_string = " ".join(tokens).replace(" ##", "").strip()
328
+ return out_string
329
+
330
+ def build_inputs_with_special_tokens(
331
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
332
+ ) -> List[int]:
333
+ """
334
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
335
+ by concatenating and adding special tokens.
336
+ A BERT sequence has the following format:
337
+ - single sequence: ``[CLS] X [SEP]``
338
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
339
+ Args:
340
+ token_ids_0 (:obj:`List[int]`):
341
+ List of IDs to which the special tokens will be added
342
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
343
+ Optional second list of IDs for sequence pairs.
344
+ Returns:
345
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
346
+ """
347
+ if token_ids_1 is None:
348
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
349
+ cls = [self.cls_token_id]
350
+ sep = [self.sep_token_id]
351
+ return cls + token_ids_0 + sep + token_ids_1 + sep
352
+
353
+ def get_special_tokens_mask(
354
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
355
+ ) -> List[int]:
356
+ """
357
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
358
+ special tokens using the tokenizer ``prepare_for_model`` method.
359
+ Args:
360
+ token_ids_0 (:obj:`List[int]`):
361
+ List of ids.
362
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
363
+ Optional second list of IDs for sequence pairs.
364
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
365
+ Set to True if the token list is already formatted with special tokens for the model
366
+ Returns:
367
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
368
+ """
369
+
370
+ if already_has_special_tokens:
371
+ if token_ids_1 is not None:
372
+ raise ValueError(
373
+ "You should not supply a second sequence if the provided sequence of "
374
+ "ids is already formated with special tokens for the model."
375
+ )
376
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
377
+
378
+ if token_ids_1 is not None:
379
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
380
+ return [1] + ([0] * len(token_ids_0)) + [1]
381
+
382
+ def create_token_type_ids_from_sequences(
383
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
384
+ ) -> List[int]:
385
+ """
386
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
387
+ A BERT sequence pair mask has the following format:
388
+ ::
389
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
390
+ | first sequence | second sequence |
391
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
392
+ Args:
393
+ token_ids_0 (:obj:`List[int]`):
394
+ List of ids.
395
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
396
+ Optional second list of IDs for sequence pairs.
397
+ Returns:
398
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
399
+ sequence(s).
400
+ """
401
+ sep = [self.sep_token_id]
402
+ cls = [self.cls_token_id]
403
+ if token_ids_1 is None:
404
+ return len(cls + token_ids_0 + sep) * [0]
405
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
406
+
407
+ def save_vocabulary(self, vocab_path):
408
+ """
409
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
410
+ Args:
411
+ vocab_path (:obj:`str`):
412
+ The directory in which to save the vocabulary.
413
+ Returns:
414
+ :obj:`Tuple(str)`: Paths to the files saved.
415
+ """
416
+ index = 0
417
+ vocab_file = vocab_path
418
+ with open(vocab_file, "w", encoding="utf-8") as writer:
419
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
420
+ if index != token_index:
421
+ index = token_index
422
+ writer.write(token + "\n")
423
+ index += 1
424
+ return (vocab_file,)
tokenizer/new_splits.txt ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ c 1
2
+ c 2
3
+ c 3
4
+ c 4
5
+ c 5
6
+ c 6
7
+ c 7
8
+ c 8
9
+ c 9
10
+ ( c1
11
+ ( c2
12
+ c1 )
13
+ c2 )
14
+ n 1
15
+ n 2
16
+ n 3
17
+ n 4
18
+ n 5
19
+ n 6
20
+ n 7
21
+ n 8
22
+ n 9
23
+ ( n1
24
+ ( n2
25
+ n1 )
26
+ n2 )
27
+ O 1
28
+ O 2
29
+ O 3
30
+ O 4
31
+ O 5
32
+ O 6
33
+ O 7
34
+ O 8
35
+ O 9
36
+ ( O1
37
+ ( O2
38
+ O2 )
39
+ O2 )
40
+ = O
41
+ = C
42
+ = c
43
+ = N
44
+ = n
45
+ =C C
46
+ =C N
47
+ =C c
48
+ =c c
49
+ =N C
50
+ =N c
51
+ =n C
52
+ =n c
53
+ # N
54
+ # C
55
+ #N C
56
+ #C C
57
+ #C N
58
+ #N N
59
+ ( C
60
+ C )
61
+ ( O
62
+ O )
63
+ ( N
64
+ N )
65
+ Br c
66
+ ( =O
67
+ (=O )
68
+ C (=O)
69
+ C =O
70
+ C =N
71
+ C #N
72
+ C #C
73
+ C C
74
+ CC C
75
+ CC N
76
+ CC O
77
+ CC S
78
+ CC c
79
+ CC n
80
+ C N
81
+ CN C
82
+ CN c
83
+ C O
84
+ CO C
85
+ CO N
86
+ CO c
87
+ C S
88
+ CS C
89
+ CS S
90
+ CS c
91
+ C c
92
+ Cl c
93
+ C n
94
+ F c
95
+ N C
96
+ NC C
97
+ NC c
98
+ N N
99
+ N O
100
+ N c
101
+ N n
102
+ O C
103
+ OC C
104
+ OC O
105
+ OC c
106
+ O N
107
+ O O
108
+ O c
109
+ S C
110
+ SC C
111
+ SC c
112
+ S S
113
+ S c
114
+ c c
115
+ cc c
116
+ cc n
117
+ cc o
118
+ cc s
119
+ cc cc
120
+ c n
121
+ cn c
122
+ cn n
123
+ c o
124
+ co c
125
+ c s
126
+ cs c
127
+ cs n
128
+ n c
129
+ nc c
130
+ nc n
131
+ nc o
132
+ nc s
133
+ n n
134
+ nn c
135
+ nn n
136
+ n o
137
+ no c
138
+ no n
139
+ n s
140
+ ns c
141
+ ns n
142
+ o c
143
+ oc c
144
+ o n
145
+ s c
146
+ sc c
147
+ sc n
148
+ s n
149
+ N P
150
+ P N
151
+ C P
152
+ P C
153
+ N S
154
+ S N
155
+ C S
156
+ S C
157
+ S P
158
+ P S
159
+ C I
tokenizer/new_vocab.txt ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [UNK]
3
+ [CLS]
4
+ [SEP]
5
+ [MASK]
6
+ #
7
+ %
8
+ (
9
+ )
10
+ +
11
+ -
12
+ /
13
+ 0
14
+ 1
15
+ 2
16
+ 3
17
+ 4
18
+ 5
19
+ 6
20
+ 7
21
+ 8
22
+ 9
23
+ =
24
+ @
25
+ A
26
+ B
27
+ Br
28
+ Brc
29
+ C
30
+ CC
31
+ CCC
32
+ CCN
33
+ CCO
34
+ CCS
35
+ CCc
36
+ CCn
37
+ CN
38
+ CNC
39
+ CNc
40
+ CO
41
+ COC
42
+ CON
43
+ COc
44
+ CS
45
+ CSC
46
+ CSS
47
+ CSc
48
+ Cc
49
+ Cl
50
+ Clc
51
+ Cn
52
+ F
53
+ Fc
54
+ H
55
+ I
56
+ K
57
+ L
58
+ M
59
+ N
60
+ NC
61
+ NCC
62
+ NCc
63
+ NN
64
+ NO
65
+ Nc
66
+ Nn
67
+ O
68
+ OC
69
+ OCC
70
+ OCO
71
+ OCc
72
+ ON
73
+ OO
74
+ Oc
75
+ P
76
+ R
77
+ S
78
+ SC
79
+ SCC
80
+ SCc
81
+ SS
82
+ Sc
83
+ T
84
+ X
85
+ Z
86
+ [
87
+ \\
88
+ (/
89
+ ]
90
+ a
91
+ b
92
+ c
93
+ cc
94
+ ccc
95
+ cccc
96
+ ccn
97
+ cco
98
+ ccs
99
+ cn
100
+ cnc
101
+ cnn
102
+ co
103
+ coc
104
+ cs
105
+ csc
106
+ csn
107
+ e
108
+ g
109
+ i
110
+ l
111
+ n
112
+ nc
113
+ ncc
114
+ ncn
115
+ nco
116
+ ncs
117
+ nn
118
+ nnc
119
+ nnn
120
+ no
121
+ noc
122
+ non
123
+ ns
124
+ nsc
125
+ nsn
126
+ o
127
+ oc
128
+ occ
129
+ on
130
+ p
131
+ r
132
+ s
133
+ sc
134
+ scc
135
+ scn
136
+ sn
137
+ t
138
+ c1
139
+ c2
140
+ c3
141
+ c4
142
+ c5
143
+ c6
144
+ c7
145
+ c8
146
+ c9
147
+ n1
148
+ n2
149
+ n3
150
+ n4
151
+ n5
152
+ n6
153
+ n7
154
+ n8
155
+ n9
156
+ O1
157
+ O2
158
+ O3
159
+ O4
160
+ O5
161
+ O6
162
+ O7
163
+ O8
164
+ O9
165
+ (c1
166
+ (c2
167
+ c1)
168
+ c2)
169
+ (n1
170
+ (n2
171
+ n1)
172
+ n2)
173
+ (O1
174
+ (O2
175
+ O2)
176
+ =O
177
+ =C
178
+ =c
179
+ =N
180
+ =n
181
+ =CC
182
+ =CN
183
+ =Cc
184
+ =cc
185
+ =NC
186
+ =Nc
187
+ =nC
188
+ =nc
189
+ #C
190
+ #CC
191
+ #CN
192
+ #N
193
+ #NC
194
+ #NN
195
+ (C
196
+ C)
197
+ (O
198
+ O)
199
+ (N
200
+ N)
201
+ NP
202
+ PN
203
+ CP
204
+ PC
205
+ NS
206
+ SN
207
+ SP
208
+ PS
209
+ C(=O)
210
+ (/Br)
211
+ (/C#N)
212
+ (/C)
213
+ (/C=N)
214
+ (/C=O)
215
+ (/CBr)
216
+ (/CC)
217
+ (/CCC)
218
+ (/CCF)
219
+ (/CCN)
220
+ (/CCO)
221
+ (/CCl)
222
+ (/CI)
223
+ (/CN)
224
+ (/CO)
225
+ (/CS)
226
+ (/Cl)
227
+ (/F)
228
+ (/I)
229
+ (/N)
230
+ (/NC)
231
+ (/NCC)
232
+ (/NO)
233
+ (/O)
234
+ (/OC)
235
+ (/OCC)
236
+ (/S)
237
+ (/SC)
238
+ (=C)
239
+ (=C/C)
240
+ (=C/F)
241
+ (=C/I)
242
+ (=C/N)
243
+ (=C/O)
244
+ (=CBr)
245
+ (=CC)
246
+ (=CCF)
247
+ (=CCN)
248
+ (=CCO)
249
+ (=CCl)
250
+ (=CF)
251
+ (=CI)
252
+ (=CN)
253
+ (=CO)
254
+ (=C\\C)
255
+ (=C\\F)
256
+ (=C\\I)
257
+ (=C\\N)
258
+ (=C\\O)
259
+ (=N)
260
+ (=N/C)
261
+ (=N/N)
262
+ (=N/O)
263
+ (=NBr)
264
+ (=NC)
265
+ (=NCC)
266
+ (=NCl)
267
+ (=NN)
268
+ (=NO)
269
+ (=NOC)
270
+ (=N\\C)
271
+ (=N\\N)
272
+ (=N\\O)
273
+ (=O)
274
+ (=S)
275
+ (B)
276
+ (Br)
277
+ (C#C)
278
+ (C#CC)
279
+ (C#CI)
280
+ (C#CO)
281
+ (C#N)
282
+ (C#SN)
283
+ (C)
284
+ (C=C)
285
+ (C=CF)
286
+ (C=CI)
287
+ (C=N)
288
+ (C=NN)
289
+ (C=NO)
290
+ (C=O)
291
+ (C=S)
292
+ (CBr)
293
+ (CC#C)
294
+ (CC#N)
295
+ (CC)
296
+ (CC=C)
297
+ (CC=O)
298
+ (CCBr)
299
+ (CCC)
300
+ (CCCC)
301
+ (CCCF)
302
+ (CCCI)
303
+ (CCCN)
304
+ (CCCO)
305
+ (CCCS)
306
+ (CCCl)
307
+ (CCF)
308
+ (CCI)
309
+ (CCN)
310
+ (CCNC)
311
+ (CCNN)
312
+ (CCNO)
313
+ (CCO)
314
+ (CCOC)
315
+ (CCON)
316
+ (CCS)
317
+ (CCSC)
318
+ (CCl)
319
+ (CF)
320
+ (CI)
321
+ (CN)
322
+ (CN=O)
323
+ (CNC)
324
+ (CNCC)
325
+ (CNCO)
326
+ (CNN)
327
+ (CNNC)
328
+ (CNO)
329
+ (CNOC)
330
+ (CO)
331
+ (COC)
332
+ (COCC)
333
+ (COCI)
334
+ (COCN)
335
+ (COCO)
336
+ (COF)
337
+ (CON)
338
+ (COO)
339
+ (CS)
340
+ (CSC)
341
+ (CSCC)
342
+ (CSCF)
343
+ (CSO)
344
+ (Cl)
345
+ (F)
346
+ (I)
347
+ (N)
348
+ (N=N)
349
+ (N=NO)
350
+ (N=O)
351
+ (N=S)
352
+ (NBr)
353
+ (NC#N)
354
+ (NC)
355
+ (NC=N)
356
+ (NC=O)
357
+ (NC=S)
358
+ (NCBr)
359
+ (NCC)
360
+ (NCCC)
361
+ (NCCF)
362
+ (NCCN)
363
+ (NCCO)
364
+ (NCCS)
365
+ (NCCl)
366
+ (NCNC)
367
+ (NCO)
368
+ (NCS)
369
+ (NCl)
370
+ (NN)
371
+ (NN=O)
372
+ (NNC)
373
+ (NO)
374
+ (NOC)
375
+ (O)
376
+ (OC#N)
377
+ (OC)
378
+ (OC=C)
379
+ (OC=O)
380
+ (OC=S)
381
+ (OCBr)
382
+ (OCC)
383
+ (OCCC)
384
+ (OCCF)
385
+ (OCCI)
386
+ (OCCN)
387
+ (OCCO)
388
+ (OCCS)
389
+ (OCCl)
390
+ (OCF)
391
+ (OCI)
392
+ (OCO)
393
+ (OCOC)
394
+ (OCON)
395
+ (OCSC)
396
+ (OCl)
397
+ (OI)
398
+ (ON)
399
+ (OO)
400
+ (OOC)
401
+ (OOCC)
402
+ (OOSN)
403
+ (OSC)
404
+ (P)
405
+ (S)
406
+ (SC#N)
407
+ (SC)
408
+ (SCC)
409
+ (SCCC)
410
+ (SCCF)
411
+ (SCCN)
412
+ (SCCO)
413
+ (SCCS)
414
+ (SCCl)
415
+ (SCF)
416
+ (SCN)
417
+ (SCOC)
418
+ (SCSC)
419
+ (SCl)
420
+ (SI)
421
+ (SN)
422
+ (SN=O)
423
+ (SO)
424
+ (SOC)
425
+ (SOOO)
426
+ (SS)
427
+ (SSC)
428
+ (SSCC)
429
+ ([At])
430
+ ([O-])
431
+ ([O])
432
+ ([S-])
433
+ (\\Br)
434
+ (\\C#N)
435
+ (\\C)
436
+ (\\C=N)
437
+ (\\C=O)
438
+ (\\CBr)
439
+ (\\CC)
440
+ (\\CCC)
441
+ (\\CCO)
442
+ (\\CCl)
443
+ (\\CF)
444
+ (\\CN)
445
+ (\\CNC)
446
+ (\\CO)
447
+ (\\COC)
448
+ (\\Cl)
449
+ (\\F)
450
+ (\\I)
451
+ (\\N)
452
+ (\\NC)
453
+ (\\NCC)
454
+ (\\NN)
455
+ (\\NO)
456
+ (\\NOC)
457
+ (\\O)
458
+ (\\OC)
459
+ (\\OCC)
460
+ (\\ON)
461
+ (\\S)
462
+ (\\SC)
463
+ (\\SCC)
464
+ [Ag+]
465
+ [Ag-4]
466
+ [Ag]
467
+ [Al-3]
468
+ [Al]
469
+ [As+]
470
+ [AsH3]
471
+ [AsH]
472
+ [As]
473
+ [At]
474
+ [B-]
475
+ [B@-]
476
+ [B@@-]
477
+ [BH-]
478
+ [BH2-]
479
+ [BH3-]
480
+ [B]
481
+ [Ba]
482
+ [Br+2]
483
+ [BrH]
484
+ [Br]
485
+ [C+]
486
+ [C-]
487
+ [C@@H]
488
+ [C@@]
489
+ [C@H]
490
+ [C@]
491
+ [CH-]
492
+ [CH2]
493
+ [CH3]
494
+ [CH]
495
+ [C]
496
+ [CaH2]
497
+ [Ca]
498
+ [Cl+2]
499
+ [Cl+3]
500
+ [Cl+]
501
+ [Cs]
502
+ [FH]
503
+ [F]
504
+ [H]
505
+ [He]
506
+ [I+2]
507
+ [I+3]
508
+ [I+]
509
+ [IH]
510
+ [I]
511
+ [K]
512
+ [Kr]
513
+ [Li+]
514
+ [LiH]
515
+ [MgH2]
516
+ [Mg]
517
+ [N+]
518
+ [N-]
519
+ [N@+]
520
+ [N@@+]
521
+ [N@@]
522
+ [N@]
523
+ [NH+]
524
+ [NH-]
525
+ [NH2+]
526
+ [NH3]
527
+ [NH]
528
+ [N]
529
+ [Na]
530
+ [O+]
531
+ [O-]
532
+ [OH+]
533
+ [OH2]
534
+ [OH]
535
+ [O]
536
+ [P+]
537
+ [P@+]
538
+ [P@@+]
539
+ [P@@]
540
+ [P@]
541
+ [PH2]
542
+ [PH]
543
+ [P]
544
+ [Ra]
545
+ [Rb]
546
+ [S+]
547
+ [S-]
548
+ [S@+]
549
+ [S@@+]
550
+ [S@@]
551
+ [S@]
552
+ [SH+]
553
+ [SH2]
554
+ [SH]
555
+ [S]
556
+ [Se+]
557
+ [Se-2]
558
+ [SeH2]
559
+ [SeH]
560
+ [Se]
561
+ [Si@]
562
+ [SiH2]
563
+ [SiH]
564
+ [Si]
565
+ [SrH2]
566
+ [TeH]
567
+ [Te]
568
+ [Xe]
569
+ [Zn+2]
570
+ [Zn-2]
571
+ [Zn]
572
+ [b-]
573
+ [c+]
574
+ [c-]
575
+ [cH-]
576
+ [cH]
577
+ [c]
578
+ [n+]
579
+ [n-]
580
+ [nH]
581
+ [n]
582
+ [o+]
583
+ [s+]
584
+ [se+]
585
+ [se]
586
+ [te+]
587
+ [te]
utils/app.py ADDED
@@ -0,0 +1,1287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+ from io import StringIO
5
+ import rdkit
6
+ from rdkit import Chem
7
+ from rdkit.Chem import AllChem, Draw
8
+ import numpy as np
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.patches as patches
12
+ from io import BytesIO
13
+ import tempfile
14
+ from rdkit import Chem
15
+
16
+ class PeptideAnalyzer:
17
+ def __init__(self):
18
+ self.bond_patterns = [
19
+ (r'OC\(=O\)', 'ester'), # Ester bond
20
+ (r'N\(C\)C\(=O\)', 'n_methyl'), # N-methylated peptide bond
21
+ (r'N[0-9]C\(=O\)', 'proline'), # Proline peptide bond
22
+ (r'NC\(=O\)', 'peptide'), # Standard peptide bond
23
+ (r'C\(=O\)N\(C\)', 'n_methyl_reverse'), # Reverse N-methylated
24
+ (r'C\(=O\)N[12]?', 'peptide_reverse') # Reverse peptide bond
25
+ ]
26
+ # Three to one letter code mapping
27
+ self.three_to_one = {
28
+ 'Ala': 'A', 'Cys': 'C', 'Asp': 'D', 'Glu': 'E',
29
+ 'Phe': 'F', 'Gly': 'G', 'His': 'H', 'Ile': 'I',
30
+ 'Lys': 'K', 'Leu': 'L', 'Met': 'M', 'Asn': 'N',
31
+ 'Pro': 'P', 'Gln': 'Q', 'Arg': 'R', 'Ser': 'S',
32
+ 'Thr': 'T', 'Val': 'V', 'Trp': 'W', 'Tyr': 'Y'
33
+ }
34
+
35
+ def is_amino_acid_sequence(self, seq):
36
+ """
37
+ Check if the input is a valid amino acid sequence.
38
+
39
+ Args:
40
+ seq: String to check
41
+
42
+ Returns:
43
+ bool: True if valid amino acid sequence, False otherwise
44
+ """
45
+ if not seq or not isinstance(seq, str):
46
+ return False
47
+
48
+ # Valid amino acid letters (20 standard + some common modifications)
49
+ valid_amino_acids = set('ACDEFGHIKLMNPQRSTVWY')
50
+
51
+ # Check if all characters are valid amino acids
52
+ # Allow for some special characters that might be in the sequence
53
+ seq_clean = seq.strip().upper()
54
+
55
+ # Must have at least 2 amino acids to be a peptide
56
+ if len(seq_clean) < 2:
57
+ return False
58
+
59
+ # Check if all characters are valid amino acids
60
+ return all(c in valid_amino_acids for c in seq_clean)
61
+
62
+ def is_peptide(self, smiles):
63
+ """Check if the SMILES represents a peptide structure"""
64
+ # First check if it's an amino acid sequence (not SMILES)
65
+ if self.is_amino_acid_sequence(smiles):
66
+ return True
67
+
68
+ # Otherwise check if it's a SMILES peptide
69
+ mol = Chem.MolFromSmiles(smiles)
70
+ if mol is None:
71
+ return False
72
+
73
+ # Look for peptide bonds: NC(=O) pattern
74
+ peptide_bond_pattern = Chem.MolFromSmarts('[NH][C](=O)')
75
+ if mol.HasSubstructMatch(peptide_bond_pattern):
76
+ return True
77
+
78
+ # Look for N-methylated peptide bonds: N(C)C(=O) pattern
79
+ n_methyl_pattern = Chem.MolFromSmarts('[N;H0;$(NC)](C)[C](=O)')
80
+ if mol.HasSubstructMatch(n_methyl_pattern):
81
+ return True
82
+
83
+ return False
84
+
85
+ def is_cyclic(self, smiles):
86
+ """Improved cyclic peptide detection"""
87
+ # Check for C-terminal carboxyl
88
+ if smiles.endswith('C(=O)O'):
89
+ return False, [], []
90
+
91
+ # Find all numbers used in ring closures
92
+ ring_numbers = re.findall(r'(?:^|[^c])[0-9](?=[A-Z@\(\)])', smiles)
93
+
94
+ # Find aromatic ring numbers
95
+ aromatic_matches = re.findall(r'c[0-9](?:ccccc|c\[nH\]c)[0-9]', smiles)
96
+ aromatic_cycles = []
97
+ for match in aromatic_matches:
98
+ numbers = re.findall(r'[0-9]', match)
99
+ aromatic_cycles.extend(numbers)
100
+
101
+ # Numbers that aren't part of aromatic rings are peptide cycles
102
+ peptide_cycles = [n for n in ring_numbers if n not in aromatic_cycles]
103
+
104
+ is_cyclic = len(peptide_cycles) > 0 and not smiles.endswith('C(=O)O')
105
+ return is_cyclic, peptide_cycles, aromatic_cycles
106
+
107
+ def split_on_bonds(self, smiles):
108
+ """Split SMILES into segments with simplified Pro handling"""
109
+ positions = []
110
+ used = set()
111
+
112
+ # Find Gly pattern first
113
+ gly_pattern = r'NCC\(=O\)'
114
+ for match in re.finditer(gly_pattern, smiles):
115
+ if not any(p in range(match.start(), match.end()) for p in used):
116
+ positions.append({
117
+ 'start': match.start(),
118
+ 'end': match.end(),
119
+ 'type': 'gly',
120
+ 'pattern': match.group()
121
+ })
122
+ used.update(range(match.start(), match.end()))
123
+
124
+ for pattern, bond_type in self.bond_patterns:
125
+ for match in re.finditer(pattern, smiles):
126
+ if not any(p in range(match.start(), match.end()) for p in used):
127
+ positions.append({
128
+ 'start': match.start(),
129
+ 'end': match.end(),
130
+ 'type': bond_type,
131
+ 'pattern': match.group()
132
+ })
133
+ used.update(range(match.start(), match.end()))
134
+
135
+ # Sort by position
136
+ positions.sort(key=lambda x: x['start'])
137
+
138
+ # Create segments
139
+ segments = []
140
+
141
+ if positions:
142
+ # First segment
143
+ if positions[0]['start'] > 0:
144
+ segments.append({
145
+ 'content': smiles[0:positions[0]['start']],
146
+ 'bond_after': positions[0]['pattern']
147
+ })
148
+
149
+ # Process segments
150
+ for i in range(len(positions)-1):
151
+ current = positions[i]
152
+ next_pos = positions[i+1]
153
+
154
+ if current['type'] == 'gly':
155
+ segments.append({
156
+ 'content': 'NCC(=O)',
157
+ 'bond_before': positions[i-1]['pattern'] if i > 0 else None,
158
+ 'bond_after': next_pos['pattern']
159
+ })
160
+ else:
161
+ content = smiles[current['end']:next_pos['start']]
162
+ if content:
163
+ segments.append({
164
+ 'content': content,
165
+ 'bond_before': current['pattern'],
166
+ 'bond_after': next_pos['pattern']
167
+ })
168
+
169
+ # Last segment
170
+ if positions[-1]['end'] < len(smiles):
171
+ segments.append({
172
+ 'content': smiles[positions[-1]['end']:],
173
+ 'bond_before': positions[-1]['pattern']
174
+ })
175
+
176
+ return segments
177
+
178
+ def clean_terminal_carboxyl(self, segment):
179
+ """Remove C-terminal carboxyl only if it's the true terminus"""
180
+ content = segment['content']
181
+
182
+ # Only clean if:
183
+ # 1. Contains C(=O)O
184
+ # 2. No bond_after exists (meaning it's the last segment)
185
+ # 3. C(=O)O is at the end of the content
186
+ if 'C(=O)O' in content and not segment.get('bond_after'):
187
+ print('recognized?')
188
+ # Remove C(=O)O pattern regardless of position
189
+ cleaned = re.sub(r'\(C\(=O\)O\)', '', content)
190
+ # Remove any leftover empty parentheses
191
+ cleaned = re.sub(r'\(\)', '', cleaned)
192
+ print(cleaned)
193
+ return cleaned
194
+ return content
195
+
196
+ def identify_residue(self, segment):
197
+ """Identify residue with Pro reconstruction"""
198
+ # Only clean terminal carboxyl if this is the last segment
199
+ content = self.clean_terminal_carboxyl(segment)
200
+ mods = self.get_modifications(segment)
201
+
202
+ # UAA pattern matching section - before regular residues
203
+ # Phenylglycine and derivatives
204
+ if 'c1ccccc1' in content:
205
+ if '[C@@H](c1ccccc1)' in content or '[C@H](c1ccccc1)' in content:
206
+ return '4', mods # Base phenylglycine
207
+
208
+ # 4-substituted phenylalanines
209
+ if 'Cc1ccc' in content:
210
+ if 'OMe' in content or 'OCc1ccc' in content:
211
+ return '0A1', mods # 4-methoxy-Phenylalanine
212
+ elif 'Clc1ccc' in content:
213
+ return '200', mods # 4-chloro-Phenylalanine
214
+ elif 'Brc1ccc' in content:
215
+ return '4BF', mods # 4-Bromo-phenylalanine
216
+ elif 'C#Nc1ccc' in content:
217
+ return '4CF', mods # 4-cyano-phenylalanine
218
+ elif 'Ic1ccc' in content:
219
+ return 'PHI', mods # 4-Iodo-phenylalanine
220
+ elif 'Fc1ccc' in content:
221
+ return 'PFF', mods # 4-Fluoro-phenylalanine
222
+
223
+ # Modified tryptophans
224
+ if 'c[nH]c2' in content:
225
+ if 'Oc2cccc2' in content:
226
+ return '0AF', mods # 7-hydroxy-tryptophan
227
+ elif 'Fc2cccc2' in content:
228
+ return '4FW', mods # 4-fluoro-tryptophan
229
+ elif 'Clc2cccc2' in content:
230
+ return '6CW', mods # 6-chloro-tryptophan
231
+ elif 'Brc2cccc2' in content:
232
+ return 'BTR', mods # 6-bromo-tryptophan
233
+ elif 'COc2cccc2' in content:
234
+ return 'MOT5', mods # 5-Methoxy-tryptophan
235
+ elif 'Cc2cccc2' in content:
236
+ return 'MTR5', mods # 5-Methyl-tryptophan
237
+
238
+ # Special amino acids
239
+ if 'CC(C)(C)[C@@H]' in content or 'CC(C)(C)[C@H]' in content:
240
+ return 'BUG', mods # Tertleucine
241
+
242
+ if 'CCCNC(=N)N' in content:
243
+ return 'CIR', mods # Citrulline
244
+
245
+ if '[SeH]' in content:
246
+ return 'CSE', mods # Selenocysteine
247
+
248
+ if '[NH3]CC[C@@H]' in content or '[NH3]CC[C@H]' in content:
249
+ return 'DAB', mods # Diaminobutyric acid
250
+
251
+ if 'C1CCCCC1' in content:
252
+ if 'C1CCCCC1[C@@H]' in content or 'C1CCCCC1[C@H]' in content:
253
+ return 'CHG', mods # Cyclohexylglycine
254
+ elif 'C1CCCCC1C[C@@H]' in content or 'C1CCCCC1C[C@H]' in content:
255
+ return 'ALC', mods # 3-cyclohexyl-alanine
256
+
257
+ # Naphthalene derivatives
258
+ if 'c1cccc2c1cccc2' in content:
259
+ if 'c1cccc2c1cccc2[C@@H]' in content or 'c1cccc2c1cccc2[C@H]' in content:
260
+ return 'NAL', mods # 2-Naphthyl-alanine
261
+
262
+ # Heteroaromatic derivatives
263
+ if 'c1cncc' in content:
264
+ return 'PYR4', mods # 3-(4-Pyridyl)-alanine
265
+ if 'c1cscc' in content:
266
+ return 'THA3', mods # 3-(3-thienyl)-alanine
267
+ if 'c1nnc' in content:
268
+ return 'TRZ4', mods # 3-(1,2,4-Triazol-1-yl)-alanine
269
+
270
+ # Modified serines and threonines
271
+ if 'OP(O)(O)O' in content:
272
+ if '[C@@H](COP' in content or '[C@H](COP' in content:
273
+ return 'SEP', mods # phosphoserine
274
+ elif '[C@@H](OP' in content or '[C@H](OP' in content:
275
+ return 'TPO', mods # phosphothreonine
276
+
277
+ # Specialized ring systems
278
+ if 'c1c2ccccc2cc2c1cccc2' in content:
279
+ return 'ANTH', mods # 3-(9-anthryl)-alanine
280
+ if 'c1csc2c1cccc2' in content:
281
+ return 'BTH3', mods # 3-(3-benzothienyl)-alanine
282
+ if '[C@]12C[C@H]3C[C@@H](C2)C[C@@H](C1)C3' in content:
283
+ return 'ADAM', mods # Adamanthane
284
+
285
+ # Fluorinated derivatives
286
+ if 'FC(F)(F)' in content:
287
+ if 'CC(F)(F)F' in content:
288
+ return 'FLA', mods # Trifluoro-alanine
289
+ if 'C(F)(F)F)c1' in content:
290
+ if 'c1ccccc1C(F)(F)F' in content:
291
+ return 'TFG2', mods # 2-(Trifluoromethyl)-phenylglycine
292
+ if 'c1cccc(c1)C(F)(F)F' in content:
293
+ return 'TFG3', mods # 3-(Trifluoromethyl)-phenylglycine
294
+ if 'c1ccc(cc1)C(F)(F)F' in content:
295
+ return 'TFG4', mods # 4-(Trifluoromethyl)-phenylglycine
296
+
297
+ # Multiple halogen patterns
298
+ if 'F' in content and 'c1' in content:
299
+ if 'c1ccc(c(c1)F)F' in content:
300
+ return 'F2F', mods # 3,4-Difluoro-phenylalanine
301
+ if 'cc(F)cc(c1)F' in content:
302
+ return 'WFP', mods # 3,5-Difluoro-phenylalanine
303
+ if 'Cl' in content and 'c1' in content:
304
+ if 'c1ccc(cc1Cl)Cl' in content:
305
+ return 'CP24', mods # 2,4-dichloro-phenylalanine
306
+ if 'c1ccc(c(c1)Cl)Cl' in content:
307
+ return 'CP34', mods # 3,4-dichloro-phenylalanine
308
+
309
+ # Hydroxy and amino derivatives
310
+ if 'O' in content and 'c1' in content:
311
+ if 'c1cc(O)cc(c1)O' in content:
312
+ return '3FG', mods # (2s)-amino(3,5-dihydroxyphenyl)-ethanoic acid
313
+ if 'c1ccc(c(c1)O)O' in content:
314
+ return 'DAH', mods # 3,4-Dihydroxy-phenylalanine
315
+
316
+ # Cyclic amino acids
317
+ if 'C1CCCC1' in content:
318
+ return 'CPA3', mods # 3-Cyclopentyl-alanine
319
+ if 'C1CCCCC1' in content:
320
+ if 'CC1CCCCC1' in content:
321
+ return 'ALC', mods # 3-cyclohexyl-alanine
322
+ else:
323
+ return 'CHG', mods # Cyclohexylglycine
324
+
325
+ # Chain-length variants
326
+ if 'CCC[C@@H]' in content or 'CCC[C@H]' in content:
327
+ return 'NLE', mods # Norleucine
328
+ if 'CC[C@@H]' in content or 'CC[C@H]' in content:
329
+ if not any(x in content for x in ['CC(C)', 'COC', 'CN(']):
330
+ return 'ABA', mods # 2-Aminobutyric acid
331
+
332
+ # Modified histidines
333
+ if 'c1cnc' in content:
334
+ if '[C@@H]1CN[C@@H](N1)F' in content:
335
+ return '2HF', mods # 2-fluoro-l-histidine
336
+ if 'c1cnc([nH]1)F' in content:
337
+ return '2HF1', mods # 2-fluoro-l-histidine variant
338
+ if 'c1c[nH]c(n1)F' in content:
339
+ return '2HF2', mods # 2-fluoro-l-histidine variant
340
+
341
+ # Sulfur and selenium containing
342
+ if '[SeH]' in content:
343
+ return 'CSE', mods # Selenocysteine
344
+ if 'S' in content:
345
+ if 'CSCc1ccccc1' in content:
346
+ return 'BCS', mods # benzylcysteine
347
+ if 'CCSC' in content:
348
+ return 'ESC', mods # Ethionine
349
+ if 'CCS' in content:
350
+ return 'HCS', mods # homocysteine
351
+
352
+ # Additional modifications
353
+ if 'CN=[N]=N' in content:
354
+ return 'AZDA', mods # azido-alanine
355
+ if '[NH]=[C](=[NH2])=[NH2]' in content:
356
+ if 'CCC[NH]=' in content:
357
+ return 'AGM', mods # 5-methyl-arginine
358
+ if 'CC[NH]=' in content:
359
+ return 'GDPR', mods # 2-Amino-3-guanidinopropionic acid
360
+
361
+ if 'CCON' in content:
362
+ return 'CAN', mods # canaline
363
+ if '[C@@H]1C=C[C@@H](C=C1)' in content:
364
+ return 'ACZ', mods # cis-amiclenomycin
365
+ if 'CCC(=O)[NH3]' in content:
366
+ return 'ONL', mods # 5-oxo-l-norleucine
367
+ if 'c1ccncc1' in content:
368
+ return 'PYR4', mods # 3-(4-Pyridyl)-alanine
369
+ if 'c1ccco1' in content:
370
+ return 'FUA2', mods # (2-furyl)-alanine
371
+
372
+ if 'c1ccc' in content:
373
+ if 'c1ccc(cc1)c1ccccc1' in content:
374
+ return 'BIF', mods # 4,4-biphenylalanine
375
+ if 'c1ccc(cc1)C(=O)c1ccccc1' in content:
376
+ return 'PBF', mods # 4-benzoyl-phenylalanine
377
+ if 'c1ccc(cc1)C(C)(C)C' in content:
378
+ return 'TBP4', mods # 4-tert-butyl-phenylalanine
379
+ if 'c1ccc(cc1)[C](=[NH2])=[NH2]' in content:
380
+ return '0BN', mods # 4-carbamimidoyl-l-phenylalanine
381
+ if 'c1cccc(c1)[C](=[NH2])=[NH2]' in content:
382
+ return 'APM', mods # m-amidinophenyl-3-alanine
383
+
384
+ # Multiple hydroxy patterns
385
+ if 'O' in content:
386
+ if '[C@H]([C@H](C)O)O' in content:
387
+ return 'ILX', mods # 4,5-dihydroxy-isoleucine
388
+ if '[C@H]([C@@H](C)O)O' in content:
389
+ return 'ALO', mods # Allo-threonine
390
+ if '[C@H](COP(O)(O)O)' in content:
391
+ return 'SEP', mods # phosphoserine
392
+ if '[C@H]([C@@H](C)OP(O)(O)O)' in content:
393
+ return 'TPO', mods # phosphothreonine
394
+ if '[C@H](c1ccc(O)cc1)O' in content:
395
+ return 'OMX', mods # (betar)-beta-hydroxy-l-tyrosine
396
+ if '[C@H](c1ccc(c(Cl)c1)O)O' in content:
397
+ return 'OMY', mods # (betar)-3-chloro-beta-hydroxy-l-tyrosine
398
+
399
+ # Heterocyclic patterns
400
+ if 'n1' in content:
401
+ if 'n1cccn1' in content:
402
+ return 'PYZ1', mods # 3-(1-Pyrazolyl)-alanine
403
+ if 'n1nncn1' in content:
404
+ return 'TEZA', mods # 3-(2-Tetrazolyl)-alanine
405
+ if 'c2c(n1)cccc2' in content:
406
+ return 'QU32', mods # 3-(2-Quinolyl)-alanine
407
+ if 'c1cnc2c(c1)cccc2' in content:
408
+ return 'QU33', mods # 3-(3-quinolyl)-alanine
409
+ if 'c1ccnc2c1cccc2' in content:
410
+ return 'QU34', mods # 3-(4-quinolyl)-alanine
411
+ if 'c1ccc2c(c1)nccc2' in content:
412
+ return 'QU35', mods # 3-(5-Quinolyl)-alanine
413
+ if 'c1ccc2c(c1)cncc2' in content:
414
+ return 'QU36', mods # 3-(6-Quinolyl)-alanine
415
+ if 'c1cnc2c(n1)cccc2' in content:
416
+ return 'QX32', mods # 3-(2-quinoxalyl)-alanine
417
+
418
+ # Multiple nitrogen patterns
419
+ if 'N' in content:
420
+ if '[NH3]CC[C@@H]' in content:
421
+ return 'DAB', mods # Diaminobutyric acid
422
+ if '[NH3]C[C@@H]' in content:
423
+ return 'DPP', mods # 2,3-Diaminopropanoic acid
424
+ if '[NH3]CCCCCC[C@@H]' in content:
425
+ return 'HHK', mods # (2s)-2,8-diaminooctanoic acid
426
+ if 'CCC[NH]=[C](=[NH2])=[NH2]' in content:
427
+ return 'GBUT', mods # 2-Amino-4-guanidinobutryric acid
428
+ if '[NH]=[C](=S)=[NH2]' in content:
429
+ return 'THIC', mods # Thio-citrulline
430
+
431
+ # Chain modified amino acids
432
+ if 'CC' in content:
433
+ if 'CCCC[C@@H]' in content:
434
+ return 'AHP', mods # 2-Aminoheptanoic acid
435
+ if 'CCC([C@@H])(C)C' in content:
436
+ return 'I2M', mods # 3-methyl-l-alloisoleucine
437
+ if 'CC[C@H]([C@@H])C' in content:
438
+ return 'IIL', mods # Allo-Isoleucine
439
+ if '[C@H](CCC(C)C)' in content:
440
+ return 'HLEU', mods # Homoleucine
441
+ if '[C@@H]([C@@H](C)O)C' in content:
442
+ return 'HLU', mods # beta-hydroxyleucine
443
+
444
+ # Modified glutamate/aspartate patterns
445
+ if '[C@@H]' in content:
446
+ if '[C@@H](C[C@@H](F))' in content:
447
+ return 'FGA4', mods # 4-Fluoro-glutamic acid
448
+ if '[C@@H](C[C@@H](O))' in content:
449
+ return '3GL', mods # 4-hydroxy-glutamic-acid
450
+ if '[C@@H](C[C@H](C))' in content:
451
+ return 'LME', mods # (3r)-3-methyl-l-glutamic acid
452
+ if '[C@@H](CC[C@H](C))' in content:
453
+ return 'MEG', mods # (3s)-3-methyl-l-glutamic acid
454
+
455
+ # Sulfur and selenium modifications
456
+ if 'S' in content:
457
+ if 'SCC[C@@H]' in content:
458
+ return 'HSER', mods # homoserine
459
+ if 'SCCN' in content:
460
+ return 'SLZ', mods # thialysine
461
+ if 'SC(=O)' in content:
462
+ return 'CSA', mods # s-acetonylcysteine
463
+ if '[S@@](=O)' in content:
464
+ return 'SME', mods # Methionine sulfoxide
465
+ if 'S(=O)(=O)' in content:
466
+ return 'OMT', mods # Methionine sulfone
467
+
468
+ # Double bond containing
469
+ if 'C=' in content:
470
+ if 'C=C[C@@H]' in content:
471
+ return '2AG', mods # 2-Allyl-glycine
472
+ if 'C=C[C@@H]' in content:
473
+ return 'LVG', mods # vinylglycine
474
+ if 'C=Cc1ccccc1' in content:
475
+ return 'STYA', mods # Styrylalanine
476
+
477
+ # Special cases
478
+ if '[C@@H]1Cc2c(C1)cccc2' in content:
479
+ return 'IGL', mods # alpha-amino-2-indanacetic acid
480
+ if '[C](=[C](=O)=O)=O' in content:
481
+ return '26P', mods # 2-amino-6-oxopimelic acid
482
+ if '[C](=[C](=O)=O)=C' in content:
483
+ return '2NP', mods # l-2-amino-6-methylene-pimelic acid
484
+ if 'c2cnc[nH]2' in content:
485
+ return 'HIS', mods # histidine core
486
+ if 'c1cccc2c1cc(O)cc2' in content:
487
+ return 'NAO1', mods # 5-hydroxy-1-naphthalene
488
+ if 'c1ccc2c(c1)cc(O)cc2' in content:
489
+ return 'NAO2', mods # 6-hydroxy-2-naphthalene
490
+
491
+ # Proline (P) - flexible ring numbers
492
+ if any([
493
+ # Check for any ring number in bond patterns
494
+ (segment.get('bond_after', '').startswith(f'N{n}C(=O)') and 'CCC' in content and
495
+ any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
496
+ for n in '123456789'
497
+ ]) or any([
498
+ # Check ending patterns with any ring number
499
+ (f'CCCN{n}' in content and content.endswith('=O') and
500
+ any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
501
+ for n in '123456789'
502
+ ]) or any([
503
+ # Handle CCC[C@H]n patterns
504
+ (content == f'CCC[C@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
505
+ (content == f'CCC[C@@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
506
+ # N-terminal Pro with any ring number
507
+ (f'N{n}CCC[C@H]{n}' in content) or
508
+ (f'N{n}CCC[C@@H]{n}' in content)
509
+ for n in '123456789'
510
+ ]):
511
+ return 'Pro', mods
512
+
513
+ # Tryptophan (W) - more specific indole pattern
514
+ if re.search(r'c[0-9]c\[nH\]c[0-9]ccccc[0-9][0-9]', content) and \
515
+ 'c[nH]c' in content.replace(' ', ''):
516
+ return 'Trp', mods
517
+
518
+ # Lysine (K) - both patterns
519
+ if '[C@@H](CCCCN)' in content or '[C@H](CCCCN)' in content:
520
+ return 'Lys', mods
521
+
522
+ # Arginine (R) - both patterns
523
+ if '[C@@H](CCCNC(=N)N)' in content or '[C@H](CCCNC(=N)N)' in content:
524
+ return 'Arg', mods
525
+
526
+ if ('C[C@H](CCCC)' in content or 'C[C@@H](CCCC)' in content) and 'CC(C)' not in content:
527
+ return 'Nle', mods
528
+
529
+ # Ornithine (Orn) - 3-carbon chain with NH2
530
+ if ('C[C@H](CCCN)' in content or 'C[C@@H](CCCN)' in content) and 'CC(C)' not in content:
531
+ return 'Orn', mods
532
+
533
+ # 2-Naphthylalanine (2Nal) - distinct from Phe pattern
534
+ if ('Cc3cc2ccccc2c3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
535
+ return '2Nal', mods
536
+
537
+ # Cyclohexylalanine (Cha) - already in your code but moved here for clarity
538
+ if 'N2CCCCC2' in content or 'CCCCC2' in content:
539
+ return 'Cha', mods
540
+
541
+ # Aminobutyric acid (Abu) - 2-carbon chain
542
+ if ('C[C@H](CC)' in content or 'C[C@@H](CC)' in content) and not any(p in content for p in ['CC(C)', 'CCCC', 'CCC(C)']):
543
+ return 'Abu', mods
544
+
545
+ # Pipecolic acid (Pip) - 6-membered ring like Pro
546
+ if ('N3CCCCC3' in content or 'CCCCC3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
547
+ return 'Pip', mods
548
+
549
+ # Cyclohexylglycine (Chg) - direct cyclohexyl without CH2
550
+ if ('C[C@H](C1CCCCC1)' in content or 'C[C@@H](C1CCCCC1)' in content):
551
+ return 'Chg', mods
552
+
553
+ # 4-Fluorophenylalanine (4F-Phe)
554
+ if ('Cc2ccc(F)cc2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
555
+ return '4F-Phe', mods
556
+
557
+ # Regular residue identification
558
+ if ('NCC(=O)' in content) or (content == 'C'):
559
+ # Middle case - between bonds
560
+ if segment.get('bond_before') and segment.get('bond_after'):
561
+ if ('C(=O)N' in segment['bond_before'] or 'C(=O)N(C)' in segment['bond_before']):
562
+ return 'Gly', mods
563
+ # Terminal case - at the end
564
+ elif segment.get('bond_before') and segment.get('bond_before').startswith('C(=O)N'):
565
+ return 'Gly', mods
566
+
567
+ if 'CC(C)C[C@H]' in content or 'CC(C)C[C@@H]' in content:
568
+ return 'Leu', mods
569
+ if '[C@@H](CC(C)C)' in content or '[C@H](CC(C)C)' in content:
570
+ return 'Leu', mods
571
+
572
+ if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content:
573
+ return 'Thr', mods
574
+
575
+ if '[C@H](Cc2ccccc2)' in content or '[C@@H](Cc2ccccc2)' in content:
576
+ return 'Phe', mods
577
+
578
+ if ('[C@H](C(C)C)' in content or # With outer parentheses
579
+ '[C@@H](C(C)C)' in content or # With outer parentheses
580
+ '[C@H]C(C)C' in content or # Without outer parentheses
581
+ '[C@@H]C(C)C' in content): # Without outer parentheses
582
+ if not any(p in content for p in ['CC(C)C[C@H]', 'CC(C)C[C@@H]']): # Still check not Leu
583
+ return 'Val', mods
584
+
585
+ if '[C@H](COC(C)(C)C)' in content or '[C@@H](COC(C)(C)C)' in content:
586
+ return 'O-tBu', mods
587
+
588
+ if any([
589
+ 'CC[C@H](C)' in content,
590
+ 'CC[C@@H](C)' in content,
591
+ 'C(C)C[C@H]' in content and 'CC(C)C' not in content,
592
+ 'C(C)C[C@@H]' in content and 'CC(C)C' not in content
593
+ ]):
594
+ return 'Ile', mods
595
+
596
+ if ('[C@H](C)' in content or '[C@@H](C)' in content):
597
+ if not any(p in content for p in ['C(C)C', 'COC', 'CN(', 'C(C)O', 'CC[C@H]', 'CC[C@@H]']):
598
+ return 'Ala', mods
599
+
600
+ # Tyrosine (Tyr) - 4-hydroxybenzyl side chain
601
+ if re.search(r'Cc[0-9]ccc\(O\)cc[0-9]', content):
602
+ return 'Tyr', mods
603
+
604
+
605
+ # Serine (Ser) - Hydroxymethyl side chain
606
+ if '[C@H](CO)' in content or '[C@@H](CO)' in content:
607
+ if not ('C(C)O' in content or 'COC' in content):
608
+ return 'Ser', mods
609
+
610
+ # Threonine (Thr) - 1-hydroxyethyl side chain
611
+ if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content or '[C@@H](C)O' in content or '[C@H](C)O' in content:
612
+ return 'Thr', mods
613
+
614
+ # Cysteine (Cys) - Thiol side chain
615
+ if '[C@H](CS)' in content or '[C@@H](CS)' in content:
616
+ return 'Cys', mods
617
+
618
+ # Methionine (Met) - Methylthioethyl side chain
619
+ if ('C[C@H](CCSC)' in content or 'C[C@@H](CCSC)' in content):
620
+ return 'Met', mods
621
+
622
+ # Asparagine (Asn) - Carbamoylmethyl side chain
623
+ if ('CC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
624
+ return 'Asn', mods
625
+
626
+ # Glutamine (Gln) - Carbamoylethyl side chain
627
+ if ('CCC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
628
+ return 'Gln', mods
629
+
630
+ # Aspartic acid (Asp) - Carboxymethyl side chain
631
+ if ('CC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
632
+ return 'Asp', mods
633
+
634
+ # Glutamic acid (Glu) - Carboxyethyl side chain
635
+ if ('CCC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
636
+ return 'Glu', mods
637
+
638
+ # Arginine (Arg) - 3-guanidinopropyl side chain
639
+ if ('CCCNC(=N)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
640
+ return 'Arg', mods
641
+
642
+ # Histidine (His) - Imidazole side chain
643
+ if ('Cc2cnc[nH]2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
644
+ return 'His', mods
645
+
646
+ return None, mods
647
+
648
+ def get_modifications(self, segment):
649
+ """Get modifications based on bond types"""
650
+ mods = []
651
+ if segment.get('bond_after'):
652
+ if 'N(C)' in segment['bond_after'] or segment['bond_after'].startswith('C(=O)N(C)'):
653
+ mods.append('N-Me')
654
+ if 'OC(=O)' in segment['bond_after']:
655
+ mods.append('O-linked')
656
+ return mods
657
+
658
+ def analyze_structure(self, smiles):
659
+ """Main analysis function with debug output"""
660
+ print("\nAnalyzing structure:", smiles)
661
+
662
+ # Split into segments
663
+ segments = self.split_on_bonds(smiles)
664
+
665
+ print("\nSegment Analysis:")
666
+ sequence = []
667
+ for i, segment in enumerate(segments):
668
+ print(f"\nSegment {i}:")
669
+ print(f"Content: {segment['content']}")
670
+ print(f"Bond before: {segment.get('bond_before', 'None')}")
671
+ print(f"Bond after: {segment.get('bond_after', 'None')}")
672
+
673
+ residue, mods = self.identify_residue(segment)
674
+ if residue:
675
+ if mods:
676
+ sequence.append(f"{residue}({','.join(mods)})")
677
+ else:
678
+ sequence.append(residue)
679
+ print(f"Identified as: {residue}")
680
+ print(f"Modifications: {mods}")
681
+ else:
682
+ print(f"Warning: Could not identify residue in segment: {segment['content']}")
683
+
684
+ # Check if cyclic
685
+ is_cyclic, peptide_cycles, aromatic_cycles = self.is_cyclic(smiles)
686
+ three_letter = '-'.join(sequence)
687
+ one_letter = ''.join(self.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence)
688
+
689
+ if is_cyclic:
690
+ three_letter = f"cyclo({three_letter})"
691
+ one_letter = f"cyclo({one_letter})"
692
+
693
+ print(f"\nFinal sequence: {three_letter}")
694
+ print(f"One-letter code: {one_letter}")
695
+ print(f"Is cyclic: {is_cyclic}")
696
+ #print(f"Peptide cycles: {peptide_cycles}")
697
+ #print(f"Aromatic cycles: {aromatic_cycles}")
698
+
699
+ return three_letter, len(segments)
700
+ """return {
701
+ 'three_letter': three_letter,
702
+ #'one_letter': one_letter,
703
+ 'is_cyclic': is_cyclic
704
+ }"""
705
+
706
+ def return_sequence(self, smiles):
707
+ """Main analysis function with debug output"""
708
+ print("\nAnalyzing structure:", smiles)
709
+
710
+ # Split into segments
711
+ segments = self.split_on_bonds(smiles)
712
+
713
+ print("\nSegment Analysis:")
714
+ sequence = []
715
+ for i, segment in enumerate(segments):
716
+ print(f"\nSegment {i}:")
717
+ print(f"Content: {segment['content']}")
718
+ print(f"Bond before: {segment.get('bond_before', 'None')}")
719
+ print(f"Bond after: {segment.get('bond_after', 'None')}")
720
+
721
+ residue, mods = self.identify_residue(segment)
722
+ if residue:
723
+ if mods:
724
+ sequence.append(f"{residue}({','.join(mods)})")
725
+ else:
726
+ sequence.append(residue)
727
+ print(f"Identified as: {residue}")
728
+ print(f"Modifications: {mods}")
729
+ else:
730
+ print(f"Warning: Could not identify residue in segment: {segment['content']}")
731
+
732
+ return sequence
733
+
734
+ """
735
+ def annotate_cyclic_structure(mol, sequence):
736
+ '''Create annotated 2D structure with clear, non-overlapping residue labels'''
737
+ # Generate 2D coordinates
738
+ # Generate 2D coordinates
739
+ AllChem.Compute2DCoords(mol)
740
+
741
+ # Create drawer with larger size for annotations
742
+ drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000) # Even larger size
743
+
744
+ # Get residue list and reverse it to match structural representation
745
+ if sequence.startswith('cyclo('):
746
+ residues = sequence[6:-1].split('-')
747
+ else:
748
+ residues = sequence.split('-')
749
+ residues = list(reversed(residues)) # Reverse the sequence
750
+
751
+ # Draw molecule first to get its bounds
752
+ drawer.drawOptions().addAtomIndices = False
753
+ drawer.DrawMolecule(mol)
754
+ drawer.FinishDrawing()
755
+
756
+ # Convert to PIL Image
757
+ img = Image.open(BytesIO(drawer.GetDrawingText()))
758
+ draw = ImageDraw.Draw(img)
759
+
760
+ try:
761
+ # Try to use DejaVuSans as it's commonly available on Linux systems
762
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
763
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
764
+ except OSError:
765
+ try:
766
+ # Fallback to Arial if available (common on Windows)
767
+ font = ImageFont.truetype("arial.ttf", 60)
768
+ small_font = ImageFont.truetype("arial.ttf", 60)
769
+ except OSError:
770
+ # If no TrueType fonts are available, fall back to default
771
+ print("Warning: TrueType fonts not available, using default font")
772
+ font = ImageFont.load_default()
773
+ small_font = ImageFont.load_default()
774
+ # Get molecule bounds
775
+ conf = mol.GetConformer()
776
+ positions = []
777
+ for i in range(mol.GetNumAtoms()):
778
+ pos = conf.GetAtomPosition(i)
779
+ positions.append((pos.x, pos.y))
780
+
781
+ x_coords = [p[0] for p in positions]
782
+ y_coords = [p[1] for p in positions]
783
+ min_x, max_x = min(x_coords), max(x_coords)
784
+ min_y, max_y = min(y_coords), max(y_coords)
785
+
786
+ # Calculate scaling factors
787
+ scale = 150 # Increased scale factor
788
+ center_x = 1000 # Image center
789
+ center_y = 1000
790
+
791
+ # Add residue labels in a circular arrangement around the structure
792
+ n_residues = len(residues)
793
+ radius = 700 # Distance of labels from center
794
+
795
+ # Start from the rightmost point (3 o'clock position) and go counterclockwise
796
+ # Offset by -3 positions to align with structure
797
+ offset = 0 # Adjust this value to match the structure alignment
798
+ for i, residue in enumerate(residues):
799
+ # Calculate position in a circle around the structure
800
+ # Start from 0 (3 o'clock) and go counterclockwise
801
+ angle = -(2 * np.pi * ((i + offset) % n_residues) / n_residues)
802
+
803
+ # Calculate label position
804
+ label_x = center_x + radius * np.cos(angle)
805
+ label_y = center_y + radius * np.sin(angle)
806
+
807
+ # Draw residue label
808
+ text = f"{i+1}. {residue}"
809
+ bbox = draw.textbbox((label_x, label_y), text, font=font)
810
+ padding = 10
811
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
812
+ bbox[2]+padding, bbox[3]+padding],
813
+ fill='white', outline='white')
814
+ draw.text((label_x, label_y), text,
815
+ font=font, fill='black', anchor="mm")
816
+
817
+ # Add sequence at the top with white background
818
+ seq_text = f"Sequence: {sequence}"
819
+ bbox = draw.textbbox((center_x, 100), seq_text, font=small_font)
820
+ padding = 10
821
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
822
+ bbox[2]+padding, bbox[3]+padding],
823
+ fill='white', outline='white')
824
+ draw.text((center_x, 100), seq_text,
825
+ font=small_font, fill='black', anchor="mm")
826
+
827
+ return img
828
+
829
+ """
830
+ def annotate_cyclic_structure(mol, sequence):
831
+ """Create structure visualization with just the sequence header"""
832
+ # Generate 2D coordinates
833
+ AllChem.Compute2DCoords(mol)
834
+
835
+ # Create drawer with larger size for annotations
836
+ drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000)
837
+
838
+ # Draw molecule first
839
+ drawer.drawOptions().addAtomIndices = False
840
+ drawer.DrawMolecule(mol)
841
+ drawer.FinishDrawing()
842
+
843
+ # Convert to PIL Image
844
+ img = Image.open(BytesIO(drawer.GetDrawingText()))
845
+ draw = ImageDraw.Draw(img)
846
+ try:
847
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
848
+ except OSError:
849
+ try:
850
+ small_font = ImageFont.truetype("arial.ttf", 60)
851
+ except OSError:
852
+ print("Warning: TrueType fonts not available, using default font")
853
+ small_font = ImageFont.load_default()
854
+
855
+ # Add just the sequence header at the top
856
+ seq_text = f"Sequence: {sequence}"
857
+ bbox = draw.textbbox((1000, 100), seq_text, font=small_font)
858
+ padding = 10
859
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
860
+ bbox[2]+padding, bbox[3]+padding],
861
+ fill='white', outline='white')
862
+ draw.text((1000, 100), seq_text,
863
+ font=small_font, fill='black', anchor="mm")
864
+
865
+ return img
866
+
867
+ def create_enhanced_linear_viz(sequence, smiles):
868
+ """Create an enhanced linear representation using PeptideAnalyzer"""
869
+ analyzer = PeptideAnalyzer() # Create analyzer instance
870
+
871
+ # Create figure with two subplots
872
+ fig = plt.figure(figsize=(15, 10))
873
+ gs = fig.add_gridspec(2, 1, height_ratios=[1, 2])
874
+ ax_struct = fig.add_subplot(gs[0])
875
+ ax_detail = fig.add_subplot(gs[1])
876
+
877
+ # Parse sequence and get residues
878
+ if sequence.startswith('cyclo('):
879
+ residues = sequence[6:-1].split('-')
880
+ else:
881
+ residues = sequence.split('-')
882
+
883
+ # Get segments using analyzer
884
+ segments = analyzer.split_on_bonds(smiles)
885
+
886
+ # Debug print
887
+ print(f"Number of residues: {len(residues)}")
888
+ print(f"Number of segments: {len(segments)}")
889
+
890
+ # Top subplot - Basic structure
891
+ ax_struct.set_xlim(0, 10)
892
+ ax_struct.set_ylim(0, 2)
893
+
894
+ num_residues = len(residues)
895
+ spacing = 9.0 / (num_residues - 1) if num_residues > 1 else 9.0
896
+
897
+ # Draw basic structure
898
+ y_pos = 1.5
899
+ for i in range(num_residues):
900
+ x_pos = 0.5 + i * spacing
901
+
902
+ # Draw amino acid box
903
+ rect = patches.Rectangle((x_pos-0.3, y_pos-0.2), 0.6, 0.4,
904
+ facecolor='lightblue', edgecolor='black')
905
+ ax_struct.add_patch(rect)
906
+
907
+ # Draw connecting bonds if not the last residue
908
+ if i < num_residues - 1:
909
+ segment = segments[i] if i < len(segments) else None
910
+ if segment:
911
+ # Determine bond type from segment info
912
+ bond_type = 'ester' if 'O-linked' in segment.get('bond_after', '') else 'peptide'
913
+ is_n_methylated = 'N-Me' in segment.get('bond_after', '')
914
+
915
+ bond_color = 'red' if bond_type == 'ester' else 'black'
916
+ linestyle = '--' if bond_type == 'ester' else '-'
917
+
918
+ # Draw bond line
919
+ ax_struct.plot([x_pos+0.3, x_pos+spacing-0.3], [y_pos, y_pos],
920
+ color=bond_color, linestyle=linestyle, linewidth=2)
921
+
922
+ # Add bond type label
923
+ mid_x = x_pos + spacing/2
924
+ bond_label = f"{bond_type}"
925
+ if is_n_methylated:
926
+ bond_label += "\n(N-Me)"
927
+ ax_struct.text(mid_x, y_pos+0.1, bond_label,
928
+ ha='center', va='bottom', fontsize=10,
929
+ color=bond_color)
930
+
931
+ # Add residue label
932
+ ax_struct.text(x_pos, y_pos-0.5, residues[i],
933
+ ha='center', va='top', fontsize=14)
934
+
935
+ # Bottom subplot - Detailed breakdown
936
+ ax_detail.set_ylim(0, len(segments)+1)
937
+ ax_detail.set_xlim(0, 1)
938
+
939
+ # Create detailed breakdown
940
+ segment_y = len(segments) # Start from top
941
+ for i, segment in enumerate(segments):
942
+ y = segment_y - i
943
+
944
+ # Check if this is a bond or residue
945
+ residue, mods = analyzer.identify_residue(segment)
946
+ if residue:
947
+ text = f"Residue {i+1}: {residue}"
948
+ if mods:
949
+ text += f" ({', '.join(mods)})"
950
+ color = 'blue'
951
+ else:
952
+ # Must be a bond
953
+ text = f"Bond {i}: "
954
+ if 'O-linked' in segment.get('bond_after', ''):
955
+ text += "ester"
956
+ elif 'N-Me' in segment.get('bond_after', ''):
957
+ text += "peptide (N-methylated)"
958
+ else:
959
+ text += "peptide"
960
+ color = 'red'
961
+
962
+ # Add segment analysis
963
+ ax_detail.text(0.05, y, text, fontsize=12, color=color)
964
+ ax_detail.text(0.5, y, f"SMILES: {segment.get('content', '')}", fontsize=10, color='gray')
965
+
966
+ # If cyclic, add connection indicator
967
+ if sequence.startswith('cyclo('):
968
+ ax_struct.annotate('', xy=(9.5, y_pos), xytext=(0.5, y_pos),
969
+ arrowprops=dict(arrowstyle='<->', color='red', lw=2))
970
+ ax_struct.text(5, y_pos+0.3, 'Cyclic Connection',
971
+ ha='center', color='red', fontsize=14)
972
+
973
+ # Add titles and adjust layout
974
+ ax_struct.set_title("Peptide Structure Overview", pad=20)
975
+ ax_detail.set_title("Segment Analysis Breakdown", pad=20)
976
+
977
+ # Remove axes
978
+ for ax in [ax_struct, ax_detail]:
979
+ ax.set_xticks([])
980
+ ax.set_yticks([])
981
+ ax.axis('off')
982
+
983
+ plt.tight_layout()
984
+ return fig
985
+
986
+ class PeptideStructureGenerator:
987
+ """A class to generate 3D structures of peptides using different embedding methods"""
988
+
989
+ @staticmethod
990
+ def prepare_molecule(smiles):
991
+ """Prepare molecule with proper hydrogen handling"""
992
+ mol = Chem.MolFromSmiles(smiles, sanitize=False)
993
+ if mol is None:
994
+ raise ValueError("Failed to create molecule from SMILES")
995
+
996
+ # Calculate valence for each atom
997
+ for atom in mol.GetAtoms():
998
+ atom.UpdatePropertyCache(strict=False)
999
+
1000
+ # Sanitize with reduced requirements
1001
+ Chem.SanitizeMol(mol,
1002
+ sanitizeOps=Chem.SANITIZE_FINDRADICALS|
1003
+ Chem.SANITIZE_KEKULIZE|
1004
+ Chem.SANITIZE_SETAROMATICITY|
1005
+ Chem.SANITIZE_SETCONJUGATION|
1006
+ Chem.SANITIZE_SETHYBRIDIZATION|
1007
+ Chem.SANITIZE_CLEANUPCHIRALITY)
1008
+
1009
+ mol = Chem.AddHs(mol)
1010
+ return mol
1011
+
1012
+ @staticmethod
1013
+ def get_etkdg_params(attempt=0):
1014
+ """Get ETKDG parameters with optional modifications based on attempt number"""
1015
+ params = AllChem.ETKDGv3()
1016
+ params.randomSeed = -1
1017
+ params.maxIterations = 200
1018
+ params.numThreads = 4 # Reduced for web interface
1019
+ params.useBasicKnowledge = True
1020
+ params.enforceChirality = True
1021
+ params.useExpTorsionAnglePrefs = True
1022
+ params.useSmallRingTorsions = True
1023
+ params.useMacrocycleTorsions = True
1024
+ params.ETversion = 2
1025
+ params.pruneRmsThresh = -1
1026
+ params.embedRmsThresh = 0.5
1027
+
1028
+ if attempt > 10:
1029
+ params.bondLength = 1.5 + (attempt - 10) * 0.02
1030
+ params.useExpTorsionAnglePrefs = False
1031
+
1032
+ return params
1033
+
1034
+ def generate_structure_etkdg(self, smiles, max_attempts=20):
1035
+ """Generate 3D structure using ETKDG without UFF optimization"""
1036
+ success = False
1037
+ mol = None
1038
+
1039
+ for attempt in range(max_attempts):
1040
+ try:
1041
+ mol = self.prepare_molecule(smiles)
1042
+ params = self.get_etkdg_params(attempt)
1043
+
1044
+ if AllChem.EmbedMolecule(mol, params) == 0:
1045
+ success = True
1046
+ break
1047
+ except Exception as e:
1048
+ continue
1049
+
1050
+ if not success:
1051
+ raise ValueError("Failed to generate structure with ETKDG")
1052
+
1053
+ return mol
1054
+
1055
+ def generate_structure_uff(self, smiles, max_attempts=20):
1056
+ """Generate 3D structure using ETKDG followed by UFF optimization"""
1057
+ best_mol = None
1058
+ lowest_energy = float('inf')
1059
+
1060
+ for attempt in range(max_attempts):
1061
+ try:
1062
+ test_mol = self.prepare_molecule(smiles)
1063
+ params = self.get_etkdg_params(attempt)
1064
+
1065
+ if AllChem.EmbedMolecule(test_mol, params) == 0:
1066
+ res = AllChem.UFFOptimizeMolecule(test_mol, maxIters=2000,
1067
+ vdwThresh=10.0, confId=0,
1068
+ ignoreInterfragInteractions=True)
1069
+
1070
+ if res == 0:
1071
+ ff = AllChem.UFFGetMoleculeForceField(test_mol)
1072
+ if ff:
1073
+ current_energy = ff.CalcEnergy()
1074
+ if current_energy < lowest_energy:
1075
+ lowest_energy = current_energy
1076
+ best_mol = Chem.Mol(test_mol)
1077
+ except Exception:
1078
+ continue
1079
+
1080
+ if best_mol is None:
1081
+ raise ValueError("Failed to generate optimized structure")
1082
+
1083
+ return best_mol
1084
+
1085
+ @staticmethod
1086
+ def mol_to_sdf_bytes(mol):
1087
+ """Convert RDKit molecule to SDF file bytes"""
1088
+ # First write to StringIO in text mode
1089
+ sio = StringIO()
1090
+ writer = Chem.SDWriter(sio)
1091
+ writer.write(mol)
1092
+ writer.close()
1093
+
1094
+ # Convert the string to bytes
1095
+ return sio.getvalue().encode('utf-8')
1096
+
1097
+ def process_input(smiles_input=None, file_obj=None, show_linear=False,
1098
+ show_segment_details=False, generate_3d=False, use_uff=False):
1099
+ """Process input and create visualizations using PeptideAnalyzer"""
1100
+ analyzer = PeptideAnalyzer()
1101
+ temp_dir = tempfile.mkdtemp() if generate_3d else None
1102
+ structure_files = []
1103
+
1104
+ # Handle direct SMILES input
1105
+ if smiles_input:
1106
+ smiles = smiles_input.strip()
1107
+
1108
+ # First check if it's a peptide using analyzer's method
1109
+ if not analyzer.is_peptide(smiles):
1110
+ return "Error: Input SMILES does not appear to be a peptide structure.", None, None
1111
+
1112
+ try:
1113
+ # Create molecule
1114
+ mol = Chem.MolFromSmiles(smiles)
1115
+ if mol is None:
1116
+ return "Error: Invalid SMILES notation.", None, None
1117
+
1118
+ # Generate 3D structures if requested
1119
+ if generate_3d:
1120
+ generator = PeptideStructureGenerator()
1121
+
1122
+ try:
1123
+ # Generate ETKDG structure
1124
+ mol_etkdg = generator.generate_structure_etkdg(smiles)
1125
+ etkdg_path = os.path.join(temp_dir, "structure_etkdg.sdf")
1126
+ writer = Chem.SDWriter(etkdg_path)
1127
+ writer.write(mol_etkdg)
1128
+ writer.close()
1129
+ structure_files.append(etkdg_path)
1130
+
1131
+ # Generate UFF structure if requested
1132
+ if use_uff:
1133
+ mol_uff = generator.generate_structure_uff(smiles)
1134
+ uff_path = os.path.join(temp_dir, "structure_uff.sdf")
1135
+ writer = Chem.SDWriter(uff_path)
1136
+ writer.write(mol_uff)
1137
+ writer.close()
1138
+ structure_files.append(uff_path)
1139
+
1140
+ except Exception as e:
1141
+ return f"Error generating 3D structures: {str(e)}", None, None, None
1142
+
1143
+ # Use analyzer to get sequence
1144
+ segments = analyzer.split_on_bonds(smiles)
1145
+
1146
+ # Process segments and build sequence
1147
+ sequence_parts = []
1148
+ output_text = ""
1149
+
1150
+ # Only include segment analysis in output if requested
1151
+ if show_segment_details:
1152
+ output_text += "Segment Analysis:\n"
1153
+ for i, segment in enumerate(segments):
1154
+ output_text += f"\nSegment {i}:\n"
1155
+ output_text += f"Content: {segment['content']}\n"
1156
+ output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
1157
+ output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
1158
+
1159
+ residue, mods = analyzer.identify_residue(segment)
1160
+ if residue:
1161
+ if mods:
1162
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1163
+ else:
1164
+ sequence_parts.append(residue)
1165
+ output_text += f"Identified as: {residue}\n"
1166
+ output_text += f"Modifications: {mods}\n"
1167
+ else:
1168
+ output_text += f"Warning: Could not identify residue in segment: {segment['content']}\n"
1169
+ output_text += "\n"
1170
+ else:
1171
+ # Just build sequence without detailed analysis in output
1172
+ for segment in segments:
1173
+ residue, mods = analyzer.identify_residue(segment)
1174
+ if residue:
1175
+ if mods:
1176
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1177
+ else:
1178
+ sequence_parts.append(residue)
1179
+
1180
+ # Check if cyclic using analyzer's method
1181
+ is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
1182
+ three_letter = '-'.join(sequence_parts)
1183
+ one_letter = ''.join(analyzer.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence_parts)
1184
+
1185
+ if is_cyclic:
1186
+ three_letter = f"cyclo({three_letter})"
1187
+ one_letter = f"cyclo({one_letter})"
1188
+
1189
+ # Create cyclic structure visualization
1190
+ img_cyclic = annotate_cyclic_structure(mol, three_letter)
1191
+
1192
+ # Create linear representation if requested
1193
+ img_linear = None
1194
+ if show_linear:
1195
+ fig_linear = create_enhanced_linear_viz(three_letter, smiles)
1196
+ buf = BytesIO()
1197
+ fig_linear.savefig(buf, format='png', bbox_inches='tight', dpi=300)
1198
+ buf.seek(0)
1199
+ img_linear = Image.open(buf)
1200
+ plt.close(fig_linear)
1201
+
1202
+ # Add summary to output
1203
+ summary = "Summary:\n"
1204
+ summary += f"Sequence: {three_letter}\n"
1205
+ summary += f"One-letter code: {one_letter}\n"
1206
+ summary += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
1207
+ #if is_cyclic:
1208
+ #summary += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
1209
+ #summary += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
1210
+
1211
+ if structure_files:
1212
+ summary += "\n3D Structures Generated:\n"
1213
+ for filepath in structure_files:
1214
+ summary += f"- {os.path.basename(filepath)}\n"
1215
+
1216
+ return summary + output_text, img_cyclic, img_linear, structure_files if structure_files else None
1217
+
1218
+ except Exception as e:
1219
+ return f"Error processing SMILES: {str(e)}", None, None, None
1220
+
1221
+ # Handle file input
1222
+ if file_obj is not None:
1223
+ try:
1224
+ # Handle file content
1225
+ if hasattr(file_obj, 'name'):
1226
+ with open(file_obj.name, 'r') as f:
1227
+ content = f.read()
1228
+ else:
1229
+ content = file_obj.decode('utf-8') if isinstance(file_obj, bytes) else str(file_obj)
1230
+
1231
+ output_text = ""
1232
+ for line in content.splitlines():
1233
+ smiles = line.strip()
1234
+ if smiles:
1235
+ # Check if it's a peptide
1236
+ if not analyzer.is_peptide(smiles):
1237
+ output_text += f"Skipping non-peptide SMILES: {smiles}\n"
1238
+ continue
1239
+
1240
+ # Process this SMILES
1241
+ segments = analyzer.split_on_bonds(smiles)
1242
+ sequence_parts = []
1243
+
1244
+ # Add segment details if requested
1245
+ if show_segment_details:
1246
+ output_text += f"\nSegment Analysis for SMILES: {smiles}\n"
1247
+ for i, segment in enumerate(segments):
1248
+ output_text += f"\nSegment {i}:\n"
1249
+ output_text += f"Content: {segment['content']}\n"
1250
+ output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
1251
+ output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
1252
+ residue, mods = analyzer.identify_residue(segment)
1253
+ if residue:
1254
+ if mods:
1255
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1256
+ else:
1257
+ sequence_parts.append(residue)
1258
+ output_text += f"Identified as: {residue}\n"
1259
+ output_text += f"Modifications: {mods}\n"
1260
+ else:
1261
+ for segment in segments:
1262
+ residue, mods = analyzer.identify_residue(segment)
1263
+ if residue:
1264
+ if mods:
1265
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1266
+ else:
1267
+ sequence_parts.append(residue)
1268
+
1269
+ # Get cyclicity and create sequence
1270
+ is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
1271
+ sequence = f"cyclo({'-'.join(sequence_parts)})" if is_cyclic else '-'.join(sequence_parts)
1272
+
1273
+ output_text += f"\nSummary for SMILES: {smiles}\n"
1274
+ output_text += f"Sequence: {sequence}\n"
1275
+ output_text += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
1276
+ if is_cyclic:
1277
+ output_text += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
1278
+ #output_text += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
1279
+ output_text += "-" * 50 + "\n"
1280
+
1281
+ return output_text, None, None
1282
+
1283
+ except Exception as e:
1284
+ return f"Error processing file: {str(e)}", None, None
1285
+
1286
+ return "No input provided.", None, None
1287
+
utils/timer.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time, torch
2
+ from collections import defaultdict
3
+ from contextlib import contextmanager
4
+
5
+ class StepTimer:
6
+ def __init__(self, device=None):
7
+ self.times = defaultdict(list)
8
+ self.device = device
9
+ self._use_cuda_sync = (
10
+ isinstance(device, torch.device) and device.type == "cuda"
11
+ ) or (isinstance(device, str) and "cuda" in device)
12
+
13
+ @contextmanager
14
+ def section(self, name):
15
+ if self._use_cuda_sync:
16
+ torch.cuda.synchronize()
17
+ t0 = time.perf_counter()
18
+ try:
19
+ yield
20
+ finally:
21
+ if self._use_cuda_sync:
22
+ torch.cuda.synchronize()
23
+ dt = time.perf_counter() - t0
24
+ self.times[name].append(dt)
25
+
26
+ def summary(self, top_k=None):
27
+ # returns (name, count, total, mean, p50, p95)
28
+ import numpy as np
29
+ rows = []
30
+ for k, v in self.times.items():
31
+ a = np.array(v, dtype=float)
32
+ rows.append((k, len(a), a.sum(), a.mean(), np.median(a), np.percentile(a, 95)))
33
+ rows.sort(key=lambda r: r[2], reverse=True) # by total time
34
+ return rows[:top_k] if top_k else rows
utils/utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Console logger utilities.
2
+
3
+ Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
4
+ Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
5
+ """
6
+
7
+ import logging
8
+ import fsspec
9
+ import lightning
10
+ import torch
11
+ from timm.scheduler import CosineLRScheduler
12
+ import argparse
13
+ import numpy as np
14
+ import random
15
+ import os
16
+
17
+ def sample_categorical_logits(logits, dtype=torch.float64):
18
+ # do not require logits to be log-softmaxed
19
+ gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log()
20
+ return (logits + gumbel_noise).argmax(dim=-1)
21
+
22
+ def fsspec_exists(filename):
23
+ """Check if a file exists using fsspec."""
24
+ fs, _ = fsspec.core.url_to_fs(filename)
25
+ return fs.exists(filename)
26
+
27
+
28
+ def fsspec_listdir(dirname):
29
+ """Listdir in manner compatible with fsspec."""
30
+ fs, _ = fsspec.core.url_to_fs(dirname)
31
+ return fs.ls(dirname)
32
+
33
+
34
+ def fsspec_mkdirs(dirname, exist_ok=True):
35
+ """Mkdirs in manner compatible with fsspec."""
36
+ fs, _ = fsspec.core.url_to_fs(dirname)
37
+ fs.makedirs(dirname, exist_ok=exist_ok)
38
+
39
+
40
+ def print_nans(tensor, name):
41
+ if torch.isnan(tensor).any():
42
+ print(name, tensor)
43
+
44
+
45
+ class CosineDecayWarmupLRScheduler(
46
+ CosineLRScheduler,
47
+ torch.optim.lr_scheduler._LRScheduler):
48
+
49
+ def __init__(self, *args, **kwargs):
50
+ super().__init__(*args, **kwargs)
51
+ self._last_epoch = -1
52
+ self.step(epoch=0)
53
+
54
+ def step(self, epoch=None):
55
+ if epoch is None:
56
+ self._last_epoch += 1
57
+ else:
58
+ self._last_epoch = epoch
59
+ # We call either step or step_update, depending on
60
+ # whether we're using the scheduler every epoch or every
61
+ # step.
62
+ # Otherwise, lightning will always call step (i.e.,
63
+ # meant for each epoch), and if we set scheduler
64
+ # interval to "step", then the learning rate update will
65
+ # be wrong.
66
+ if self.t_in_epochs:
67
+ super().step(epoch=self._last_epoch)
68
+ else:
69
+ super().step_update(num_updates=self._last_epoch)
70
+
71
+
72
+ class LoggingContext:
73
+ """Context manager for selective logging."""
74
+ def __init__(self, logger, level=None, handler=None, close=True):
75
+ self.logger = logger
76
+ self.level = level
77
+ self.handler = handler
78
+ self.close = close
79
+
80
+ def __enter__(self):
81
+ if self.level is not None:
82
+ self.old_level = self.logger.level
83
+ self.logger.setLevel(self.level)
84
+ if self.handler:
85
+ self.logger.addHandler(self.handler)
86
+
87
+ def __exit__(self, et, ev, tb):
88
+ if self.level is not None:
89
+ self.logger.setLevel(self.old_level)
90
+ if self.handler:
91
+ self.logger.removeHandler(self.handler)
92
+ if self.handler and self.close:
93
+ self.handler.close()
94
+
95
+
96
+ def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
97
+ """Initializes multi-GPU-friendly python logger."""
98
+
99
+ logger = logging.getLogger(name)
100
+ logger.setLevel(level)
101
+
102
+ # this ensures all logging levels get marked with the rank zero decorator
103
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
104
+ for level in ('debug', 'info', 'warning', 'error',
105
+ 'exception', 'fatal', 'critical'):
106
+ setattr(logger,
107
+ level,
108
+ lightning.pytorch.utilities.rank_zero_only(
109
+ getattr(logger, level)))
110
+
111
+ return logger
112
+
113
+
114
+ def str2bool(v):
115
+ if isinstance(v, bool):
116
+ return v
117
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
118
+ return True
119
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
120
+ return False
121
+ else:
122
+ raise argparse.ArgumentTypeError('Boolean value expected.')
123
+
124
+
125
+ def set_seed(seed, use_cuda):
126
+ os.environ['PYTHONHASHSEED'] = str(seed)
127
+ np.random.seed(seed)
128
+ random.seed(seed)
129
+ torch.manual_seed(seed)
130
+ # torch.backends.cudnn.deterministic = True
131
+ if use_cuda:
132
+ torch.cuda.manual_seed(seed)
133
+ torch.cuda.manual_seed_all(seed)
134
+ print(f'=> Seed of the run set to {seed}')
135
+