Upload TD3B code (inference, training, baselines)
Browse files- .gitattributes +1 -0
- README.md +128 -3
- baselines/__init__.py +19 -0
- baselines/baselines.py +746 -0
- baselines/run.sh +77 -0
- baselines/run_mcts_tr2d2.py +421 -0
- baselines/run_validation_td3b.py +548 -0
- baselines/sampling_setup.py +538 -0
- configs/finetune_config.py +122 -0
- configs/peptune_config.yaml +159 -0
- diffusion.py +1588 -0
- distributed_utils.py +28 -0
- env.yml +37 -0
- finetune_multi_target.py +1061 -0
- finetune_utils.py +571 -0
- inference.py +253 -0
- launch_multi_target.sh +175 -0
- noise_schedule.py +150 -0
- peptide_mcts.py +676 -0
- roformer.py +74 -0
- scoring/functions/binding.py +482 -0
- scoring/functions/classifiers/hemolysis-xgboost.json +0 -0
- scoring/functions/classifiers/nonfouling-xgboost.json +0 -0
- scoring/functions/classifiers/permeability-xgboost.json +3 -0
- scoring/functions/classifiers/solubility-xgboost.json +0 -0
- scoring/functions/hemolysis.py +63 -0
- scoring/functions/nonfouling.py +66 -0
- scoring/functions/permeability.py +171 -0
- scoring/functions/solubility.py +63 -0
- scoring/scoring_functions.py +104 -0
- setup.py +9 -0
- td3b/__init__.py +30 -0
- td3b/data_utils.py +392 -0
- td3b/direction_oracle.py +709 -0
- td3b/td3b_finetune.py +604 -0
- td3b/td3b_losses.py +527 -0
- td3b/td3b_mcts.py +307 -0
- td3b/td3b_scoring.py +400 -0
- tokenizer/my_tokenizers.py +424 -0
- tokenizer/new_splits.txt +159 -0
- tokenizer/new_vocab.txt +587 -0
- utils/app.py +1287 -0
- utils/timer.py +34 -0
- utils/utils.py +135 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
scoring/functions/classifiers/permeability-xgboost.json filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,128 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TD3B: Transition-Directed Discrete Diffusion for Allosteric Binder Generation
|
| 2 |
+
|
| 3 |
+
TD3B is a sequence-based generative framework that designs peptide binders with specified agonist or antagonist behavior. It combines a Direction Oracle, a soft binding-affinity gate, and amortized fine-tuning of a pre-trained discrete diffusion model (MDLM).
|
| 4 |
+
|
| 5 |
+
## Installation
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
conda env create -f env.yml
|
| 9 |
+
conda activate td3b
|
| 10 |
+
pip install -e .
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
## Data and Checkpoints
|
| 14 |
+
|
| 15 |
+
Download the pretrained checkpoints and data from [Google Drive (TBA)](placeholder_link).
|
| 16 |
+
|
| 17 |
+
Place the files as follows:
|
| 18 |
+
|
| 19 |
+
```
|
| 20 |
+
TD3B/
|
| 21 |
+
├── checkpoints/
|
| 22 |
+
│ ├── pretrained.ckpt # Pre-trained MDLM weights
|
| 23 |
+
│ ├── td3b.ckpt # Fine-tuned TD3B model
|
| 24 |
+
│ └── direction_oracle.pt # Direction Oracle weights
|
| 25 |
+
├── data/
|
| 26 |
+
│ ├── train.csv # Training set (target-binder pairs)
|
| 27 |
+
│ └── test.csv # Test set
|
| 28 |
+
├── scoring/functions/classifiers/
|
| 29 |
+
│ ├── binding-affinity.pt
|
| 30 |
+
│ ├── hemolysis-xgboost.json
|
| 31 |
+
│ ├── nonfouling-xgboost.json
|
| 32 |
+
│ ├── permeability-xgboost.json
|
| 33 |
+
│ └── solubility-xgboost.json
|
| 34 |
+
└── tokenizer/
|
| 35 |
+
├── new_vocab.txt
|
| 36 |
+
└── new_splits.txt
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
## Code Structure
|
| 40 |
+
|
| 41 |
+
```
|
| 42 |
+
TD3B/
|
| 43 |
+
├── inference.py # Generate binders (main inference entry point)
|
| 44 |
+
├── finetune_multi_target.py # Multi-target TD3B training
|
| 45 |
+
├── finetune_utils.py # Training utilities
|
| 46 |
+
├── launch_multi_target.sh # Training launcher script
|
| 47 |
+
├── diffusion.py # MDLM backbone (TR2-D2)
|
| 48 |
+
├── roformer.py # RoFormer wrapper
|
| 49 |
+
├── noise_schedule.py # Noise schedules
|
| 50 |
+
├── peptide_mcts.py # MCTS tree search
|
| 51 |
+
├── td3b/
|
| 52 |
+
│ ├── direction_oracle.py # Direction Oracle (f_φ)
|
| 53 |
+
│ ├── td3b_scoring.py # Gated reward R = g_ψ · σ(d*·(f_φ−0.5)/τ)
|
| 54 |
+
│ ├── td3b_losses.py # L_WDCE + λ·L_ctr + β·L_KL
|
| 55 |
+
│ ├── td3b_mcts.py # TD3B-extended MCTS
|
| 56 |
+
│ ├── td3b_finetune.py # Training loop
|
| 57 |
+
│ └── data_utils.py # Data loading utilities
|
| 58 |
+
├── scoring/ # Affinity predictor (g_ψ) and property classifiers
|
| 59 |
+
├── baselines/ # CG, SMC, TDS, PepTune, Unguided baselines
|
| 60 |
+
├── tokenizer/ # SMILES tokenizer (vocab + splits)
|
| 61 |
+
├── configs/ # Model and training configs
|
| 62 |
+
└── utils/ # Misc utilities
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Inference
|
| 66 |
+
|
| 67 |
+
Generate agonist/antagonist binders for target proteins:
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
python inference.py \
|
| 71 |
+
--ckpt_path checkpoints/td3b.ckpt \
|
| 72 |
+
--val_csv data/test.csv \
|
| 73 |
+
--save_path results/ \
|
| 74 |
+
--seed 42 \
|
| 75 |
+
--num_pool 32 \
|
| 76 |
+
--val_samples_per_target 8 \
|
| 77 |
+
--resample_alpha 0.1
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
This generates 32 candidates per (target, direction), scores them with the Direction Oracle and affinity predictor, applies Algorithm 2 weighted resampling, and saves only valid peptide samples.
|
| 81 |
+
|
| 82 |
+
Output: `results/td3b_results_seed42.csv` with columns: target, sequence, direction, affinity, gated_reward, direction_oracle, direction_accuracy.
|
| 83 |
+
|
| 84 |
+
## Training
|
| 85 |
+
|
| 86 |
+
### Multi-target TD3B
|
| 87 |
+
|
| 88 |
+
1. Edit `launch_multi_target.sh` — set paths to checkpoints, data, and oracle:
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
BASE_PATH="/path/to/TD3B"
|
| 92 |
+
PRETRAINED_CHECKPOINT="${BASE_PATH}/checkpoints/pretrained.ckpt"
|
| 93 |
+
TRAIN_CSV="${BASE_PATH}/data/train.csv"
|
| 94 |
+
ORACLE_CKPT="${BASE_PATH}/checkpoints/direction_oracle.pt"
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
2. Launch training:
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
bash launch_multi_target.sh
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
Key hyperparameters (in `launch_multi_target.sh`):
|
| 104 |
+
- `CONTRASTIVE_WEIGHT=0.1` — λ for L_ctr
|
| 105 |
+
- `KL_BETA=0.1` — β for L_KL
|
| 106 |
+
- `SIGMOID_TEMPERATURE=0.1` — τ for gated reward
|
| 107 |
+
- `NUM_ITER=20` — MCTS iterations per round
|
| 108 |
+
- `NUM_CHILDREN=16` — Children per MCTS expansion
|
| 109 |
+
|
| 110 |
+
### Baselines
|
| 111 |
+
|
| 112 |
+
Run baseline methods (CG, SMC, TDS, PepTune, Unguided):
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
cd baselines/
|
| 116 |
+
bash run.sh --baseline cg --device cuda:0
|
| 117 |
+
bash run.sh --baseline smc --device cuda:0
|
| 118 |
+
bash run.sh --baseline tds --device cuda:0
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
## Citation
|
| 122 |
+
|
| 123 |
+
```bibtex
|
| 124 |
+
@article{caotd3b,
|
| 125 |
+
title={TD3B: Transition-Directed Discrete Diffusion for Allosteric Binder Generation},
|
| 126 |
+
author={Cao, Hanqun and Pal, Aastha and Tang, Sophia and Zhang, Yinuo and Zhang, Jingjie and Heng, Pheng-Ann and Chatterjee, Pranam}
|
| 127 |
+
}
|
| 128 |
+
```
|
baselines/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from baselines.baselines import (
|
| 2 |
+
RewardInputs,
|
| 3 |
+
RewardWrapper,
|
| 4 |
+
classifier_guidance,
|
| 5 |
+
peptune_mctg_sampling,
|
| 6 |
+
unguided_sampling,
|
| 7 |
+
sequential_monte_carlo,
|
| 8 |
+
twisted_diffusion_sampler,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"RewardInputs",
|
| 13 |
+
"RewardWrapper",
|
| 14 |
+
"classifier_guidance",
|
| 15 |
+
"peptune_mctg_sampling",
|
| 16 |
+
"unguided_sampling",
|
| 17 |
+
"sequential_monte_carlo",
|
| 18 |
+
"twisted_diffusion_sampler",
|
| 19 |
+
]
|
baselines/baselines.py
ADDED
|
@@ -0,0 +1,746 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from types import SimpleNamespace
|
| 5 |
+
from typing import Callable, Dict, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
DEFAULT_EPS = 1e-5
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _sample_categorical(categorical_probs: torch.Tensor) -> torch.Tensor:
|
| 17 |
+
gumbel = 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()
|
| 18 |
+
return (categorical_probs / gumbel).argmax(dim=-1).to(dtype=torch.long)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _normalize_probs(probs: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
| 22 |
+
return probs / probs.sum(dim=dim, keepdim=True).clamp_min(1e-12)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _safe_resample_weights(weights: torch.Tensor) -> torch.Tensor:
|
| 26 |
+
if weights.numel() == 0:
|
| 27 |
+
return weights
|
| 28 |
+
weights = torch.where(torch.isfinite(weights), weights, torch.zeros_like(weights))
|
| 29 |
+
total = weights.sum()
|
| 30 |
+
if not torch.isfinite(total) or total <= 0:
|
| 31 |
+
return torch.full_like(weights, 1.0 / weights.numel())
|
| 32 |
+
return weights / total
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _sequence_logprob(
|
| 36 |
+
probs: torch.Tensor,
|
| 37 |
+
x_next: torch.Tensor,
|
| 38 |
+
x_current: torch.Tensor,
|
| 39 |
+
mask_idx: int,
|
| 40 |
+
) -> torch.Tensor:
|
| 41 |
+
gather = probs.gather(-1, x_next.unsqueeze(-1)).squeeze(-1).clamp_min(1e-12)
|
| 42 |
+
mask = (x_current == mask_idx).to(gather.dtype)
|
| 43 |
+
return (gather.log() * mask).sum(dim=-1)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _transition_probs_from_logits(
|
| 47 |
+
log_probs: torch.Tensor,
|
| 48 |
+
t: torch.Tensor,
|
| 49 |
+
dt: torch.Tensor,
|
| 50 |
+
mask_idx: int,
|
| 51 |
+
) -> torch.Tensor:
|
| 52 |
+
change_prob_t = t[:, None, None]
|
| 53 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 54 |
+
q_xs = log_probs.exp() * (change_prob_t - change_prob_s)
|
| 55 |
+
q_xs[:, :, mask_idx] = change_prob_s[:, :, 0]
|
| 56 |
+
return q_xs
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _sample_from_q(
|
| 60 |
+
q_probs: torch.Tensor,
|
| 61 |
+
x_current: torch.Tensor,
|
| 62 |
+
mask_idx: int,
|
| 63 |
+
) -> torch.Tensor:
|
| 64 |
+
x_changed = _sample_categorical(q_probs)
|
| 65 |
+
copy_flag = (x_current != mask_idx)
|
| 66 |
+
return torch.where(copy_flag, x_current, x_changed)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _protein_tokens_to_device(tokens: torch.Tensor, device: torch.device) -> torch.Tensor:
|
| 70 |
+
if tokens.device != device:
|
| 71 |
+
return tokens.to(device)
|
| 72 |
+
return tokens
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _tokens_to_one_hot(tokens: torch.Tensor, vocab_size: int) -> torch.Tensor:
|
| 76 |
+
return F.one_hot(tokens, num_classes=vocab_size).float()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _decode_sequences(tokenizer, token_ids: torch.Tensor) -> list:
|
| 80 |
+
return tokenizer.batch_decode(token_ids)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _affinity_from_scoring(
|
| 84 |
+
scoring_fn: Callable,
|
| 85 |
+
sequences: list,
|
| 86 |
+
device: torch.device,
|
| 87 |
+
protein_seq: Optional[str] = None,
|
| 88 |
+
) -> torch.Tensor:
|
| 89 |
+
if protein_seq is not None:
|
| 90 |
+
try:
|
| 91 |
+
scores = scoring_fn(sequences, protein_seq)
|
| 92 |
+
except TypeError:
|
| 93 |
+
try:
|
| 94 |
+
scores = scoring_fn(sequences, prot_seq=protein_seq)
|
| 95 |
+
except TypeError:
|
| 96 |
+
scores = scoring_fn(sequences)
|
| 97 |
+
else:
|
| 98 |
+
scores = scoring_fn(sequences)
|
| 99 |
+
if isinstance(scores, tuple):
|
| 100 |
+
scores = scores[0]
|
| 101 |
+
scores = np.asarray(scores)
|
| 102 |
+
if scores.ndim == 1:
|
| 103 |
+
affinity = scores
|
| 104 |
+
else:
|
| 105 |
+
affinity = scores[:, 0]
|
| 106 |
+
return torch.as_tensor(affinity, device=device, dtype=torch.float32)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _roformer_hidden_from_inputs(
|
| 110 |
+
base_model,
|
| 111 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 112 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 113 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 114 |
+
) -> torch.Tensor:
|
| 115 |
+
outputs = base_model.backbone.model(
|
| 116 |
+
input_ids=input_ids,
|
| 117 |
+
inputs_embeds=inputs_embeds,
|
| 118 |
+
attention_mask=attn_mask,
|
| 119 |
+
output_hidden_states=True,
|
| 120 |
+
return_dict=True,
|
| 121 |
+
)
|
| 122 |
+
return outputs.hidden_states[-1]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _logits_from_inputs(
|
| 126 |
+
base_model,
|
| 127 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 128 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 129 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 130 |
+
) -> torch.Tensor:
|
| 131 |
+
outputs = base_model.backbone.model(
|
| 132 |
+
input_ids=input_ids,
|
| 133 |
+
inputs_embeds=inputs_embeds,
|
| 134 |
+
attention_mask=attn_mask,
|
| 135 |
+
output_hidden_states=False,
|
| 136 |
+
return_dict=True,
|
| 137 |
+
)
|
| 138 |
+
return outputs.logits
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@dataclass
|
| 142 |
+
class RewardInputs:
|
| 143 |
+
protein_tokens: torch.Tensor
|
| 144 |
+
d_star: float
|
| 145 |
+
protein_seq: str
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class RewardWrapper:
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
scoring_fn: Callable,
|
| 152 |
+
direction_oracle: torch.nn.Module,
|
| 153 |
+
base_model,
|
| 154 |
+
tokenizer,
|
| 155 |
+
reward_inputs: RewardInputs,
|
| 156 |
+
device: torch.device,
|
| 157 |
+
fast_direction: bool = False,
|
| 158 |
+
reward_alpha: float = 0.1,
|
| 159 |
+
):
|
| 160 |
+
self.scoring_fn = scoring_fn
|
| 161 |
+
self.direction_oracle = direction_oracle
|
| 162 |
+
self.base_model = base_model
|
| 163 |
+
self.tokenizer = tokenizer
|
| 164 |
+
self.reward_inputs = reward_inputs
|
| 165 |
+
self.device = device
|
| 166 |
+
self.fast_direction = fast_direction
|
| 167 |
+
self.reward_alpha = reward_alpha
|
| 168 |
+
self._supports_hidden_direction = all(
|
| 169 |
+
hasattr(direction_oracle, attr)
|
| 170 |
+
for attr in ("protein_embedder", "fusion", "classifier")
|
| 171 |
+
)
|
| 172 |
+
self._supports_predict = hasattr(direction_oracle, "predict_with_confidence")
|
| 173 |
+
if self.fast_direction and not self._supports_hidden_direction:
|
| 174 |
+
logger.warning("fast_direction requested but oracle lacks hidden-direction modules; disabling fast_direction.")
|
| 175 |
+
self.fast_direction = False
|
| 176 |
+
self._protein_emb_cache = None
|
| 177 |
+
if self.reward_inputs.protein_seq is None:
|
| 178 |
+
raise ValueError("RewardInputs.protein_seq is required for conditioned sampling.")
|
| 179 |
+
|
| 180 |
+
def _protein_emb(self, batch_size: int) -> torch.Tensor:
|
| 181 |
+
if not self._supports_hidden_direction:
|
| 182 |
+
raise RuntimeError("direction_oracle does not support hidden-direction inference.")
|
| 183 |
+
if self._protein_emb_cache is None:
|
| 184 |
+
prot_tokens = _protein_tokens_to_device(self.reward_inputs.protein_tokens, self.device)
|
| 185 |
+
prot_emb = self.direction_oracle.protein_embedder(prot_tokens)
|
| 186 |
+
self._protein_emb_cache = prot_emb
|
| 187 |
+
return self._protein_emb_cache.expand(batch_size, -1)
|
| 188 |
+
|
| 189 |
+
def _direction_from_hidden(
|
| 190 |
+
self,
|
| 191 |
+
hidden: torch.Tensor,
|
| 192 |
+
attn_mask: torch.Tensor,
|
| 193 |
+
) -> torch.Tensor:
|
| 194 |
+
if not self._supports_hidden_direction:
|
| 195 |
+
raise RuntimeError("direction_oracle does not support hidden-direction inference.")
|
| 196 |
+
mask = attn_mask.to(hidden.dtype).unsqueeze(-1)
|
| 197 |
+
pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0)
|
| 198 |
+
protein_emb = self._protein_emb(pooled.size(0))
|
| 199 |
+
fused = self.direction_oracle.fusion(pooled, protein_emb)
|
| 200 |
+
return self.direction_oracle.classifier(fused).squeeze(-1)
|
| 201 |
+
|
| 202 |
+
def _direction_from_probs(
|
| 203 |
+
self,
|
| 204 |
+
y_probs: torch.Tensor,
|
| 205 |
+
attn_mask: torch.Tensor,
|
| 206 |
+
) -> torch.Tensor:
|
| 207 |
+
if hasattr(self.direction_oracle, "predict_from_probs"):
|
| 208 |
+
prot_tokens = _protein_tokens_to_device(self.reward_inputs.protein_tokens, self.device)
|
| 209 |
+
return self.direction_oracle.predict_from_probs(y_probs, prot_tokens, attn_mask)
|
| 210 |
+
if not self._supports_hidden_direction:
|
| 211 |
+
token_ids = y_probs.argmax(dim=-1)
|
| 212 |
+
return self._direction_from_tokens(token_ids)
|
| 213 |
+
if self.fast_direction:
|
| 214 |
+
emb_weight = self.base_model.backbone.model.roformer.embeddings.word_embeddings.weight
|
| 215 |
+
inputs_embeds = y_probs @ emb_weight
|
| 216 |
+
hidden = inputs_embeds
|
| 217 |
+
else:
|
| 218 |
+
emb_weight = self.base_model.backbone.model.roformer.embeddings.word_embeddings.weight
|
| 219 |
+
inputs_embeds = y_probs @ emb_weight
|
| 220 |
+
hidden = _roformer_hidden_from_inputs(
|
| 221 |
+
self.base_model,
|
| 222 |
+
inputs_embeds=inputs_embeds,
|
| 223 |
+
attn_mask=attn_mask,
|
| 224 |
+
)
|
| 225 |
+
return self._direction_from_hidden(hidden, attn_mask)
|
| 226 |
+
|
| 227 |
+
def _direction_from_tokens(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 228 |
+
prot_tokens = _protein_tokens_to_device(self.reward_inputs.protein_tokens, self.device)
|
| 229 |
+
if prot_tokens.dim() == 2 and prot_tokens.size(0) == 1:
|
| 230 |
+
prot_tokens = prot_tokens.expand(token_ids.size(0), -1)
|
| 231 |
+
if self._supports_predict:
|
| 232 |
+
direction, _ = self.direction_oracle.predict_with_confidence(token_ids, prot_tokens)
|
| 233 |
+
return direction
|
| 234 |
+
return self.direction_oracle(token_ids, prot_tokens)
|
| 235 |
+
|
| 236 |
+
def _gated_reward(self, affinity: torch.Tensor, direction: torch.Tensor) -> torch.Tensor:
|
| 237 |
+
d_star = torch.as_tensor(self.reward_inputs.d_star, device=self.device, dtype=direction.dtype)
|
| 238 |
+
directional_score = (direction - 0.5) * d_star
|
| 239 |
+
gate = torch.sigmoid(directional_score / self.reward_alpha)
|
| 240 |
+
return affinity * gate
|
| 241 |
+
|
| 242 |
+
def evaluate_tokens(self, token_ids: torch.Tensor, attn_mask: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 243 |
+
sequences = _decode_sequences(self.tokenizer, token_ids)
|
| 244 |
+
affinity = _affinity_from_scoring(
|
| 245 |
+
self.scoring_fn,
|
| 246 |
+
sequences,
|
| 247 |
+
self.device,
|
| 248 |
+
protein_seq=self.reward_inputs.protein_seq,
|
| 249 |
+
)
|
| 250 |
+
with torch.no_grad():
|
| 251 |
+
direction = self._direction_from_tokens(token_ids)
|
| 252 |
+
gated_reward = self._gated_reward(affinity, direction)
|
| 253 |
+
return {
|
| 254 |
+
"sequences": sequences,
|
| 255 |
+
"affinity": affinity,
|
| 256 |
+
"direction": direction,
|
| 257 |
+
"gated_reward": gated_reward,
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
def reward_from_tokens(
|
| 261 |
+
self,
|
| 262 |
+
token_ids: torch.Tensor,
|
| 263 |
+
attn_mask: torch.Tensor,
|
| 264 |
+
) -> torch.Tensor:
|
| 265 |
+
sequences = _decode_sequences(self.tokenizer, token_ids)
|
| 266 |
+
affinity = _affinity_from_scoring(
|
| 267 |
+
self.scoring_fn,
|
| 268 |
+
sequences,
|
| 269 |
+
self.device,
|
| 270 |
+
protein_seq=self.reward_inputs.protein_seq,
|
| 271 |
+
)
|
| 272 |
+
with torch.no_grad():
|
| 273 |
+
direction = self._direction_from_tokens(token_ids)
|
| 274 |
+
return self._gated_reward(affinity, direction)
|
| 275 |
+
|
| 276 |
+
def reward_from_probs(
|
| 277 |
+
self,
|
| 278 |
+
y_probs: torch.Tensor,
|
| 279 |
+
token_ids_for_affinity: torch.Tensor,
|
| 280 |
+
attn_mask: torch.Tensor,
|
| 281 |
+
) -> torch.Tensor:
|
| 282 |
+
affinity = None
|
| 283 |
+
if hasattr(self.scoring_fn, "forward_from_probs"):
|
| 284 |
+
try:
|
| 285 |
+
affinity = self.scoring_fn.forward_from_probs(
|
| 286 |
+
y_probs,
|
| 287 |
+
attn_mask,
|
| 288 |
+
prot_seq=self.reward_inputs.protein_seq,
|
| 289 |
+
)
|
| 290 |
+
except Exception as exc:
|
| 291 |
+
logger.warning("Differentiable affinity failed; falling back to argmax. Error: %s", exc)
|
| 292 |
+
affinity = None
|
| 293 |
+
if affinity is None:
|
| 294 |
+
sequences = _decode_sequences(self.tokenizer, token_ids_for_affinity)
|
| 295 |
+
affinity = _affinity_from_scoring(
|
| 296 |
+
self.scoring_fn,
|
| 297 |
+
sequences,
|
| 298 |
+
self.device,
|
| 299 |
+
protein_seq=self.reward_inputs.protein_seq,
|
| 300 |
+
)
|
| 301 |
+
direction = self._direction_from_probs(y_probs, attn_mask)
|
| 302 |
+
return self._gated_reward(affinity, direction)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class PepTuneSampler:
|
| 306 |
+
def __init__(
|
| 307 |
+
self,
|
| 308 |
+
base_model,
|
| 309 |
+
reward_fn: RewardWrapper,
|
| 310 |
+
seq_length: int,
|
| 311 |
+
num_steps: int,
|
| 312 |
+
mcts_iterations: int,
|
| 313 |
+
num_children: int,
|
| 314 |
+
sample_prob_weight: float,
|
| 315 |
+
invalid_penalty: float,
|
| 316 |
+
pareto_max_size: Optional[int],
|
| 317 |
+
eps: float,
|
| 318 |
+
):
|
| 319 |
+
from peptide_mcts import Node, updateParetoFront
|
| 320 |
+
from utils.app import PeptideAnalyzer
|
| 321 |
+
|
| 322 |
+
self.base_model = base_model
|
| 323 |
+
self.reward_fn = reward_fn
|
| 324 |
+
self.seq_length = seq_length
|
| 325 |
+
self.num_steps = num_steps
|
| 326 |
+
self.mcts_iterations = mcts_iterations
|
| 327 |
+
self.num_children = num_children
|
| 328 |
+
self.sample_prob_weight = sample_prob_weight
|
| 329 |
+
self.invalid_penalty = invalid_penalty
|
| 330 |
+
self.pareto_max_size = pareto_max_size
|
| 331 |
+
self.eps = eps
|
| 332 |
+
|
| 333 |
+
self.device = base_model.device
|
| 334 |
+
self.mask_idx = base_model.mask_index
|
| 335 |
+
self.tokenizer = base_model.tokenizer
|
| 336 |
+
self.analyzer = PeptideAnalyzer()
|
| 337 |
+
self.Node = Node
|
| 338 |
+
self.updateParetoFront = updateParetoFront
|
| 339 |
+
|
| 340 |
+
self.timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
|
| 341 |
+
self.dt = torch.as_tensor((1 - eps) / num_steps, device=self.device)
|
| 342 |
+
self.args = SimpleNamespace(
|
| 343 |
+
num_obj=1,
|
| 344 |
+
total_num_steps=num_steps,
|
| 345 |
+
seq_length=seq_length,
|
| 346 |
+
num_children=num_children,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
def _init_root(self):
|
| 350 |
+
masked_seq = torch.full((self.seq_length,), self.mask_idx, device=self.device, dtype=torch.long)
|
| 351 |
+
attn_mask = torch.ones_like(masked_seq, device=self.device)
|
| 352 |
+
tokens = {"seqs": masked_seq, "attention_mask": attn_mask}
|
| 353 |
+
return self.Node(
|
| 354 |
+
args=self.args,
|
| 355 |
+
tokens=tokens,
|
| 356 |
+
log_rnd=torch.zeros((), device=self.device),
|
| 357 |
+
log_policy_step=torch.zeros((), device=self.device),
|
| 358 |
+
log_pretrained_step=torch.zeros((), device=self.device),
|
| 359 |
+
totalReward=np.zeros(self.args.num_obj),
|
| 360 |
+
timestep=0,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
def _select(self, root):
|
| 364 |
+
node = root
|
| 365 |
+
while True:
|
| 366 |
+
node, status = node.selectNode()
|
| 367 |
+
if status != 3:
|
| 368 |
+
return node, status
|
| 369 |
+
|
| 370 |
+
def _update_pareto(self, pareto_front, pareto_tokens, seq, token_ids, score_vector):
|
| 371 |
+
pareto_front = self.updateParetoFront(
|
| 372 |
+
pareto_front,
|
| 373 |
+
seq,
|
| 374 |
+
score_vector,
|
| 375 |
+
totalSize=self.pareto_max_size,
|
| 376 |
+
)
|
| 377 |
+
pareto_tokens = {k: pareto_tokens[k] for k in pareto_front if k in pareto_tokens}
|
| 378 |
+
if seq in pareto_front:
|
| 379 |
+
pareto_tokens[seq] = token_ids.detach().clone()
|
| 380 |
+
return pareto_front, pareto_tokens
|
| 381 |
+
|
| 382 |
+
def _expand(self, parent, pareto_front, pareto_tokens):
|
| 383 |
+
parent_tokens = parent.tokens["seqs"].to(self.device)
|
| 384 |
+
attn_mask = parent.tokens["attention_mask"].to(self.device)
|
| 385 |
+
t = self.timesteps[parent.timestep] * torch.ones(1, 1, device=self.device)
|
| 386 |
+
|
| 387 |
+
with torch.no_grad():
|
| 388 |
+
_, x_children, log_policy_step, log_pretrained_step = self.base_model.batch_mcts_reverse_step(
|
| 389 |
+
token_array=parent_tokens,
|
| 390 |
+
t=t,
|
| 391 |
+
dt=self.dt,
|
| 392 |
+
batch_size=self.num_children,
|
| 393 |
+
pretrained=self.base_model,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
child_log_rnd = parent.log_rnd + (log_pretrained_step - log_policy_step)
|
| 397 |
+
log_policy_step = log_policy_step * self.sample_prob_weight
|
| 398 |
+
|
| 399 |
+
x_rollout = x_children
|
| 400 |
+
t_step = self.timesteps[parent.timestep] * torch.ones(self.num_children, 1, device=self.device)
|
| 401 |
+
for i in range(1, self.num_steps - parent.timestep):
|
| 402 |
+
t_step = self.timesteps[parent.timestep + i] * torch.ones(self.num_children, 1, device=self.device)
|
| 403 |
+
with torch.no_grad():
|
| 404 |
+
_, x_next, _, _ = self.base_model.mcts_reverse_step(
|
| 405 |
+
x_rollout,
|
| 406 |
+
t=t_step,
|
| 407 |
+
dt=self.dt,
|
| 408 |
+
pretrained=self.base_model,
|
| 409 |
+
)
|
| 410 |
+
x_rollout = x_next
|
| 411 |
+
|
| 412 |
+
if (x_rollout == self.mask_idx).any().item():
|
| 413 |
+
with torch.no_grad():
|
| 414 |
+
_, x_next, _, _ = self.base_model.mcts_noise_removal(
|
| 415 |
+
x_rollout,
|
| 416 |
+
t=t_step,
|
| 417 |
+
dt=self.dt,
|
| 418 |
+
pretrained=self.base_model,
|
| 419 |
+
)
|
| 420 |
+
x_rollout = x_next
|
| 421 |
+
|
| 422 |
+
sequences = self.tokenizer.batch_decode(x_rollout)
|
| 423 |
+
valid_mask = [self.analyzer.is_peptide(seq) for seq in sequences]
|
| 424 |
+
|
| 425 |
+
reward_values = np.full(self.num_children, -float(self.invalid_penalty), dtype=np.float32)
|
| 426 |
+
if any(valid_mask):
|
| 427 |
+
valid_tokens = x_rollout[valid_mask]
|
| 428 |
+
valid_sequences = [seq for seq, keep in zip(sequences, valid_mask) if keep]
|
| 429 |
+
affinity = _affinity_from_scoring(
|
| 430 |
+
self.reward_fn.scoring_fn,
|
| 431 |
+
valid_sequences,
|
| 432 |
+
self.device,
|
| 433 |
+
protein_seq=self.reward_fn.reward_inputs.protein_seq,
|
| 434 |
+
)
|
| 435 |
+
with torch.no_grad():
|
| 436 |
+
direction = self.reward_fn._direction_from_tokens(valid_tokens)
|
| 437 |
+
gated_reward = self.reward_fn._gated_reward(affinity, direction)
|
| 438 |
+
d_star = self.reward_fn.reward_inputs.d_star
|
| 439 |
+
dir_score = (direction - 0.5) * d_star
|
| 440 |
+
|
| 441 |
+
for idx, seq in enumerate(valid_sequences):
|
| 442 |
+
score_vector = np.array(
|
| 443 |
+
[float(affinity[idx].item()), float(dir_score[idx].item())],
|
| 444 |
+
dtype=np.float32,
|
| 445 |
+
)
|
| 446 |
+
pareto_front, pareto_tokens = self._update_pareto(
|
| 447 |
+
pareto_front,
|
| 448 |
+
pareto_tokens,
|
| 449 |
+
seq,
|
| 450 |
+
valid_tokens[idx],
|
| 451 |
+
score_vector,
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
reward_values[np.array(valid_mask)] = gated_reward.detach().cpu().numpy()
|
| 455 |
+
|
| 456 |
+
reward_vectors = []
|
| 457 |
+
for i in range(self.num_children):
|
| 458 |
+
child_tokens = {"seqs": x_children[i].to(dtype=torch.long), "attention_mask": attn_mask}
|
| 459 |
+
reward_vec = np.array([float(reward_values[i])], dtype=np.float32)
|
| 460 |
+
parent.addChildNode(
|
| 461 |
+
tokens=child_tokens,
|
| 462 |
+
log_rnd=child_log_rnd[i],
|
| 463 |
+
log_policy_step=log_policy_step[i],
|
| 464 |
+
log_pretrained_step=log_pretrained_step[i],
|
| 465 |
+
totalReward=reward_vec,
|
| 466 |
+
)
|
| 467 |
+
reward_vectors.append(reward_vec)
|
| 468 |
+
|
| 469 |
+
avg_reward = np.mean(np.stack(reward_vectors, axis=0), axis=0)
|
| 470 |
+
node = parent
|
| 471 |
+
while node:
|
| 472 |
+
node.updateNode(avg_reward)
|
| 473 |
+
node = node.parentNode
|
| 474 |
+
|
| 475 |
+
return pareto_front, pareto_tokens
|
| 476 |
+
|
| 477 |
+
def _select_from_pareto(self, pareto_front, pareto_tokens, batch_size):
|
| 478 |
+
if not pareto_front:
|
| 479 |
+
return self.base_model.sample_prior(batch_size, self.seq_length).to(self.device)
|
| 480 |
+
|
| 481 |
+
seqs = list(pareto_front.keys())
|
| 482 |
+
scores = np.stack([pareto_front[seq] for seq in seqs], axis=0)
|
| 483 |
+
affinity = scores[:, 0]
|
| 484 |
+
dir_score = scores[:, 1]
|
| 485 |
+
gate = 1.0 / (1.0 + np.exp(-dir_score / max(self.reward_fn.reward_alpha, 1e-6)))
|
| 486 |
+
gated = affinity * gate
|
| 487 |
+
order = np.argsort(-gated)
|
| 488 |
+
|
| 489 |
+
if len(order) >= batch_size:
|
| 490 |
+
selected = [seqs[i] for i in order[:batch_size]]
|
| 491 |
+
else:
|
| 492 |
+
repeats = np.random.choice(order, size=batch_size, replace=True)
|
| 493 |
+
selected = [seqs[i] for i in repeats]
|
| 494 |
+
|
| 495 |
+
tokens = [pareto_tokens[seq] for seq in selected]
|
| 496 |
+
return torch.stack(tokens, dim=0).to(self.device)
|
| 497 |
+
|
| 498 |
+
def sample(self, batch_size):
|
| 499 |
+
self.base_model.eval()
|
| 500 |
+
root = self._init_root()
|
| 501 |
+
pareto_front = {}
|
| 502 |
+
pareto_tokens = {}
|
| 503 |
+
|
| 504 |
+
for _ in range(self.mcts_iterations):
|
| 505 |
+
leaf, status = self._select(root)
|
| 506 |
+
if status == 1:
|
| 507 |
+
continue
|
| 508 |
+
pareto_front, pareto_tokens = self._expand(leaf, pareto_front, pareto_tokens)
|
| 509 |
+
|
| 510 |
+
return self._select_from_pareto(pareto_front, pareto_tokens, batch_size)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def _logits_and_probs_from_tokens(
|
| 514 |
+
base_model,
|
| 515 |
+
token_ids: torch.Tensor,
|
| 516 |
+
attn_mask: torch.Tensor,
|
| 517 |
+
) -> torch.Tensor:
|
| 518 |
+
logits = _logits_from_inputs(base_model, input_ids=token_ids, attn_mask=attn_mask)
|
| 519 |
+
log_probs = base_model.subs_parameterization(logits, token_ids)
|
| 520 |
+
return log_probs
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def _logits_and_probs_from_one_hot(
|
| 524 |
+
base_model,
|
| 525 |
+
y_one_hot: torch.Tensor,
|
| 526 |
+
token_ids: torch.Tensor,
|
| 527 |
+
attn_mask: torch.Tensor,
|
| 528 |
+
) -> torch.Tensor:
|
| 529 |
+
emb_weight = base_model.backbone.model.roformer.embeddings.word_embeddings.weight
|
| 530 |
+
inputs_embeds = y_one_hot @ emb_weight
|
| 531 |
+
logits = _logits_from_inputs(base_model, inputs_embeds=inputs_embeds, attn_mask=attn_mask)
|
| 532 |
+
log_probs = base_model.subs_parameterization(logits, token_ids)
|
| 533 |
+
return log_probs
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def classifier_guidance(
|
| 537 |
+
base_model,
|
| 538 |
+
reward_fn: RewardWrapper,
|
| 539 |
+
batch_size: int,
|
| 540 |
+
seq_length: int,
|
| 541 |
+
num_steps: int,
|
| 542 |
+
guidance_scale: float,
|
| 543 |
+
eps: float = DEFAULT_EPS,
|
| 544 |
+
guidance_steps: Optional[int] = None,
|
| 545 |
+
) -> Dict[str, torch.Tensor]:
|
| 546 |
+
device = base_model.device
|
| 547 |
+
mask_idx = base_model.mask_index
|
| 548 |
+
vocab_size = base_model.vocab_size
|
| 549 |
+
x = base_model.sample_prior(batch_size, seq_length).to(device)
|
| 550 |
+
attn_mask = torch.ones_like(x, device=device)
|
| 551 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=device)
|
| 552 |
+
dt = torch.as_tensor((1 - eps) / num_steps, device=device)
|
| 553 |
+
|
| 554 |
+
guidance_enabled = True
|
| 555 |
+
for step in range(num_steps):
|
| 556 |
+
t = timesteps[step].repeat(batch_size)
|
| 557 |
+
use_guidance = guidance_enabled and (guidance_steps is None or step >= num_steps - guidance_steps)
|
| 558 |
+
if not use_guidance:
|
| 559 |
+
log_probs = _logits_and_probs_from_tokens(base_model, x, attn_mask)
|
| 560 |
+
q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx)
|
| 561 |
+
x = _sample_from_q(q_base, x, mask_idx)
|
| 562 |
+
continue
|
| 563 |
+
|
| 564 |
+
y_one_hot = _tokens_to_one_hot(x, vocab_size).to(device)
|
| 565 |
+
y_one_hot.requires_grad_(True)
|
| 566 |
+
token_ids = x.detach()
|
| 567 |
+
log_probs = _logits_and_probs_from_one_hot(base_model, y_one_hot, token_ids, attn_mask)
|
| 568 |
+
y_probs = log_probs.exp()
|
| 569 |
+
token_ids_for_affinity = y_probs.argmax(dim=-1).detach()
|
| 570 |
+
reward = reward_fn.reward_from_probs(y_probs, token_ids_for_affinity, attn_mask)
|
| 571 |
+
if not reward.requires_grad:
|
| 572 |
+
if guidance_enabled:
|
| 573 |
+
logger.warning(
|
| 574 |
+
"Reward does not require grad; disabling gradient guidance for classifier_guidance."
|
| 575 |
+
)
|
| 576 |
+
guidance_enabled = False
|
| 577 |
+
q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx)
|
| 578 |
+
x = _sample_from_q(q_base, x, mask_idx)
|
| 579 |
+
continue
|
| 580 |
+
reward.sum().backward()
|
| 581 |
+
grad = y_one_hot.grad
|
| 582 |
+
q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx)
|
| 583 |
+
guidance = guidance_scale * (grad - grad[:, :, mask_idx].unsqueeze(-1))
|
| 584 |
+
guidance = guidance.clamp(min=-50.0, max=50.0)
|
| 585 |
+
q_guided = q_base * torch.exp(guidance)
|
| 586 |
+
q_guided = _normalize_probs(q_guided)
|
| 587 |
+
x = _sample_from_q(q_guided, x, mask_idx)
|
| 588 |
+
|
| 589 |
+
return {"tokens": x}
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
def unguided_sampling(
|
| 593 |
+
base_model,
|
| 594 |
+
batch_size: int,
|
| 595 |
+
seq_length: int,
|
| 596 |
+
num_steps: int,
|
| 597 |
+
eps: float = DEFAULT_EPS,
|
| 598 |
+
) -> Dict[str, torch.Tensor]:
|
| 599 |
+
device = base_model.device
|
| 600 |
+
mask_idx = base_model.mask_index
|
| 601 |
+
x = base_model.sample_prior(batch_size, seq_length).to(device)
|
| 602 |
+
attn_mask = torch.ones_like(x, device=device)
|
| 603 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=device)
|
| 604 |
+
dt = torch.as_tensor((1 - eps) / num_steps, device=device)
|
| 605 |
+
|
| 606 |
+
for step in range(num_steps):
|
| 607 |
+
t = timesteps[step].repeat(batch_size)
|
| 608 |
+
log_probs = _logits_and_probs_from_tokens(base_model, x, attn_mask)
|
| 609 |
+
q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx)
|
| 610 |
+
x = _sample_from_q(q_base, x, mask_idx)
|
| 611 |
+
|
| 612 |
+
return {"tokens": x}
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def sequential_monte_carlo(
|
| 616 |
+
base_model,
|
| 617 |
+
reward_fn: RewardWrapper,
|
| 618 |
+
batch_size: int,
|
| 619 |
+
seq_length: int,
|
| 620 |
+
num_steps: int,
|
| 621 |
+
alpha: float,
|
| 622 |
+
eps: float = DEFAULT_EPS,
|
| 623 |
+
) -> Dict[str, torch.Tensor]:
|
| 624 |
+
device = base_model.device
|
| 625 |
+
mask_idx = base_model.mask_index
|
| 626 |
+
x = base_model.sample_prior(batch_size, seq_length).to(device)
|
| 627 |
+
attn_mask = torch.ones_like(x, device=device)
|
| 628 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=device)
|
| 629 |
+
dt = torch.as_tensor((1 - eps) / num_steps, device=device)
|
| 630 |
+
|
| 631 |
+
with torch.no_grad():
|
| 632 |
+
r_current = reward_fn.reward_from_tokens(x, attn_mask).detach()
|
| 633 |
+
for step in range(num_steps):
|
| 634 |
+
t = timesteps[step].repeat(batch_size)
|
| 635 |
+
log_probs = _logits_and_probs_from_tokens(base_model, x, attn_mask)
|
| 636 |
+
q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx)
|
| 637 |
+
x_next = _sample_from_q(q_base, x, mask_idx)
|
| 638 |
+
|
| 639 |
+
with torch.no_grad():
|
| 640 |
+
r_next = reward_fn.reward_from_tokens(x_next, attn_mask).detach()
|
| 641 |
+
weights = torch.exp((r_next - r_current) / alpha).clamp_max(1e6)
|
| 642 |
+
weights = _safe_resample_weights(weights)
|
| 643 |
+
indices = torch.multinomial(weights, num_samples=batch_size, replacement=True)
|
| 644 |
+
x = x_next[indices]
|
| 645 |
+
r_current = r_next[indices]
|
| 646 |
+
|
| 647 |
+
return {"tokens": x}
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def twisted_diffusion_sampler(
|
| 651 |
+
base_model,
|
| 652 |
+
reward_fn: RewardWrapper,
|
| 653 |
+
batch_size: int,
|
| 654 |
+
seq_length: int,
|
| 655 |
+
num_steps: int,
|
| 656 |
+
guidance_scale: float,
|
| 657 |
+
alpha: float,
|
| 658 |
+
eps: float = DEFAULT_EPS,
|
| 659 |
+
guidance_steps: Optional[int] = None,
|
| 660 |
+
) -> Dict[str, torch.Tensor]:
|
| 661 |
+
device = base_model.device
|
| 662 |
+
mask_idx = base_model.mask_index
|
| 663 |
+
vocab_size = base_model.vocab_size
|
| 664 |
+
x = base_model.sample_prior(batch_size, seq_length).to(device)
|
| 665 |
+
attn_mask = torch.ones_like(x, device=device)
|
| 666 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=device)
|
| 667 |
+
dt = torch.as_tensor((1 - eps) / num_steps, device=device)
|
| 668 |
+
|
| 669 |
+
with torch.no_grad():
|
| 670 |
+
r_current = reward_fn.reward_from_tokens(x, attn_mask).detach()
|
| 671 |
+
guidance_enabled = True
|
| 672 |
+
for step in range(num_steps):
|
| 673 |
+
t = timesteps[step].repeat(batch_size)
|
| 674 |
+
use_guidance = guidance_enabled and (guidance_steps is None or step >= num_steps - guidance_steps)
|
| 675 |
+
|
| 676 |
+
if use_guidance:
|
| 677 |
+
y_one_hot = _tokens_to_one_hot(x, vocab_size).to(device)
|
| 678 |
+
y_one_hot.requires_grad_(True)
|
| 679 |
+
token_ids = x.detach()
|
| 680 |
+
log_probs = _logits_and_probs_from_one_hot(base_model, y_one_hot, token_ids, attn_mask)
|
| 681 |
+
y_probs = log_probs.exp()
|
| 682 |
+
token_ids_for_affinity = y_probs.argmax(dim=-1).detach()
|
| 683 |
+
reward = reward_fn.reward_from_probs(y_probs, token_ids_for_affinity, attn_mask)
|
| 684 |
+
q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx)
|
| 685 |
+
if not reward.requires_grad:
|
| 686 |
+
if guidance_enabled:
|
| 687 |
+
logger.warning(
|
| 688 |
+
"Reward does not require grad; disabling gradient guidance for twisted_diffusion_sampler."
|
| 689 |
+
)
|
| 690 |
+
guidance_enabled = False
|
| 691 |
+
q_guided = q_base
|
| 692 |
+
else:
|
| 693 |
+
reward.sum().backward()
|
| 694 |
+
grad = y_one_hot.grad
|
| 695 |
+
guidance = guidance_scale * (grad - grad[:, :, mask_idx].unsqueeze(-1))
|
| 696 |
+
guidance = guidance.clamp(min=-50.0, max=50.0)
|
| 697 |
+
q_guided = q_base * torch.exp(guidance)
|
| 698 |
+
q_guided = _normalize_probs(q_guided)
|
| 699 |
+
else:
|
| 700 |
+
log_probs = _logits_and_probs_from_tokens(base_model, x, attn_mask)
|
| 701 |
+
q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx)
|
| 702 |
+
q_guided = q_base
|
| 703 |
+
|
| 704 |
+
x_next = _sample_from_q(q_guided, x, mask_idx)
|
| 705 |
+
with torch.no_grad():
|
| 706 |
+
r_next = reward_fn.reward_from_tokens(x_next, attn_mask).detach()
|
| 707 |
+
|
| 708 |
+
logp_guided = _sequence_logprob(q_guided, x_next, x, mask_idx)
|
| 709 |
+
logp_base = _sequence_logprob(q_base, x_next, x, mask_idx)
|
| 710 |
+
weights = torch.exp((r_next - r_current) / alpha + (logp_base - logp_guided)).clamp_max(1e6)
|
| 711 |
+
weights = _safe_resample_weights(weights)
|
| 712 |
+
indices = torch.multinomial(weights, num_samples=batch_size, replacement=True)
|
| 713 |
+
x = x_next[indices]
|
| 714 |
+
r_current = r_next[indices]
|
| 715 |
+
|
| 716 |
+
return {"tokens": x}
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
def peptune_mctg_sampling(
|
| 720 |
+
base_model,
|
| 721 |
+
reward_fn: RewardWrapper,
|
| 722 |
+
batch_size: int,
|
| 723 |
+
seq_length: int,
|
| 724 |
+
num_steps: int,
|
| 725 |
+
mcts_iterations: int,
|
| 726 |
+
num_children: int,
|
| 727 |
+
alpha: float,
|
| 728 |
+
sample_prob_weight: float,
|
| 729 |
+
invalid_penalty: float = 1.0,
|
| 730 |
+
pareto_max_size: Optional[int] = None,
|
| 731 |
+
eps: float = DEFAULT_EPS,
|
| 732 |
+
) -> Dict[str, torch.Tensor]:
|
| 733 |
+
sampler = PepTuneSampler(
|
| 734 |
+
base_model=base_model,
|
| 735 |
+
reward_fn=reward_fn,
|
| 736 |
+
seq_length=seq_length,
|
| 737 |
+
num_steps=num_steps,
|
| 738 |
+
mcts_iterations=mcts_iterations,
|
| 739 |
+
num_children=num_children,
|
| 740 |
+
sample_prob_weight=sample_prob_weight,
|
| 741 |
+
invalid_penalty=invalid_penalty,
|
| 742 |
+
pareto_max_size=pareto_max_size,
|
| 743 |
+
eps=eps,
|
| 744 |
+
)
|
| 745 |
+
tokens = sampler.sample(batch_size=batch_size)
|
| 746 |
+
return {"tokens": tokens}
|
baselines/run.sh
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
ROOT_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 6 |
+
|
| 7 |
+
CSV_PATH="${1:-To Be Added}"
|
| 8 |
+
BASELINE="${2:-unguided}"
|
| 9 |
+
DEVICE="${3:-cuda:4}"
|
| 10 |
+
OUTPUT_DIR="${4:-${SCRIPT_DIR}/outputs}"
|
| 11 |
+
NGPUS="${5:-1}"
|
| 12 |
+
MASTER_PORT="${6:-29500}"
|
| 13 |
+
|
| 14 |
+
if [ "$NGPUS" -gt 1 ]; then
|
| 15 |
+
echo "Running multi-GPU inference with $NGPUS GPUs (master port: $MASTER_PORT)"
|
| 16 |
+
LAUNCH_DEVICE="cuda"
|
| 17 |
+
python -m torch.distributed.run \
|
| 18 |
+
--nproc_per_node="$NGPUS" \
|
| 19 |
+
--master_port="$MASTER_PORT" \
|
| 20 |
+
"${SCRIPT_DIR}/sampling_setup.py" \
|
| 21 |
+
--ckpt_path "${ROOT_DIR}/pretrained/peptune-pretrained.ckpt" \
|
| 22 |
+
--device "${LAUNCH_DEVICE}" \
|
| 23 |
+
--baseline "${BASELINE}" \
|
| 24 |
+
--targets_csv "${CSV_PATH}" \
|
| 25 |
+
--batch_size 8 \
|
| 26 |
+
--num_steps 128 \
|
| 27 |
+
--num_batches 1 \
|
| 28 |
+
--output_dir "${OUTPUT_DIR}"
|
| 29 |
+
|
| 30 |
+
export OUTPUT_DIR BASELINE
|
| 31 |
+
python - <<'PY'
|
| 32 |
+
import glob
|
| 33 |
+
import os
|
| 34 |
+
import pandas as pd
|
| 35 |
+
|
| 36 |
+
out_dir = os.environ["OUTPUT_DIR"]
|
| 37 |
+
baseline = os.environ["BASELINE"]
|
| 38 |
+
|
| 39 |
+
def merge(pattern, output_name):
|
| 40 |
+
files = sorted(glob.glob(os.path.join(out_dir, pattern)))
|
| 41 |
+
if not files:
|
| 42 |
+
return
|
| 43 |
+
dfs = []
|
| 44 |
+
for path in files:
|
| 45 |
+
try:
|
| 46 |
+
dfs.append(pd.read_csv(path))
|
| 47 |
+
except Exception as exc:
|
| 48 |
+
print(f"[merge] skip {path}: {exc}")
|
| 49 |
+
if not dfs:
|
| 50 |
+
return
|
| 51 |
+
merged = pd.concat(dfs, ignore_index=True)
|
| 52 |
+
merged.to_csv(os.path.join(out_dir, output_name), index=False)
|
| 53 |
+
print(f"[merge] wrote {output_name} from {len(files)} shards")
|
| 54 |
+
|
| 55 |
+
merge(f"{baseline}_samples_rank*.csv", f"{baseline}_samples.csv")
|
| 56 |
+
merge("batch_times_rank*.csv", "batch_times.csv")
|
| 57 |
+
merge(f"{baseline}_metrics_rank*.csv", f"{baseline}_metrics.csv")
|
| 58 |
+
PY
|
| 59 |
+
exit 0
|
| 60 |
+
fi
|
| 61 |
+
|
| 62 |
+
python "${SCRIPT_DIR}/sampling_setup.py" \
|
| 63 |
+
--ckpt_path "${ROOT_DIR}/pretrained/peptune-pretrained.ckpt" \
|
| 64 |
+
--device "${DEVICE}" \
|
| 65 |
+
--baseline "${BASELINE}" \
|
| 66 |
+
--targets_csv "${CSV_PATH}" \
|
| 67 |
+
--batch_size 8 \
|
| 68 |
+
--num_steps 128 \
|
| 69 |
+
--num_batches 1 \
|
| 70 |
+
--output_dir "${OUTPUT_DIR}"
|
| 71 |
+
|
| 72 |
+
# ./run.sh To Be Added peptune cuda:0 To Be Added
|
| 73 |
+
# ./run.sh To Be Added peptune cuda To Be Added 4 29501
|
| 74 |
+
# ./run.sh To Be Added tds cuda:1 To Be Added
|
| 75 |
+
# ./run.sh To Be Added smc cuda:2 To Be Added
|
| 76 |
+
# ./run.sh To Be Added cg cuda:3 To Be Added
|
| 77 |
+
# ./run.sh To Be Added unguided cuda:4 To Be Added
|
baselines/run_mcts_tr2d2.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from types import SimpleNamespace
|
| 6 |
+
from typing import Any, Dict, List, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
|
| 13 |
+
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 14 |
+
if ROOT_DIR not in sys.path:
|
| 15 |
+
sys.path.insert(0, ROOT_DIR)
|
| 16 |
+
|
| 17 |
+
from diffusion import Diffusion
|
| 18 |
+
from configs.finetune_config import (
|
| 19 |
+
DiffusionConfig,
|
| 20 |
+
RoFormerConfig,
|
| 21 |
+
NoiseConfig,
|
| 22 |
+
TrainingConfig,
|
| 23 |
+
SamplingConfig,
|
| 24 |
+
EvalConfig,
|
| 25 |
+
OptimConfig,
|
| 26 |
+
MCTSConfig,
|
| 27 |
+
)
|
| 28 |
+
from finetune_utils import load_tokenizer
|
| 29 |
+
from finetune_distributed_utils import setup_distributed, cleanup_distributed, is_main_process
|
| 30 |
+
from scoring.functions.binding import MultiTargetBindingAffinity, TargetSpecificBindingAffinity
|
| 31 |
+
from td3b.direction_oracle import DirectionalOracle
|
| 32 |
+
from finetune_multi_target_tr2d2_ddp import TR2D2GatedReward, TargetDataset, create_tr2d2_mcts
|
| 33 |
+
from utils.app import PeptideAnalyzer
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _load_checkpoint(ckpt_path: str, device: torch.device) -> Dict[str, Any]:
|
| 37 |
+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 38 |
+
if not isinstance(ckpt, dict):
|
| 39 |
+
raise ValueError(f"Unsupported checkpoint format: {type(ckpt)}")
|
| 40 |
+
return ckpt
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _extract_state_and_config(ckpt: Dict[str, Any]) -> Dict[str, Any]:
|
| 44 |
+
state_dict = ckpt.get("model_state_dict") or ckpt.get("state_dict") or ckpt
|
| 45 |
+
config = ckpt.get("config") or {}
|
| 46 |
+
return {"state_dict": state_dict, "config": config}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _build_args(cfg: Dict[str, Any], cli: argparse.Namespace) -> argparse.Namespace:
|
| 50 |
+
defaults = {
|
| 51 |
+
"base_path": "To Be Added",
|
| 52 |
+
"seq_length": 200,
|
| 53 |
+
"sampling_eps": 1e-3,
|
| 54 |
+
"total_num_steps": 128,
|
| 55 |
+
"alpha": 0.1,
|
| 56 |
+
"hidden_dim": 768,
|
| 57 |
+
"num_layers": 8,
|
| 58 |
+
"num_heads": 8,
|
| 59 |
+
"min_affinity_threshold": 0.0,
|
| 60 |
+
"sigmoid_temperature": 0.1,
|
| 61 |
+
"val_samples_per_target": 8,
|
| 62 |
+
"direction_oracle_esm_name": "facebook/esm2_t33_650M_UR50D",
|
| 63 |
+
"direction_oracle_esm_cache_dir": None,
|
| 64 |
+
"direction_oracle_esm_local_files_only": False,
|
| 65 |
+
"direction_oracle_max_ligand_length": 768,
|
| 66 |
+
"direction_oracle_max_protein_length": 1024,
|
| 67 |
+
"direction_oracle_d_model": 256,
|
| 68 |
+
"direction_oracle_n_heads": 4,
|
| 69 |
+
"direction_oracle_n_self_attn_layers": 1,
|
| 70 |
+
"direction_oracle_n_bmca_layers": 2,
|
| 71 |
+
"direction_oracle_dropout": 0.3,
|
| 72 |
+
"num_iter": 20,
|
| 73 |
+
"num_children": 24,
|
| 74 |
+
"buffer_size": 32,
|
| 75 |
+
"exploration": 1.0,
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
merged = dict(defaults)
|
| 79 |
+
merged.update(cfg or {})
|
| 80 |
+
|
| 81 |
+
if cli.base_path is not None:
|
| 82 |
+
merged["base_path"] = cli.base_path
|
| 83 |
+
if cli.val_csv is not None:
|
| 84 |
+
merged["val_csv"] = cli.val_csv
|
| 85 |
+
if cli.save_path is not None:
|
| 86 |
+
merged["save_path"] = cli.save_path
|
| 87 |
+
if cli.device is not None:
|
| 88 |
+
merged["device"] = cli.device
|
| 89 |
+
if cli.val_samples_per_target is not None:
|
| 90 |
+
merged["val_samples_per_target"] = cli.val_samples_per_target
|
| 91 |
+
if cli.seq_length is not None:
|
| 92 |
+
merged["seq_length"] = cli.seq_length
|
| 93 |
+
if cli.total_num_steps is not None:
|
| 94 |
+
merged["total_num_steps"] = cli.total_num_steps
|
| 95 |
+
if cli.sampling_eps is not None:
|
| 96 |
+
merged["sampling_eps"] = cli.sampling_eps
|
| 97 |
+
if cli.alpha is not None:
|
| 98 |
+
merged["alpha"] = cli.alpha
|
| 99 |
+
if cli.num_iter is not None:
|
| 100 |
+
merged["num_iter"] = cli.num_iter
|
| 101 |
+
if cli.num_children is not None:
|
| 102 |
+
merged["num_children"] = cli.num_children
|
| 103 |
+
if cli.buffer_size is not None:
|
| 104 |
+
merged["buffer_size"] = cli.buffer_size
|
| 105 |
+
if cli.exploration is not None:
|
| 106 |
+
merged["exploration"] = cli.exploration
|
| 107 |
+
if cli.max_sequence_length is not None:
|
| 108 |
+
merged["max_sequence_length"] = cli.max_sequence_length
|
| 109 |
+
|
| 110 |
+
args = SimpleNamespace(**merged)
|
| 111 |
+
|
| 112 |
+
base_tr2d2_path = os.path.join(args.base_path, "tr2d2-pep")
|
| 113 |
+
if not getattr(args, "direction_oracle_ckpt", None):
|
| 114 |
+
args.direction_oracle_ckpt = os.path.join(base_tr2d2_path, "direction_oracle.pt")
|
| 115 |
+
if not getattr(args, "direction_oracle_tr2d2_checkpoint", None):
|
| 116 |
+
args.direction_oracle_tr2d2_checkpoint = os.path.join(
|
| 117 |
+
base_tr2d2_path, "pretrained", "peptune-pretrained.ckpt"
|
| 118 |
+
)
|
| 119 |
+
if not getattr(args, "direction_oracle_tokenizer_vocab", None):
|
| 120 |
+
args.direction_oracle_tokenizer_vocab = os.path.join(
|
| 121 |
+
base_tr2d2_path, "tokenizer", "new_vocab.txt"
|
| 122 |
+
)
|
| 123 |
+
if not getattr(args, "direction_oracle_tokenizer_splits", None):
|
| 124 |
+
args.direction_oracle_tokenizer_splits = os.path.join(
|
| 125 |
+
base_tr2d2_path, "tokenizer", "new_splits.txt"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if not getattr(args, "save_path", None):
|
| 129 |
+
args.save_path = os.path.join(base_tr2d2_path, "baselines", "outputs_mcts_tr2d2")
|
| 130 |
+
os.makedirs(args.save_path, exist_ok=True)
|
| 131 |
+
return args
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _build_model(args: argparse.Namespace, state_dict: Dict[str, Any], device: torch.device) -> Diffusion:
|
| 135 |
+
config = DiffusionConfig(
|
| 136 |
+
roformer=RoFormerConfig(
|
| 137 |
+
hidden_size=args.hidden_dim,
|
| 138 |
+
n_layers=args.num_layers,
|
| 139 |
+
n_heads=args.num_heads,
|
| 140 |
+
),
|
| 141 |
+
noise=NoiseConfig(),
|
| 142 |
+
training=TrainingConfig(sampling_eps=args.sampling_eps),
|
| 143 |
+
sampling=SamplingConfig(
|
| 144 |
+
steps=args.total_num_steps,
|
| 145 |
+
sampling_eps=args.sampling_eps,
|
| 146 |
+
),
|
| 147 |
+
eval_cfg=EvalConfig(),
|
| 148 |
+
optim=OptimConfig(lr=getattr(args, "learning_rate", 3e-4)),
|
| 149 |
+
mcts=MCTSConfig(),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
tokenizer = load_tokenizer(args.base_path)
|
| 153 |
+
model = Diffusion(
|
| 154 |
+
config=config,
|
| 155 |
+
tokenizer=tokenizer,
|
| 156 |
+
device=device,
|
| 157 |
+
).to(device)
|
| 158 |
+
load_result = model.load_state_dict(state_dict, strict=False)
|
| 159 |
+
if load_result.missing_keys:
|
| 160 |
+
print(f"[load] Missing keys: {len(load_result.missing_keys)}")
|
| 161 |
+
if load_result.unexpected_keys:
|
| 162 |
+
print(f"[load] Unexpected keys: {len(load_result.unexpected_keys)}")
|
| 163 |
+
model.eval()
|
| 164 |
+
return model
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _build_oracle(args: argparse.Namespace, device: torch.device) -> DirectionalOracle:
|
| 168 |
+
oracle = DirectionalOracle(
|
| 169 |
+
model_ckpt=args.direction_oracle_ckpt,
|
| 170 |
+
tr2d2_checkpoint=args.direction_oracle_tr2d2_checkpoint,
|
| 171 |
+
tokenizer_vocab=args.direction_oracle_tokenizer_vocab,
|
| 172 |
+
tokenizer_splits=args.direction_oracle_tokenizer_splits,
|
| 173 |
+
esm_name=args.direction_oracle_esm_name,
|
| 174 |
+
d_model=args.direction_oracle_d_model,
|
| 175 |
+
n_heads=args.direction_oracle_n_heads,
|
| 176 |
+
n_self_attn_layers=args.direction_oracle_n_self_attn_layers,
|
| 177 |
+
n_bmca_layers=args.direction_oracle_n_bmca_layers,
|
| 178 |
+
dropout=args.direction_oracle_dropout,
|
| 179 |
+
max_ligand_length=args.direction_oracle_max_ligand_length,
|
| 180 |
+
max_protein_length=args.direction_oracle_max_protein_length,
|
| 181 |
+
device=device,
|
| 182 |
+
esm_cache_dir=args.direction_oracle_esm_cache_dir,
|
| 183 |
+
esm_local_files_only=args.direction_oracle_esm_local_files_only,
|
| 184 |
+
)
|
| 185 |
+
oracle.eval()
|
| 186 |
+
return oracle
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _compute_direction_accuracy(directions: np.ndarray, d_star: float) -> np.ndarray:
|
| 190 |
+
if directions.size == 0:
|
| 191 |
+
return directions
|
| 192 |
+
acc = np.full(directions.shape, np.nan, dtype=np.float32)
|
| 193 |
+
valid = np.isfinite(directions)
|
| 194 |
+
if not valid.any():
|
| 195 |
+
return acc
|
| 196 |
+
if d_star > 0:
|
| 197 |
+
acc[valid] = (directions[valid] >= 0.5).astype(np.float32)
|
| 198 |
+
else:
|
| 199 |
+
acc[valid] = (directions[valid] < 0.5).astype(np.float32)
|
| 200 |
+
return acc
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _nanmean(values: np.ndarray) -> float:
|
| 204 |
+
if values.size == 0:
|
| 205 |
+
return 0.0
|
| 206 |
+
finite = values[np.isfinite(values)]
|
| 207 |
+
return float(np.mean(finite)) if finite.size else 0.0
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def _nanstd(values: np.ndarray) -> float:
|
| 211 |
+
if values.size == 0:
|
| 212 |
+
return 0.0
|
| 213 |
+
finite = values[np.isfinite(values)]
|
| 214 |
+
return float(np.std(finite)) if finite.size else 0.0
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def main() -> None:
|
| 218 |
+
parser = argparse.ArgumentParser(description="MCTS-based TR2-D2 evaluation.")
|
| 219 |
+
parser.add_argument("--ckpt_path", required=True, help="Path to finetuned checkpoint (.ckpt)")
|
| 220 |
+
parser.add_argument("--val_csv", required=True, help="Validation CSV path")
|
| 221 |
+
parser.add_argument("--device", default="cuda", help="Device string (e.g., cuda:0 or cpu)")
|
| 222 |
+
parser.add_argument("--base_path", default=None, help="Base path for TR2-D2")
|
| 223 |
+
parser.add_argument("--save_path", default=None, help="Output directory for evaluation CSV")
|
| 224 |
+
parser.add_argument("--epoch", type=int, default=0, help="Epoch number to label outputs")
|
| 225 |
+
parser.add_argument("--val_samples_per_target", type=int, default=None, help="Samples per target (unused by MCTS)")
|
| 226 |
+
parser.add_argument("--seq_length", type=int, default=None, help="Fallback sequence length")
|
| 227 |
+
parser.add_argument("--total_num_steps", type=int, default=None, help="Diffusion steps")
|
| 228 |
+
parser.add_argument("--sampling_eps", type=float, default=None, help="Sampling epsilon")
|
| 229 |
+
parser.add_argument("--alpha", type=float, default=None, help="MCTS alpha temperature")
|
| 230 |
+
parser.add_argument("--num_iter", type=int, default=None, help="MCTS iterations")
|
| 231 |
+
parser.add_argument("--num_children", type=int, default=None, help="MCTS children per expand")
|
| 232 |
+
parser.add_argument("--buffer_size", type=int, default=None, help="MCTS buffer size")
|
| 233 |
+
parser.add_argument("--exploration", type=float, default=None, help="MCTS exploration constant")
|
| 234 |
+
parser.add_argument("--max_sequence_length", type=int, default=1035)
|
| 235 |
+
parser.add_argument("--max_attempts", type=int, default=3, help="Max MCTS attempts to reach target count")
|
| 236 |
+
parser.add_argument("--seed", type=int, default=None, help="Random seed")
|
| 237 |
+
cli_args = parser.parse_args()
|
| 238 |
+
|
| 239 |
+
rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 240 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 241 |
+
|
| 242 |
+
if world_size > 1:
|
| 243 |
+
setup_distributed(rank, world_size)
|
| 244 |
+
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
|
| 245 |
+
else:
|
| 246 |
+
device = torch.device(cli_args.device)
|
| 247 |
+
|
| 248 |
+
if cli_args.seed is not None:
|
| 249 |
+
torch.manual_seed(cli_args.seed + rank)
|
| 250 |
+
np.random.seed(cli_args.seed + rank)
|
| 251 |
+
|
| 252 |
+
ckpt = _load_checkpoint(cli_args.ckpt_path, device)
|
| 253 |
+
payload = _extract_state_and_config(ckpt)
|
| 254 |
+
args = _build_args(payload["config"], cli_args)
|
| 255 |
+
|
| 256 |
+
tokenizer = load_tokenizer(args.base_path)
|
| 257 |
+
val_dataset = TargetDataset(args.val_csv, tokenizer=tokenizer)
|
| 258 |
+
|
| 259 |
+
policy_model = _build_model(args, payload["state_dict"], device)
|
| 260 |
+
|
| 261 |
+
multi_target_affinity = MultiTargetBindingAffinity(
|
| 262 |
+
tokenizer=tokenizer,
|
| 263 |
+
base_path=args.base_path,
|
| 264 |
+
device=device,
|
| 265 |
+
emb_model=policy_model.backbone,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
directional_oracle = _build_oracle(args, device)
|
| 269 |
+
analyzer = PeptideAnalyzer()
|
| 270 |
+
|
| 271 |
+
val_targets = val_dataset.get_all_targets()
|
| 272 |
+
if world_size > 1:
|
| 273 |
+
my_targets = val_targets[rank::world_size]
|
| 274 |
+
else:
|
| 275 |
+
my_targets = val_targets
|
| 276 |
+
|
| 277 |
+
records: List[Dict[str, Any]] = []
|
| 278 |
+
protein_token_cache: Dict[str, torch.Tensor] = {}
|
| 279 |
+
|
| 280 |
+
with torch.no_grad():
|
| 281 |
+
for target_seq in my_targets:
|
| 282 |
+
target_tokens = protein_token_cache.get(target_seq)
|
| 283 |
+
if target_tokens is None:
|
| 284 |
+
target_tokens = directional_oracle.encode_protein(target_seq)
|
| 285 |
+
protein_token_cache[target_seq] = target_tokens
|
| 286 |
+
|
| 287 |
+
for direction_name, d_star in [("agonist", 1.0), ("antagonist", -1.0)]:
|
| 288 |
+
target_length = val_dataset.get_sequence_length(target_seq, direction_name)
|
| 289 |
+
if target_length > args.max_sequence_length:
|
| 290 |
+
target_length = args.max_sequence_length
|
| 291 |
+
|
| 292 |
+
original_seq_length = args.seq_length
|
| 293 |
+
args.seq_length = int(target_length)
|
| 294 |
+
|
| 295 |
+
target_affinity = TargetSpecificBindingAffinity(multi_target_affinity, target_seq)
|
| 296 |
+
reward_model = TR2D2GatedReward(
|
| 297 |
+
affinity_predictor=target_affinity,
|
| 298 |
+
directional_oracle=directional_oracle,
|
| 299 |
+
target_direction=d_star,
|
| 300 |
+
target_protein_tokens=target_tokens,
|
| 301 |
+
tokenizer=tokenizer,
|
| 302 |
+
device=device,
|
| 303 |
+
min_affinity_threshold=args.min_affinity_threshold,
|
| 304 |
+
temperature=args.sigmoid_temperature,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
mcts = create_tr2d2_mcts(
|
| 308 |
+
args=args,
|
| 309 |
+
policy_model=policy_model,
|
| 310 |
+
reward_function=reward_model,
|
| 311 |
+
buffer_size=args.buffer_size,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
target_count = int(args.val_samples_per_target)
|
| 315 |
+
collected_sequences: List[str] = []
|
| 316 |
+
attempt_valid_fractions: List[float] = []
|
| 317 |
+
|
| 318 |
+
for attempt in range(max(cli_args.max_attempts, 1)):
|
| 319 |
+
try:
|
| 320 |
+
_, _, _, _, sequences = mcts.forward(resetTree=True)
|
| 321 |
+
except Exception as exc:
|
| 322 |
+
print(f"[mcts] failed for target={target_seq[:12]} dir={direction_name}: {exc}")
|
| 323 |
+
sequences = []
|
| 324 |
+
|
| 325 |
+
attempt_valid = float(np.mean(mcts.valid_fraction_log)) if getattr(mcts, "valid_fraction_log", None) else 0.0
|
| 326 |
+
attempt_valid_fractions.append(attempt_valid)
|
| 327 |
+
|
| 328 |
+
if sequences:
|
| 329 |
+
collected_sequences.extend(sequences)
|
| 330 |
+
|
| 331 |
+
if len(collected_sequences) >= target_count:
|
| 332 |
+
break
|
| 333 |
+
|
| 334 |
+
args.seq_length = original_seq_length
|
| 335 |
+
|
| 336 |
+
valid_fraction = _nanmean(np.asarray(attempt_valid_fractions, dtype=np.float32))
|
| 337 |
+
|
| 338 |
+
if not collected_sequences:
|
| 339 |
+
records.append(
|
| 340 |
+
{
|
| 341 |
+
"target": target_seq[:20],
|
| 342 |
+
"sequence": "",
|
| 343 |
+
"target_direction": d_star,
|
| 344 |
+
"is_valid": False,
|
| 345 |
+
"valid_fraction": valid_fraction,
|
| 346 |
+
"affinity": np.nan,
|
| 347 |
+
"gated_reward": np.nan,
|
| 348 |
+
"direction_oracle": np.nan,
|
| 349 |
+
"consistency_reward": np.nan,
|
| 350 |
+
"direction_accuracy": np.nan,
|
| 351 |
+
"success_rate": np.nan,
|
| 352 |
+
}
|
| 353 |
+
)
|
| 354 |
+
continue
|
| 355 |
+
|
| 356 |
+
if len(collected_sequences) > target_count:
|
| 357 |
+
collected_sequences = collected_sequences[:target_count]
|
| 358 |
+
|
| 359 |
+
gated_rewards, affinities, confidences, directions = reward_model.reward_fn.compute_gated_reward(collected_sequences)
|
| 360 |
+
direction_accuracy = _compute_direction_accuracy(directions, d_star)
|
| 361 |
+
consistency = d_star * (directions - 0.5)
|
| 362 |
+
success_rate = direction_accuracy * valid_fraction
|
| 363 |
+
|
| 364 |
+
valid_mask = np.array([analyzer.is_peptide(seq) for seq in collected_sequences], dtype=bool)
|
| 365 |
+
|
| 366 |
+
for idx, seq in enumerate(collected_sequences):
|
| 367 |
+
records.append(
|
| 368 |
+
{
|
| 369 |
+
"target": target_seq[:20],
|
| 370 |
+
"sequence": seq,
|
| 371 |
+
"target_direction": d_star,
|
| 372 |
+
"is_valid": bool(valid_mask[idx]) if valid_mask.size else False,
|
| 373 |
+
"valid_fraction": valid_fraction,
|
| 374 |
+
"affinity": float(affinities[idx]) if len(affinities) else np.nan,
|
| 375 |
+
"gated_reward": float(gated_rewards[idx]) if len(gated_rewards) else np.nan,
|
| 376 |
+
"direction_oracle": float(directions[idx]) if len(directions) else np.nan,
|
| 377 |
+
"consistency_reward": float(consistency[idx]) if len(consistency) else np.nan,
|
| 378 |
+
"direction_accuracy": float(direction_accuracy[idx]) if len(direction_accuracy) else np.nan,
|
| 379 |
+
"success_rate": float(success_rate[idx]) if len(success_rate) else np.nan,
|
| 380 |
+
}
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
if world_size > 1:
|
| 384 |
+
gathered: List[List[Dict[str, Any]]] = [None for _ in range(world_size)]
|
| 385 |
+
dist.all_gather_object(gathered, records)
|
| 386 |
+
if is_main_process():
|
| 387 |
+
records = [item for sub in gathered for item in sub]
|
| 388 |
+
else:
|
| 389 |
+
cleanup_distributed()
|
| 390 |
+
return
|
| 391 |
+
|
| 392 |
+
if is_main_process():
|
| 393 |
+
df = pd.DataFrame(records)
|
| 394 |
+
output_path = os.path.join(args.save_path, f"mcts_validation_epoch_{cli_args.epoch}.csv")
|
| 395 |
+
df.to_csv(output_path, index=False)
|
| 396 |
+
print(f"MCTS validation sequences saved to {output_path}")
|
| 397 |
+
|
| 398 |
+
affinities = df["affinity"].to_numpy(dtype=np.float32)
|
| 399 |
+
gated_rewards = df["gated_reward"].to_numpy(dtype=np.float32)
|
| 400 |
+
directions = df["direction_oracle"].to_numpy(dtype=np.float32)
|
| 401 |
+
target_directions = df["target_direction"].to_numpy(dtype=np.float32)
|
| 402 |
+
direction_correct = df["direction_accuracy"].to_numpy(dtype=np.float32)
|
| 403 |
+
valid_fractions = df["valid_fraction"].to_numpy(dtype=np.float32)
|
| 404 |
+
|
| 405 |
+
pos_mask = target_directions == 1.0
|
| 406 |
+
neg_mask = target_directions == -1.0
|
| 407 |
+
|
| 408 |
+
print("MCTS validation summary")
|
| 409 |
+
print(f" Affinity (d*=1): {_nanmean(affinities[pos_mask]):.4f} ± {_nanstd(affinities[pos_mask]):.4f}")
|
| 410 |
+
print(f" Affinity (d*=-1): {_nanmean(affinities[neg_mask]):.4f} ± {_nanstd(affinities[neg_mask]):.4f}")
|
| 411 |
+
print(f" Direction Accuracy (d*=1): {_nanmean(direction_correct[pos_mask]):.4f} ± {_nanstd(direction_correct[pos_mask]):.4f}")
|
| 412 |
+
print(f" Direction Accuracy (d*=-1): {_nanmean(direction_correct[neg_mask]):.4f} ± {_nanstd(direction_correct[neg_mask]):.4f}")
|
| 413 |
+
print(f" Gated Reward (overall): {_nanmean(gated_rewards):.4f} ± {_nanstd(gated_rewards):.4f}")
|
| 414 |
+
print(f" Valid Fraction: {_nanmean(valid_fractions):.4f} ± {_nanstd(valid_fractions):.4f}")
|
| 415 |
+
|
| 416 |
+
if world_size > 1:
|
| 417 |
+
cleanup_distributed()
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
if __name__ == "__main__":
|
| 421 |
+
main()
|
baselines/run_validation_td3b.py
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from types import SimpleNamespace
|
| 6 |
+
from typing import Any, Dict, List, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
|
| 13 |
+
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 14 |
+
if ROOT_DIR not in sys.path:
|
| 15 |
+
sys.path.insert(0, ROOT_DIR)
|
| 16 |
+
|
| 17 |
+
from diffusion import Diffusion
|
| 18 |
+
from configs.finetune_config import (
|
| 19 |
+
DiffusionConfig,
|
| 20 |
+
RoFormerConfig,
|
| 21 |
+
NoiseConfig,
|
| 22 |
+
TrainingConfig,
|
| 23 |
+
SamplingConfig,
|
| 24 |
+
EvalConfig,
|
| 25 |
+
OptimConfig,
|
| 26 |
+
MCTSConfig,
|
| 27 |
+
)
|
| 28 |
+
from finetune_utils import load_tokenizer, create_reward_function
|
| 29 |
+
from finetune_multi_target import TargetDataset
|
| 30 |
+
from distributed_utils import setup_distributed, cleanup_distributed, is_main_process
|
| 31 |
+
from scoring.functions.binding import MultiTargetBindingAffinity, TargetSpecificBindingAffinity
|
| 32 |
+
from td3b.direction_oracle import DirectionalOracle
|
| 33 |
+
from utils.app import PeptideAnalyzer
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _load_checkpoint(ckpt_path: str, device: torch.device) -> Dict[str, Any]:
|
| 37 |
+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 38 |
+
if not isinstance(ckpt, dict):
|
| 39 |
+
raise ValueError(f"Unsupported checkpoint format: {type(ckpt)}")
|
| 40 |
+
return ckpt
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _extract_state_and_config(ckpt: Dict[str, Any]) -> Dict[str, Any]:
|
| 44 |
+
state_dict = ckpt.get("model_state_dict") or ckpt.get("state_dict") or ckpt
|
| 45 |
+
config = ckpt.get("config") or {}
|
| 46 |
+
return {"state_dict": state_dict, "config": config}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _build_args(cfg: Dict[str, Any], cli: argparse.Namespace) -> argparse.Namespace:
|
| 50 |
+
defaults = {
|
| 51 |
+
"base_path": "To Be Added",
|
| 52 |
+
"seq_length": 200,
|
| 53 |
+
"sampling_eps": 1e-3,
|
| 54 |
+
"total_num_steps": 128,
|
| 55 |
+
"alpha": 0.1,
|
| 56 |
+
"hidden_dim": 768,
|
| 57 |
+
"num_layers": 8,
|
| 58 |
+
"num_heads": 8,
|
| 59 |
+
"min_affinity_threshold": 0.0,
|
| 60 |
+
"sigmoid_temperature": 0.1,
|
| 61 |
+
"val_samples_per_target": 8,
|
| 62 |
+
"direction_oracle_esm_name": "facebook/esm2_t33_650M_UR50D",
|
| 63 |
+
"direction_oracle_esm_cache_dir": None,
|
| 64 |
+
"direction_oracle_esm_local_files_only": False,
|
| 65 |
+
"direction_oracle_max_ligand_length": 768,
|
| 66 |
+
"direction_oracle_max_protein_length": 1024,
|
| 67 |
+
"direction_oracle_d_model": 256,
|
| 68 |
+
"direction_oracle_n_heads": 4,
|
| 69 |
+
"direction_oracle_n_self_attn_layers": 1,
|
| 70 |
+
"direction_oracle_n_bmca_layers": 2,
|
| 71 |
+
"direction_oracle_dropout": 0.3,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
merged = dict(defaults)
|
| 75 |
+
merged.update(cfg or {})
|
| 76 |
+
|
| 77 |
+
if cli.base_path is not None:
|
| 78 |
+
merged["base_path"] = cli.base_path
|
| 79 |
+
if cli.val_csv is not None:
|
| 80 |
+
merged["val_csv"] = cli.val_csv
|
| 81 |
+
if cli.save_path is not None:
|
| 82 |
+
merged["save_path"] = cli.save_path
|
| 83 |
+
if cli.device is not None:
|
| 84 |
+
merged["device"] = cli.device
|
| 85 |
+
if cli.val_samples_per_target is not None:
|
| 86 |
+
merged["val_samples_per_target"] = cli.val_samples_per_target
|
| 87 |
+
if getattr(cli, "num_pool", None) is not None:
|
| 88 |
+
merged["num_pool"] = cli.num_pool
|
| 89 |
+
if cli.seq_length is not None:
|
| 90 |
+
merged["seq_length"] = cli.seq_length
|
| 91 |
+
if cli.total_num_steps is not None:
|
| 92 |
+
merged["total_num_steps"] = cli.total_num_steps
|
| 93 |
+
if cli.sampling_eps is not None:
|
| 94 |
+
merged["sampling_eps"] = cli.sampling_eps
|
| 95 |
+
if cli.seed is not None:
|
| 96 |
+
merged["seed"] = cli.seed
|
| 97 |
+
|
| 98 |
+
args = SimpleNamespace(**merged)
|
| 99 |
+
|
| 100 |
+
base_tr2d2_path = os.path.join(args.base_path, "tr2d2-pep")
|
| 101 |
+
if not getattr(args, "direction_oracle_ckpt", None):
|
| 102 |
+
args.direction_oracle_ckpt = os.path.join(base_tr2d2_path, "direction_oracle.pt")
|
| 103 |
+
if not getattr(args, "direction_oracle_tr2d2_checkpoint", None):
|
| 104 |
+
args.direction_oracle_tr2d2_checkpoint = os.path.join(
|
| 105 |
+
base_tr2d2_path, "pretrained", "peptune-pretrained.ckpt"
|
| 106 |
+
)
|
| 107 |
+
if not getattr(args, "direction_oracle_tokenizer_vocab", None):
|
| 108 |
+
args.direction_oracle_tokenizer_vocab = os.path.join(
|
| 109 |
+
base_tr2d2_path, "tokenizer", "new_vocab.txt"
|
| 110 |
+
)
|
| 111 |
+
if not getattr(args, "direction_oracle_tokenizer_splits", None):
|
| 112 |
+
args.direction_oracle_tokenizer_splits = os.path.join(
|
| 113 |
+
base_tr2d2_path, "tokenizer", "new_splits.txt"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if not getattr(args, "save_path", None):
|
| 117 |
+
args.save_path = os.path.join(base_tr2d2_path, "results", "validation_runs")
|
| 118 |
+
|
| 119 |
+
os.makedirs(args.save_path, exist_ok=True)
|
| 120 |
+
return args
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _build_model(args: argparse.Namespace, state_dict: Dict[str, Any], device: torch.device) -> Diffusion:
|
| 124 |
+
config = DiffusionConfig(
|
| 125 |
+
roformer=RoFormerConfig(
|
| 126 |
+
hidden_size=args.hidden_dim,
|
| 127 |
+
n_layers=args.num_layers,
|
| 128 |
+
n_heads=args.num_heads,
|
| 129 |
+
),
|
| 130 |
+
noise=NoiseConfig(),
|
| 131 |
+
training=TrainingConfig(sampling_eps=args.sampling_eps),
|
| 132 |
+
sampling=SamplingConfig(
|
| 133 |
+
steps=args.total_num_steps,
|
| 134 |
+
sampling_eps=args.sampling_eps,
|
| 135 |
+
),
|
| 136 |
+
eval_cfg=EvalConfig(),
|
| 137 |
+
optim=OptimConfig(lr=getattr(args, "learning_rate", 3e-4)),
|
| 138 |
+
mcts=MCTSConfig(),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
tokenizer = load_tokenizer(args.base_path)
|
| 142 |
+
model = Diffusion(
|
| 143 |
+
config=config,
|
| 144 |
+
tokenizer=tokenizer,
|
| 145 |
+
device=device,
|
| 146 |
+
).to(device)
|
| 147 |
+
load_result = model.load_state_dict(state_dict, strict=False)
|
| 148 |
+
if load_result.missing_keys:
|
| 149 |
+
print(f"[load] Missing keys: {len(load_result.missing_keys)}")
|
| 150 |
+
if load_result.unexpected_keys:
|
| 151 |
+
print(f"[load] Unexpected keys: {len(load_result.unexpected_keys)}")
|
| 152 |
+
model.eval()
|
| 153 |
+
return model
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _build_oracle(args: argparse.Namespace, device: torch.device) -> DirectionalOracle:
|
| 157 |
+
oracle = DirectionalOracle(
|
| 158 |
+
model_ckpt=args.direction_oracle_ckpt,
|
| 159 |
+
tr2d2_checkpoint=args.direction_oracle_tr2d2_checkpoint,
|
| 160 |
+
tokenizer_vocab=args.direction_oracle_tokenizer_vocab,
|
| 161 |
+
tokenizer_splits=args.direction_oracle_tokenizer_splits,
|
| 162 |
+
esm_name=args.direction_oracle_esm_name,
|
| 163 |
+
d_model=args.direction_oracle_d_model,
|
| 164 |
+
n_heads=args.direction_oracle_n_heads,
|
| 165 |
+
n_self_attn_layers=args.direction_oracle_n_self_attn_layers,
|
| 166 |
+
n_bmca_layers=args.direction_oracle_n_bmca_layers,
|
| 167 |
+
dropout=args.direction_oracle_dropout,
|
| 168 |
+
max_ligand_length=args.direction_oracle_max_ligand_length,
|
| 169 |
+
max_protein_length=args.direction_oracle_max_protein_length,
|
| 170 |
+
device=device,
|
| 171 |
+
esm_cache_dir=args.direction_oracle_esm_cache_dir,
|
| 172 |
+
esm_local_files_only=args.direction_oracle_esm_local_files_only,
|
| 173 |
+
)
|
| 174 |
+
oracle.eval()
|
| 175 |
+
return oracle
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _sample_sequences(
|
| 179 |
+
model: Diffusion,
|
| 180 |
+
batch_size: int,
|
| 181 |
+
seq_length: int,
|
| 182 |
+
total_num_steps: int,
|
| 183 |
+
sampling_eps: float,
|
| 184 |
+
) -> torch.Tensor:
|
| 185 |
+
model.backbone.eval()
|
| 186 |
+
model.noise.eval()
|
| 187 |
+
|
| 188 |
+
x_rollout = model.sample_prior(batch_size, seq_length).to(model.device, dtype=torch.long)
|
| 189 |
+
|
| 190 |
+
timesteps = torch.linspace(1, sampling_eps, total_num_steps + 1, device=model.device)
|
| 191 |
+
dt = torch.tensor((1 - sampling_eps) / total_num_steps, device=model.device)
|
| 192 |
+
|
| 193 |
+
for i in range(total_num_steps):
|
| 194 |
+
t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=model.device)
|
| 195 |
+
_, x_next = model.single_reverse_step(x_rollout, t=t, dt=dt)
|
| 196 |
+
x_rollout = x_next.to(model.device)
|
| 197 |
+
|
| 198 |
+
if (x_rollout == model.mask_index).any().item():
|
| 199 |
+
_, x_next = model.single_noise_removal(x_rollout, t=t, dt=dt)
|
| 200 |
+
x_rollout = x_next.to(model.device)
|
| 201 |
+
|
| 202 |
+
return x_rollout
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _score_sequences(reward_model, sequences: List[str]) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 206 |
+
if not sequences:
|
| 207 |
+
empty = np.array([], dtype=np.float32)
|
| 208 |
+
return empty, empty, empty, empty
|
| 209 |
+
|
| 210 |
+
try:
|
| 211 |
+
result = reward_model(sequences)
|
| 212 |
+
if isinstance(result, tuple):
|
| 213 |
+
total_rewards, info = result
|
| 214 |
+
affinity = np.asarray(info.get("affinities", total_rewards), dtype=np.float32)
|
| 215 |
+
confidence = np.asarray(info.get("confidences", np.ones_like(affinity)), dtype=np.float32)
|
| 216 |
+
directions = np.asarray(info.get("directions", np.zeros_like(affinity)), dtype=np.float32)
|
| 217 |
+
else:
|
| 218 |
+
total_rewards = np.asarray(result, dtype=np.float32)
|
| 219 |
+
if total_rewards.ndim > 1:
|
| 220 |
+
affinity = total_rewards[:, 0]
|
| 221 |
+
else:
|
| 222 |
+
affinity = total_rewards
|
| 223 |
+
confidence = np.ones_like(affinity, dtype=np.float32)
|
| 224 |
+
directions = np.zeros_like(affinity, dtype=np.float32)
|
| 225 |
+
return np.asarray(total_rewards, dtype=np.float32), affinity, directions, confidence
|
| 226 |
+
except Exception:
|
| 227 |
+
total_rewards = np.full(len(sequences), np.nan, dtype=np.float32)
|
| 228 |
+
affinity = np.full(len(sequences), np.nan, dtype=np.float32)
|
| 229 |
+
directions = np.full(len(sequences), np.nan, dtype=np.float32)
|
| 230 |
+
confidence = np.full(len(sequences), np.nan, dtype=np.float32)
|
| 231 |
+
for idx, seq in enumerate(sequences):
|
| 232 |
+
try:
|
| 233 |
+
result = reward_model([seq])
|
| 234 |
+
if isinstance(result, tuple):
|
| 235 |
+
rewards, info = result
|
| 236 |
+
total_rewards[idx] = float(np.asarray(rewards)[0])
|
| 237 |
+
affinity[idx] = float(np.asarray(info.get("affinities", rewards))[0])
|
| 238 |
+
confidence[idx] = float(np.asarray(info.get("confidences", [np.nan]))[0])
|
| 239 |
+
directions[idx] = float(np.asarray(info.get("directions", [np.nan]))[0])
|
| 240 |
+
else:
|
| 241 |
+
reward = np.asarray(result)
|
| 242 |
+
total_rewards[idx] = float(reward[0]) if reward.size else np.nan
|
| 243 |
+
affinity[idx] = total_rewards[idx]
|
| 244 |
+
except Exception:
|
| 245 |
+
continue
|
| 246 |
+
return total_rewards, affinity, directions, confidence
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _compute_direction_accuracy(directions: np.ndarray, d_star: float) -> np.ndarray:
|
| 250 |
+
if directions.size == 0:
|
| 251 |
+
return directions
|
| 252 |
+
acc = np.full(directions.shape, np.nan, dtype=np.float32)
|
| 253 |
+
valid = np.isfinite(directions)
|
| 254 |
+
if not valid.any():
|
| 255 |
+
return acc
|
| 256 |
+
if d_star > 0:
|
| 257 |
+
acc[valid] = (directions[valid] >= 0.5).astype(np.float32)
|
| 258 |
+
else:
|
| 259 |
+
acc[valid] = (directions[valid] < 0.5).astype(np.float32)
|
| 260 |
+
return acc
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def _nanmean(values: np.ndarray) -> float:
|
| 264 |
+
return float(np.nanmean(values)) if values.size else float("nan")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def _nanstd(values: np.ndarray) -> float:
|
| 268 |
+
return float(np.nanstd(values)) if values.size else float("nan")
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def main() -> None:
|
| 272 |
+
parser = argparse.ArgumentParser(description="Run TD3B validation from a saved checkpoint.")
|
| 273 |
+
parser.add_argument("--ckpt_path", required=True, help="Path to saved checkpoint (.ckpt)")
|
| 274 |
+
parser.add_argument("--val_csv", required=True, help="Validation CSV path")
|
| 275 |
+
parser.add_argument("--device", default="cuda", help="Device string (e.g., cuda:0 or cpu)")
|
| 276 |
+
parser.add_argument("--base_path", default=None, help="Base path for TR2-D2")
|
| 277 |
+
parser.add_argument("--save_path", default=None, help="Output directory for validation CSV")
|
| 278 |
+
parser.add_argument("--epoch", type=int, default=0, help="Epoch number to label outputs")
|
| 279 |
+
parser.add_argument("--val_samples_per_target", type=int, default=None, help="Samples per target")
|
| 280 |
+
parser.add_argument("--num_pool", type=int, default=None,
|
| 281 |
+
help="Number of candidate sequences to sample before resampling")
|
| 282 |
+
parser.add_argument("--seq_length", type=int, default=None, help="Fallback sequence length")
|
| 283 |
+
parser.add_argument("--total_num_steps", type=int, default=None, help="Diffusion steps")
|
| 284 |
+
parser.add_argument("--sampling_eps", type=float, default=None, help="Sampling epsilon")
|
| 285 |
+
parser.add_argument("--seed", type=int, default=None, help="Base random seed")
|
| 286 |
+
parser.add_argument("--no_resample", action="store_true", help="Disable reward-weighted resampling")
|
| 287 |
+
parser.add_argument("--resample_without_replacement", action="store_true",
|
| 288 |
+
help="Resample without replacement when possible")
|
| 289 |
+
parser.add_argument("--resample_alpha", type=float, default=None,
|
| 290 |
+
help="Override alpha for resampling weights")
|
| 291 |
+
cli_args = parser.parse_args()
|
| 292 |
+
|
| 293 |
+
rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 294 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 295 |
+
|
| 296 |
+
if world_size > 1:
|
| 297 |
+
setup_distributed(rank, world_size)
|
| 298 |
+
device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
|
| 299 |
+
else:
|
| 300 |
+
device = torch.device(cli_args.device)
|
| 301 |
+
|
| 302 |
+
if cli_args.seed is not None:
|
| 303 |
+
torch.manual_seed(cli_args.seed + rank)
|
| 304 |
+
np.random.seed(cli_args.seed + rank)
|
| 305 |
+
|
| 306 |
+
ckpt = _load_checkpoint(cli_args.ckpt_path, device)
|
| 307 |
+
payload = _extract_state_and_config(ckpt)
|
| 308 |
+
args = _build_args(payload["config"], cli_args)
|
| 309 |
+
|
| 310 |
+
tokenizer = load_tokenizer(args.base_path)
|
| 311 |
+
val_dataset = TargetDataset(args.val_csv, tokenizer=tokenizer)
|
| 312 |
+
|
| 313 |
+
policy_model = _build_model(args, payload["state_dict"], device)
|
| 314 |
+
|
| 315 |
+
multi_target_affinity = MultiTargetBindingAffinity(
|
| 316 |
+
tokenizer=tokenizer,
|
| 317 |
+
base_path=args.base_path,
|
| 318 |
+
device=device,
|
| 319 |
+
emb_model=policy_model.backbone,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
directional_oracle = _build_oracle(args, device)
|
| 323 |
+
analyzer = PeptideAnalyzer()
|
| 324 |
+
protein_token_cache: Dict[str, torch.Tensor] = {}
|
| 325 |
+
|
| 326 |
+
resample_enabled = not cli_args.no_resample
|
| 327 |
+
resample_with_replacement = not cli_args.resample_without_replacement
|
| 328 |
+
resample_alpha = cli_args.resample_alpha if cli_args.resample_alpha is not None else args.alpha
|
| 329 |
+
|
| 330 |
+
all_targets = val_dataset.get_all_targets()
|
| 331 |
+
if world_size > 1:
|
| 332 |
+
my_targets = all_targets[rank::world_size]
|
| 333 |
+
else:
|
| 334 |
+
my_targets = all_targets
|
| 335 |
+
|
| 336 |
+
records: List[Dict[str, Any]] = []
|
| 337 |
+
resampled_records: List[Dict[str, Any]] = []
|
| 338 |
+
resampled_affinity_pos: List[float] = []
|
| 339 |
+
resampled_affinity_neg: List[float] = []
|
| 340 |
+
resampled_acc_pos: List[float] = []
|
| 341 |
+
resampled_acc_neg: List[float] = []
|
| 342 |
+
resampled_gated_rewards: List[float] = []
|
| 343 |
+
|
| 344 |
+
with torch.no_grad():
|
| 345 |
+
for target_seq in my_targets:
|
| 346 |
+
target_protein_tokens = protein_token_cache.get(target_seq)
|
| 347 |
+
if target_protein_tokens is None:
|
| 348 |
+
target_protein_tokens = directional_oracle.encode_protein(target_seq)
|
| 349 |
+
protein_token_cache[target_seq] = target_protein_tokens
|
| 350 |
+
|
| 351 |
+
for direction_name, d_star in [("agonist", 1.0), ("antagonist", -1.0)]:
|
| 352 |
+
target_length = val_dataset.get_sequence_length(target_seq, direction_name)
|
| 353 |
+
max_len = 1035
|
| 354 |
+
if target_length > max_len:
|
| 355 |
+
target_length = max_len
|
| 356 |
+
|
| 357 |
+
target_affinity = TargetSpecificBindingAffinity(multi_target_affinity, target_seq)
|
| 358 |
+
reward_model = create_reward_function(
|
| 359 |
+
affinity_predictor=target_affinity,
|
| 360 |
+
directional_oracle=directional_oracle,
|
| 361 |
+
target_direction=d_star,
|
| 362 |
+
target_protein_tokens=target_protein_tokens,
|
| 363 |
+
tokenizer=tokenizer,
|
| 364 |
+
device=device,
|
| 365 |
+
min_affinity_threshold=args.min_affinity_threshold,
|
| 366 |
+
use_confidence_weighting=True,
|
| 367 |
+
temperature=args.sigmoid_temperature,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
pool_size = args.val_samples_per_target
|
| 371 |
+
if getattr(args, "num_pool", None) is not None:
|
| 372 |
+
pool_size = int(args.num_pool)
|
| 373 |
+
if pool_size < args.val_samples_per_target:
|
| 374 |
+
print(
|
| 375 |
+
f"[warn] num_pool ({pool_size}) < val_samples_per_target "
|
| 376 |
+
f"({args.val_samples_per_target}); using val_samples_per_target."
|
| 377 |
+
)
|
| 378 |
+
pool_size = args.val_samples_per_target
|
| 379 |
+
|
| 380 |
+
x_eval = _sample_sequences(
|
| 381 |
+
policy_model,
|
| 382 |
+
batch_size=pool_size,
|
| 383 |
+
seq_length=target_length,
|
| 384 |
+
total_num_steps=args.total_num_steps,
|
| 385 |
+
sampling_eps=args.sampling_eps,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
sequences = tokenizer.batch_decode(x_eval)
|
| 389 |
+
valid_mask = np.array([analyzer.is_peptide(seq) for seq in sequences], dtype=bool)
|
| 390 |
+
valid_fraction = float(valid_mask.mean()) if valid_mask.size else 0.0
|
| 391 |
+
|
| 392 |
+
gated_rewards, affinities, directions, confidences = _score_sequences(reward_model, sequences)
|
| 393 |
+
direction_accuracy = _compute_direction_accuracy(directions, d_star)
|
| 394 |
+
consistency = d_star * (directions - 0.5)
|
| 395 |
+
success_rate = direction_accuracy * valid_fraction
|
| 396 |
+
|
| 397 |
+
if resample_enabled:
|
| 398 |
+
finite_rewards = np.isfinite(gated_rewards)
|
| 399 |
+
if np.any(finite_rewards):
|
| 400 |
+
rewards_t = torch.as_tensor(gated_rewards[finite_rewards], device=device)
|
| 401 |
+
alpha = max(float(resample_alpha), 1e-6)
|
| 402 |
+
weights = torch.softmax(rewards_t / alpha, dim=0)
|
| 403 |
+
if resample_with_replacement:
|
| 404 |
+
num_samples = args.val_samples_per_target
|
| 405 |
+
idx = torch.multinomial(weights, num_samples=num_samples, replacement=True)
|
| 406 |
+
else:
|
| 407 |
+
num_samples = min(args.val_samples_per_target, int(finite_rewards.sum()))
|
| 408 |
+
idx = torch.multinomial(weights, num_samples=num_samples, replacement=False)
|
| 409 |
+
|
| 410 |
+
valid_idx = np.where(finite_rewards)[0]
|
| 411 |
+
chosen = valid_idx[idx.detach().cpu().numpy()]
|
| 412 |
+
if d_star > 0:
|
| 413 |
+
resampled_affinity_pos.extend(affinities[chosen].tolist())
|
| 414 |
+
resampled_acc_pos.extend(direction_accuracy[chosen].tolist())
|
| 415 |
+
else:
|
| 416 |
+
resampled_affinity_neg.extend(affinities[chosen].tolist())
|
| 417 |
+
resampled_acc_neg.extend(direction_accuracy[chosen].tolist())
|
| 418 |
+
resampled_gated_rewards.extend(gated_rewards[chosen].tolist())
|
| 419 |
+
|
| 420 |
+
for picked in chosen.tolist():
|
| 421 |
+
resampled_records.append({
|
| 422 |
+
"target": target_seq[:20],
|
| 423 |
+
"sequence": sequences[picked],
|
| 424 |
+
"target_direction": d_star,
|
| 425 |
+
"is_valid": bool(valid_mask[picked]) if valid_mask.size else False,
|
| 426 |
+
"affinity": float(affinities[picked]) if affinities.size else np.nan,
|
| 427 |
+
"gated_reward": float(gated_rewards[picked]) if gated_rewards.size else np.nan,
|
| 428 |
+
"direction_oracle": float(directions[picked]) if directions.size else np.nan,
|
| 429 |
+
"consistency_reward": float(consistency[picked]) if consistency.size else np.nan,
|
| 430 |
+
"direction_accuracy": float(direction_accuracy[picked]) if direction_accuracy.size else np.nan,
|
| 431 |
+
"success_rate": float(success_rate[picked]) if success_rate.size else np.nan,
|
| 432 |
+
})
|
| 433 |
+
|
| 434 |
+
for idx, seq in enumerate(sequences):
|
| 435 |
+
records.append({
|
| 436 |
+
"target": target_seq[:20],
|
| 437 |
+
"sequence": seq,
|
| 438 |
+
"target_direction": d_star,
|
| 439 |
+
"is_valid": bool(valid_mask[idx]) if valid_mask.size else False,
|
| 440 |
+
"affinity": float(affinities[idx]) if affinities.size else np.nan,
|
| 441 |
+
"gated_reward": float(gated_rewards[idx]) if gated_rewards.size else np.nan,
|
| 442 |
+
"direction_oracle": float(directions[idx]) if directions.size else np.nan,
|
| 443 |
+
"consistency_reward": float(consistency[idx]) if consistency.size else np.nan,
|
| 444 |
+
"direction_accuracy": float(direction_accuracy[idx]) if direction_accuracy.size else np.nan,
|
| 445 |
+
"success_rate": float(success_rate[idx]) if success_rate.size else np.nan,
|
| 446 |
+
})
|
| 447 |
+
|
| 448 |
+
if world_size > 1:
|
| 449 |
+
gathered: List[List[Dict[str, Any]]] = [None for _ in range(world_size)]
|
| 450 |
+
dist.all_gather_object(gathered, records)
|
| 451 |
+
if is_main_process():
|
| 452 |
+
all_records = [item for sub in gathered for item in sub]
|
| 453 |
+
else:
|
| 454 |
+
all_records = []
|
| 455 |
+
else:
|
| 456 |
+
all_records = records
|
| 457 |
+
|
| 458 |
+
if world_size > 1:
|
| 459 |
+
gathered_resampled_records: List[List[Dict[str, Any]]] = [None for _ in range(world_size)]
|
| 460 |
+
dist.all_gather_object(gathered_resampled_records, resampled_records)
|
| 461 |
+
if is_main_process():
|
| 462 |
+
all_resampled_records = [item for sub in gathered_resampled_records for item in sub]
|
| 463 |
+
else:
|
| 464 |
+
all_resampled_records = []
|
| 465 |
+
else:
|
| 466 |
+
all_resampled_records = resampled_records
|
| 467 |
+
|
| 468 |
+
if world_size > 1:
|
| 469 |
+
resampled_payload = {
|
| 470 |
+
"aff_pos": resampled_affinity_pos,
|
| 471 |
+
"aff_neg": resampled_affinity_neg,
|
| 472 |
+
"acc_pos": resampled_acc_pos,
|
| 473 |
+
"acc_neg": resampled_acc_neg,
|
| 474 |
+
"gated": resampled_gated_rewards,
|
| 475 |
+
}
|
| 476 |
+
gathered_resampled = [None for _ in range(world_size)]
|
| 477 |
+
dist.all_gather_object(gathered_resampled, resampled_payload)
|
| 478 |
+
if is_main_process():
|
| 479 |
+
resampled_affinity_pos = []
|
| 480 |
+
resampled_affinity_neg = []
|
| 481 |
+
resampled_acc_pos = []
|
| 482 |
+
resampled_acc_neg = []
|
| 483 |
+
resampled_gated_rewards = []
|
| 484 |
+
for payload in gathered_resampled:
|
| 485 |
+
resampled_affinity_pos.extend(payload.get("aff_pos", []))
|
| 486 |
+
resampled_affinity_neg.extend(payload.get("aff_neg", []))
|
| 487 |
+
resampled_acc_pos.extend(payload.get("acc_pos", []))
|
| 488 |
+
resampled_acc_neg.extend(payload.get("acc_neg", []))
|
| 489 |
+
resampled_gated_rewards.extend(payload.get("gated", []))
|
| 490 |
+
|
| 491 |
+
if is_main_process():
|
| 492 |
+
df = pd.DataFrame(all_records)
|
| 493 |
+
output_path = os.path.join(args.save_path, f"validation_epoch_{cli_args.epoch}.csv")
|
| 494 |
+
df.to_csv(output_path, index=False)
|
| 495 |
+
print(f"Validation sequences saved to {output_path}")
|
| 496 |
+
|
| 497 |
+
if resample_enabled:
|
| 498 |
+
if all_resampled_records:
|
| 499 |
+
resampled_df = pd.DataFrame(all_resampled_records)
|
| 500 |
+
resampled_path = os.path.join(args.save_path, f"validation_epoch_{cli_args.epoch}_resampled.csv")
|
| 501 |
+
resampled_df.to_csv(resampled_path, index=False)
|
| 502 |
+
print(f"Resampled sequences saved to {resampled_path}")
|
| 503 |
+
else:
|
| 504 |
+
print("Resampling enabled but no finite rewards were available to select.")
|
| 505 |
+
|
| 506 |
+
if resample_enabled and resampled_gated_rewards:
|
| 507 |
+
aff_mean_pos = _nanmean(np.asarray(resampled_affinity_pos, dtype=np.float32))
|
| 508 |
+
aff_std_pos = _nanstd(np.asarray(resampled_affinity_pos, dtype=np.float32))
|
| 509 |
+
acc_mean_pos = _nanmean(np.asarray(resampled_acc_pos, dtype=np.float32))
|
| 510 |
+
acc_std_pos = _nanstd(np.asarray(resampled_acc_pos, dtype=np.float32))
|
| 511 |
+
|
| 512 |
+
aff_mean_neg = _nanmean(np.asarray(resampled_affinity_neg, dtype=np.float32))
|
| 513 |
+
aff_std_neg = _nanstd(np.asarray(resampled_affinity_neg, dtype=np.float32))
|
| 514 |
+
acc_mean_neg = _nanmean(np.asarray(resampled_acc_neg, dtype=np.float32))
|
| 515 |
+
acc_std_neg = _nanstd(np.asarray(resampled_acc_neg, dtype=np.float32))
|
| 516 |
+
|
| 517 |
+
gated = np.asarray(resampled_gated_rewards, dtype=np.float32)
|
| 518 |
+
gated_mean = _nanmean(gated)
|
| 519 |
+
gated_std = _nanstd(gated)
|
| 520 |
+
else:
|
| 521 |
+
def _stats_for_direction(d_star: float) -> Tuple[float, float, float, float]:
|
| 522 |
+
subset = df[df["target_direction"] == d_star]
|
| 523 |
+
affinity = subset["affinity"].to_numpy(dtype=np.float32)
|
| 524 |
+
direction_acc = subset["direction_accuracy"].to_numpy(dtype=np.float32)
|
| 525 |
+
return _nanmean(affinity), _nanstd(affinity), _nanmean(direction_acc), _nanstd(direction_acc)
|
| 526 |
+
|
| 527 |
+
aff_mean_pos, aff_std_pos, acc_mean_pos, acc_std_pos = _stats_for_direction(1.0)
|
| 528 |
+
aff_mean_neg, aff_std_neg, acc_mean_neg, acc_std_neg = _stats_for_direction(-1.0)
|
| 529 |
+
gated = df["gated_reward"].to_numpy(dtype=np.float32)
|
| 530 |
+
gated_mean = _nanmean(gated)
|
| 531 |
+
gated_std = _nanstd(gated)
|
| 532 |
+
|
| 533 |
+
print("Validation summary")
|
| 534 |
+
print(f" Affinity (d*=1): {aff_mean_pos:.4f} ± {aff_std_pos:.4f}")
|
| 535 |
+
print(f" Affinity (d*=-1): {aff_mean_neg:.4f} ± {aff_std_neg:.4f}")
|
| 536 |
+
print(f" Direction Accuracy (d*=1): {acc_mean_pos:.4f} ± {acc_std_pos:.4f}")
|
| 537 |
+
print(f" Direction Accuracy (d*=-1): {acc_mean_neg:.4f} ± {acc_std_neg:.4f}")
|
| 538 |
+
print(f" Gated Reward (overall): {gated_mean:.4f} ± {gated_std:.4f}")
|
| 539 |
+
|
| 540 |
+
if world_size > 1:
|
| 541 |
+
cleanup_distributed()
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
if __name__ == "__main__":
|
| 545 |
+
main()
|
| 546 |
+
|
| 547 |
+
# Running command:
|
| 548 |
+
# CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29501 run_validation_td3b.py --ckpt_path To Be Added --val_csv To Be Added --device cuda:0 --save_path To Be Added --epoch 99 --val_samples_per_target 8 --seed 42 --resample_alpha 0.1
|
baselines/sampling_setup.py
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import time
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 10 |
+
if ROOT_DIR not in sys.path:
|
| 11 |
+
sys.path.insert(0, ROOT_DIR)
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from hydra import compose, initialize_config_dir
|
| 15 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 16 |
+
|
| 17 |
+
from diffusion import Diffusion
|
| 18 |
+
from scoring.scoring_functions import ScoringFunctions
|
| 19 |
+
from scoring.functions.binding import MultiTargetBindingAffinity
|
| 20 |
+
from td3b.direction_oracle import DirectionalOracle, resolve_device
|
| 21 |
+
from td3b.data_utils import peptide_seq_to_smiles, smiles_token_length
|
| 22 |
+
|
| 23 |
+
from baselines.baselines import (
|
| 24 |
+
RewardInputs,
|
| 25 |
+
RewardWrapper,
|
| 26 |
+
classifier_guidance,
|
| 27 |
+
peptune_mctg_sampling,
|
| 28 |
+
sequential_monte_carlo,
|
| 29 |
+
twisted_diffusion_sampler,
|
| 30 |
+
unguided_sampling,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class ProteinTokenizer:
|
| 39 |
+
aa_to_id: Dict[str, int]
|
| 40 |
+
pad_id: int = 0
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def default(cls) -> "ProteinTokenizer":
|
| 44 |
+
aa_to_id = {aa: idx + 1 for idx, aa in enumerate(AMINO_ACIDS)}
|
| 45 |
+
return cls(aa_to_id=aa_to_id, pad_id=0)
|
| 46 |
+
|
| 47 |
+
def encode(self, seq: str) -> torch.Tensor:
|
| 48 |
+
ids = [self.aa_to_id.get(aa, self.pad_id) for aa in seq]
|
| 49 |
+
return torch.tensor([ids], dtype=torch.long)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_base_model(
|
| 53 |
+
ckpt_path: str,
|
| 54 |
+
device: str,
|
| 55 |
+
config_name: str = "peptune_config.yaml",
|
| 56 |
+
) -> Diffusion:
|
| 57 |
+
GlobalHydra.instance().clear()
|
| 58 |
+
config_dir = os.path.join(os.path.dirname(__file__), "..", "configs")
|
| 59 |
+
initialize_config_dir(config_dir=config_dir, job_name="load_model")
|
| 60 |
+
cfg = compose(config_name=config_name)
|
| 61 |
+
try:
|
| 62 |
+
model = Diffusion.load_from_checkpoint(
|
| 63 |
+
ckpt_path,
|
| 64 |
+
config=cfg,
|
| 65 |
+
mode="eval",
|
| 66 |
+
device=device,
|
| 67 |
+
map_location=device,
|
| 68 |
+
)
|
| 69 |
+
model.eval()
|
| 70 |
+
return model
|
| 71 |
+
except Exception as exc:
|
| 72 |
+
print(f"[load_base_model] Lightning load failed, falling back to raw state_dict: {exc}")
|
| 73 |
+
|
| 74 |
+
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 75 |
+
if isinstance(checkpoint, dict):
|
| 76 |
+
if "model_state_dict" in checkpoint:
|
| 77 |
+
state_dict = checkpoint["model_state_dict"]
|
| 78 |
+
elif "state_dict" in checkpoint:
|
| 79 |
+
state_dict = checkpoint["state_dict"]
|
| 80 |
+
else:
|
| 81 |
+
state_dict = checkpoint
|
| 82 |
+
else:
|
| 83 |
+
raise ValueError(f"Unsupported checkpoint format: {type(checkpoint)}")
|
| 84 |
+
|
| 85 |
+
model = Diffusion(
|
| 86 |
+
config=cfg,
|
| 87 |
+
mode="eval",
|
| 88 |
+
device=device,
|
| 89 |
+
)
|
| 90 |
+
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
| 91 |
+
if missing:
|
| 92 |
+
print(f"[load_base_model] Missing keys: {len(missing)}")
|
| 93 |
+
if unexpected:
|
| 94 |
+
print(f"[load_base_model] Unexpected keys: {len(unexpected)}")
|
| 95 |
+
model.eval()
|
| 96 |
+
model.to(device)
|
| 97 |
+
return model
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def load_reward_models(
|
| 101 |
+
prot_seq: Optional[str],
|
| 102 |
+
device: str,
|
| 103 |
+
base_model: Optional[Diffusion] = None,
|
| 104 |
+
base_path: Optional[str] = None,
|
| 105 |
+
multi_target: bool = False,
|
| 106 |
+
score_func_names: Optional[List[str]] = None,
|
| 107 |
+
):
|
| 108 |
+
if multi_target:
|
| 109 |
+
if base_model is None or base_path is None:
|
| 110 |
+
raise ValueError("base_model and base_path are required for multi-target affinity.")
|
| 111 |
+
return MultiTargetBindingAffinity(
|
| 112 |
+
tokenizer=base_model.tokenizer,
|
| 113 |
+
base_path=base_path,
|
| 114 |
+
device=device,
|
| 115 |
+
emb_model=base_model.backbone,
|
| 116 |
+
)
|
| 117 |
+
if score_func_names is None:
|
| 118 |
+
score_func_names = [
|
| 119 |
+
"binding_affinity1",
|
| 120 |
+
"solubility",
|
| 121 |
+
"hemolysis",
|
| 122 |
+
"nonfouling",
|
| 123 |
+
"permeability",
|
| 124 |
+
]
|
| 125 |
+
if prot_seq is None:
|
| 126 |
+
raise ValueError("prot_seq is required for single-target scoring.")
|
| 127 |
+
return ScoringFunctions(score_func_names, prot_seqs=[prot_seq], device=device)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def load_direction_oracle(args, device: str) -> DirectionalOracle:
|
| 131 |
+
oracle = DirectionalOracle(
|
| 132 |
+
model_ckpt=args.direction_oracle_ckpt,
|
| 133 |
+
tr2d2_checkpoint=args.direction_oracle_tr2d2_checkpoint,
|
| 134 |
+
tokenizer_vocab=args.direction_oracle_tokenizer_vocab,
|
| 135 |
+
tokenizer_splits=args.direction_oracle_tokenizer_splits,
|
| 136 |
+
esm_name=args.direction_oracle_esm_name,
|
| 137 |
+
d_model=args.direction_oracle_d_model,
|
| 138 |
+
n_heads=args.direction_oracle_n_heads,
|
| 139 |
+
n_self_attn_layers=args.direction_oracle_n_self_attn_layers,
|
| 140 |
+
n_bmca_layers=args.direction_oracle_n_bmca_layers,
|
| 141 |
+
dropout=args.direction_oracle_dropout,
|
| 142 |
+
max_ligand_length=args.direction_oracle_max_ligand_length,
|
| 143 |
+
max_protein_length=args.direction_oracle_max_protein_length,
|
| 144 |
+
device=device,
|
| 145 |
+
esm_cache_dir=args.direction_oracle_esm_cache_dir,
|
| 146 |
+
esm_local_files_only=args.direction_oracle_esm_local_files_only,
|
| 147 |
+
)
|
| 148 |
+
oracle.eval()
|
| 149 |
+
return oracle
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def run_baseline(
|
| 153 |
+
baseline: str,
|
| 154 |
+
base_model: Diffusion,
|
| 155 |
+
reward_fn: RewardWrapper,
|
| 156 |
+
batch_size: int,
|
| 157 |
+
seq_length: int,
|
| 158 |
+
num_steps: int,
|
| 159 |
+
guidance_scale: float,
|
| 160 |
+
alpha: float,
|
| 161 |
+
guidance_steps: Optional[int],
|
| 162 |
+
mcts_iterations: int,
|
| 163 |
+
num_children: int,
|
| 164 |
+
sample_prob_weight: float,
|
| 165 |
+
invalid_penalty: float,
|
| 166 |
+
pareto_max_size: Optional[int],
|
| 167 |
+
) -> Dict[str, torch.Tensor]:
|
| 168 |
+
baseline = baseline.lower()
|
| 169 |
+
if baseline == "cg":
|
| 170 |
+
return classifier_guidance(
|
| 171 |
+
base_model,
|
| 172 |
+
reward_fn,
|
| 173 |
+
batch_size=batch_size,
|
| 174 |
+
seq_length=seq_length,
|
| 175 |
+
num_steps=num_steps,
|
| 176 |
+
guidance_scale=guidance_scale,
|
| 177 |
+
guidance_steps=guidance_steps,
|
| 178 |
+
)
|
| 179 |
+
if baseline == "unguided":
|
| 180 |
+
return unguided_sampling(
|
| 181 |
+
base_model,
|
| 182 |
+
batch_size=batch_size,
|
| 183 |
+
seq_length=seq_length,
|
| 184 |
+
num_steps=num_steps,
|
| 185 |
+
)
|
| 186 |
+
if baseline == "smc":
|
| 187 |
+
return sequential_monte_carlo(
|
| 188 |
+
base_model,
|
| 189 |
+
reward_fn,
|
| 190 |
+
batch_size=batch_size,
|
| 191 |
+
seq_length=seq_length,
|
| 192 |
+
num_steps=num_steps,
|
| 193 |
+
alpha=alpha,
|
| 194 |
+
)
|
| 195 |
+
if baseline == "tds":
|
| 196 |
+
return twisted_diffusion_sampler(
|
| 197 |
+
base_model,
|
| 198 |
+
reward_fn,
|
| 199 |
+
batch_size=batch_size,
|
| 200 |
+
seq_length=seq_length,
|
| 201 |
+
num_steps=num_steps,
|
| 202 |
+
guidance_scale=guidance_scale,
|
| 203 |
+
alpha=alpha,
|
| 204 |
+
guidance_steps=guidance_steps,
|
| 205 |
+
)
|
| 206 |
+
if baseline == "peptune":
|
| 207 |
+
return peptune_mctg_sampling(
|
| 208 |
+
base_model,
|
| 209 |
+
reward_fn,
|
| 210 |
+
batch_size=batch_size,
|
| 211 |
+
seq_length=seq_length,
|
| 212 |
+
num_steps=num_steps,
|
| 213 |
+
mcts_iterations=mcts_iterations,
|
| 214 |
+
num_children=num_children,
|
| 215 |
+
alpha=alpha,
|
| 216 |
+
sample_prob_weight=sample_prob_weight,
|
| 217 |
+
invalid_penalty=invalid_penalty,
|
| 218 |
+
pareto_max_size=pareto_max_size,
|
| 219 |
+
)
|
| 220 |
+
raise ValueError(f"Unknown baseline: {baseline}")
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def main():
|
| 224 |
+
parser = argparse.ArgumentParser()
|
| 225 |
+
parser.add_argument("--ckpt_path", type=str, required=True)
|
| 226 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
| 227 |
+
parser.add_argument("--baseline", type=str, default="cg", choices=["cg", "smc", "tds", "unguided", "peptune"])
|
| 228 |
+
parser.add_argument("--prot_seq", type=str, default=None)
|
| 229 |
+
parser.add_argument("--targets_csv", type=str, default=None)
|
| 230 |
+
parser.add_argument("--d_star", type=float, default=1.0)
|
| 231 |
+
parser.add_argument("--batch_size", type=int, default=32)
|
| 232 |
+
parser.add_argument("--seq_length", type=int, default=200)
|
| 233 |
+
parser.add_argument("--binder_seq", type=str, default=None)
|
| 234 |
+
parser.add_argument("--num_steps", type=int, default=128)
|
| 235 |
+
parser.add_argument("--guidance_scale", type=float, default=1.0)
|
| 236 |
+
parser.add_argument("--alpha", type=float, default=0.1)
|
| 237 |
+
parser.add_argument("--reward_alpha", type=float, default=None)
|
| 238 |
+
parser.add_argument("--mcts_iterations", type=int, default=20)
|
| 239 |
+
parser.add_argument("--num_children", type=int, default=24)
|
| 240 |
+
parser.add_argument("--sample_prob_weight", type=float, default=0.1)
|
| 241 |
+
parser.add_argument("--invalid_penalty", type=float, default=1.0)
|
| 242 |
+
parser.add_argument("--pareto_max_size", type=int, default=None)
|
| 243 |
+
parser.add_argument("--guidance_steps", type=int, default=None)
|
| 244 |
+
parser.add_argument("--fast_direction", action="store_true", default=False)
|
| 245 |
+
parser.add_argument("--num_batches", type=int, default=1)
|
| 246 |
+
parser.add_argument("--output_dir", type=str, default=None)
|
| 247 |
+
parser.add_argument("--shard_id", type=int, default=None)
|
| 248 |
+
parser.add_argument("--num_shards", type=int, default=None)
|
| 249 |
+
parser.add_argument("--direction_oracle_ckpt", type=str, default=None)
|
| 250 |
+
parser.add_argument("--direction_oracle_tr2d2_checkpoint", type=str, default=None)
|
| 251 |
+
parser.add_argument("--direction_oracle_tokenizer_vocab", type=str, default=None)
|
| 252 |
+
parser.add_argument("--direction_oracle_tokenizer_splits", type=str, default=None)
|
| 253 |
+
parser.add_argument("--direction_oracle_esm_name", type=str, default="facebook/esm2_t33_650M_UR50D")
|
| 254 |
+
parser.add_argument("--direction_oracle_esm_cache_dir", type=str, default=None)
|
| 255 |
+
parser.add_argument("--direction_oracle_esm_local_files_only", action="store_true", default=False)
|
| 256 |
+
parser.add_argument("--direction_oracle_max_ligand_length", type=int, default=768)
|
| 257 |
+
parser.add_argument("--direction_oracle_max_protein_length", type=int, default=1024)
|
| 258 |
+
parser.add_argument("--direction_oracle_d_model", type=int, default=256)
|
| 259 |
+
parser.add_argument("--direction_oracle_n_heads", type=int, default=4)
|
| 260 |
+
parser.add_argument("--direction_oracle_n_self_attn_layers", type=int, default=1)
|
| 261 |
+
parser.add_argument("--direction_oracle_n_bmca_layers", type=int, default=2)
|
| 262 |
+
parser.add_argument("--direction_oracle_dropout", type=float, default=0.3)
|
| 263 |
+
args = parser.parse_args()
|
| 264 |
+
|
| 265 |
+
rank_env = os.environ.get("LOCAL_RANK")
|
| 266 |
+
world_env = os.environ.get("WORLD_SIZE")
|
| 267 |
+
if rank_env is not None or world_env is not None:
|
| 268 |
+
rank = int(rank_env or 0)
|
| 269 |
+
world_size = int(world_env or 1)
|
| 270 |
+
else:
|
| 271 |
+
rank = int(args.shard_id) if args.shard_id is not None else 0
|
| 272 |
+
world_size = int(args.num_shards) if args.num_shards is not None else 1
|
| 273 |
+
if world_size < 1:
|
| 274 |
+
world_size = 1
|
| 275 |
+
if world_size > 1 and str(args.device).lower() in {"cuda", "cuda:0", "auto"}:
|
| 276 |
+
args.device = f"cuda:{rank}"
|
| 277 |
+
|
| 278 |
+
resolved_device = resolve_device(args.device)
|
| 279 |
+
args.device = str(resolved_device)
|
| 280 |
+
|
| 281 |
+
tr2d2_root = ROOT_DIR
|
| 282 |
+
if args.direction_oracle_ckpt is None:
|
| 283 |
+
args.direction_oracle_ckpt = os.path.join(
|
| 284 |
+
tr2d2_root, "direction_oracle.pt"
|
| 285 |
+
)
|
| 286 |
+
if args.direction_oracle_tr2d2_checkpoint is None:
|
| 287 |
+
args.direction_oracle_tr2d2_checkpoint = os.path.join(
|
| 288 |
+
tr2d2_root, "pretrained", "peptune-pretrained.ckpt"
|
| 289 |
+
)
|
| 290 |
+
if args.direction_oracle_tokenizer_vocab is None:
|
| 291 |
+
args.direction_oracle_tokenizer_vocab = os.path.join(
|
| 292 |
+
tr2d2_root, "tokenizer", "new_vocab.txt"
|
| 293 |
+
)
|
| 294 |
+
if args.direction_oracle_tokenizer_splits is None:
|
| 295 |
+
args.direction_oracle_tokenizer_splits = os.path.join(
|
| 296 |
+
tr2d2_root, "tokenizer", "new_splits.txt"
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
if args.targets_csv is None and args.prot_seq is None:
|
| 300 |
+
raise ValueError("--prot_seq is required when --targets_csv is not provided.")
|
| 301 |
+
|
| 302 |
+
base_model = load_base_model(args.ckpt_path, args.device)
|
| 303 |
+
base_path = os.path.abspath(os.path.join(ROOT_DIR, ".."))
|
| 304 |
+
multi_target = args.targets_csv is not None
|
| 305 |
+
scoring_fn = load_reward_models(
|
| 306 |
+
args.prot_seq if not multi_target else None,
|
| 307 |
+
args.device,
|
| 308 |
+
base_model=base_model,
|
| 309 |
+
base_path=base_path,
|
| 310 |
+
multi_target=multi_target,
|
| 311 |
+
)
|
| 312 |
+
direction_oracle = load_direction_oracle(args, args.device)
|
| 313 |
+
reward_alpha = args.reward_alpha if args.reward_alpha is not None else args.alpha
|
| 314 |
+
|
| 315 |
+
if args.targets_csv:
|
| 316 |
+
import pandas as pd
|
| 317 |
+
|
| 318 |
+
df = pd.read_csv(args.targets_csv)
|
| 319 |
+
if "Target_Sequence" not in df.columns:
|
| 320 |
+
raise ValueError("targets_csv must contain a 'Target_Sequence' column.")
|
| 321 |
+
if "Ligand_Sequence" not in df.columns:
|
| 322 |
+
raise ValueError("targets_csv must contain a 'Ligand_Sequence' column.")
|
| 323 |
+
|
| 324 |
+
targets = []
|
| 325 |
+
for row_idx, row in df.iterrows():
|
| 326 |
+
target_seq = str(row["Target_Sequence"]) if pd.notna(row["Target_Sequence"]) else None
|
| 327 |
+
if not target_seq:
|
| 328 |
+
continue
|
| 329 |
+
binder_seq = row["Ligand_Sequence"]
|
| 330 |
+
if pd.isna(binder_seq):
|
| 331 |
+
binder_seq = None
|
| 332 |
+
else:
|
| 333 |
+
binder_seq = str(binder_seq)
|
| 334 |
+
if binder_seq.strip() == "":
|
| 335 |
+
binder_seq = None
|
| 336 |
+
targets.append(
|
| 337 |
+
{
|
| 338 |
+
"target_seq": target_seq,
|
| 339 |
+
"binder_seq": binder_seq,
|
| 340 |
+
"row_index": int(row_idx),
|
| 341 |
+
}
|
| 342 |
+
)
|
| 343 |
+
else:
|
| 344 |
+
targets = [{"target_seq": args.prot_seq, "binder_seq": args.binder_seq, "row_index": 0}]
|
| 345 |
+
|
| 346 |
+
if world_size > 1:
|
| 347 |
+
targets = [item for idx, item in enumerate(targets) if idx % world_size == rank]
|
| 348 |
+
print(f"[shard] rank {rank}/{world_size}: {len(targets)} targets")
|
| 349 |
+
|
| 350 |
+
output_dir = args.output_dir
|
| 351 |
+
if output_dir is None:
|
| 352 |
+
output_dir = os.path.join(os.path.dirname(__file__), "outputs")
|
| 353 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 354 |
+
|
| 355 |
+
from utils.app import PeptideAnalyzer
|
| 356 |
+
|
| 357 |
+
analyzer = PeptideAnalyzer()
|
| 358 |
+
all_rows = []
|
| 359 |
+
batch_rows = []
|
| 360 |
+
metrics_rows = []
|
| 361 |
+
def resolve_seq_length(binder_seq: Optional[str]) -> int:
|
| 362 |
+
if not binder_seq:
|
| 363 |
+
return args.seq_length
|
| 364 |
+
try:
|
| 365 |
+
smiles = peptide_seq_to_smiles(binder_seq)
|
| 366 |
+
if not smiles:
|
| 367 |
+
return args.seq_length
|
| 368 |
+
if base_model.tokenizer is None:
|
| 369 |
+
return len(smiles)
|
| 370 |
+
return smiles_token_length(smiles, base_model.tokenizer)
|
| 371 |
+
except Exception as exc:
|
| 372 |
+
print(f"Warning: failed to derive seq_length from binder_seq; using {args.seq_length}. Error: {exc}")
|
| 373 |
+
return args.seq_length
|
| 374 |
+
|
| 375 |
+
for target_idx, target_info in enumerate(targets):
|
| 376 |
+
target_seq = target_info["target_seq"]
|
| 377 |
+
binder_seq = target_info.get("binder_seq")
|
| 378 |
+
row_index = target_info.get("row_index", target_idx)
|
| 379 |
+
seq_length = resolve_seq_length(binder_seq)
|
| 380 |
+
protein_tokens = direction_oracle.encode_protein(target_seq)
|
| 381 |
+
for direction_name, d_star in [("agonist", 1.0), ("antagonist", -1.0)]:
|
| 382 |
+
|
| 383 |
+
reward_inputs = RewardInputs(
|
| 384 |
+
protein_tokens=protein_tokens,
|
| 385 |
+
d_star=d_star,
|
| 386 |
+
protein_seq=target_seq,
|
| 387 |
+
)
|
| 388 |
+
reward_fn = RewardWrapper(
|
| 389 |
+
scoring_fn=scoring_fn,
|
| 390 |
+
direction_oracle=direction_oracle,
|
| 391 |
+
base_model=base_model,
|
| 392 |
+
tokenizer=base_model.tokenizer,
|
| 393 |
+
reward_inputs=reward_inputs,
|
| 394 |
+
device=torch.device(args.device),
|
| 395 |
+
fast_direction=args.fast_direction,
|
| 396 |
+
reward_alpha=reward_alpha,
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
num_batches = 1 if multi_target else args.num_batches
|
| 400 |
+
for batch_idx in range(num_batches):
|
| 401 |
+
start = time.perf_counter()
|
| 402 |
+
result = run_baseline(
|
| 403 |
+
args.baseline,
|
| 404 |
+
base_model,
|
| 405 |
+
reward_fn,
|
| 406 |
+
batch_size=args.batch_size,
|
| 407 |
+
seq_length=seq_length,
|
| 408 |
+
num_steps=args.num_steps,
|
| 409 |
+
guidance_scale=args.guidance_scale,
|
| 410 |
+
alpha=args.alpha,
|
| 411 |
+
guidance_steps=args.guidance_steps,
|
| 412 |
+
mcts_iterations=args.mcts_iterations,
|
| 413 |
+
num_children=args.num_children,
|
| 414 |
+
sample_prob_weight=args.sample_prob_weight,
|
| 415 |
+
invalid_penalty=args.invalid_penalty,
|
| 416 |
+
pareto_max_size=args.pareto_max_size,
|
| 417 |
+
)
|
| 418 |
+
elapsed = time.perf_counter() - start
|
| 419 |
+
|
| 420 |
+
scores = reward_fn.evaluate_tokens(
|
| 421 |
+
result["tokens"],
|
| 422 |
+
torch.ones_like(result["tokens"], device=result["tokens"].device),
|
| 423 |
+
)
|
| 424 |
+
sequences = scores["sequences"]
|
| 425 |
+
affinity = scores["affinity"].detach().cpu().numpy()
|
| 426 |
+
direction = scores["direction"].detach().cpu().numpy()
|
| 427 |
+
gated_reward = scores["gated_reward"].detach().cpu().numpy()
|
| 428 |
+
valid_mask = np.array([analyzer.is_peptide(seq) for seq in sequences], dtype=np.float32)
|
| 429 |
+
valid_fraction = float(valid_mask.mean()) if len(valid_mask) else 0.0
|
| 430 |
+
consistency = d_star * (direction - 0.5)
|
| 431 |
+
if d_star > 0:
|
| 432 |
+
direction_correct = (direction >= 0.5).astype(np.float32)
|
| 433 |
+
else:
|
| 434 |
+
direction_correct = (direction < 0.5).astype(np.float32)
|
| 435 |
+
success = direction_correct * valid_mask
|
| 436 |
+
direction_mean = float(np.mean(direction))
|
| 437 |
+
direction_std = float(np.std(direction))
|
| 438 |
+
affinity_mean = float(np.mean(affinity))
|
| 439 |
+
affinity_std = float(np.std(affinity))
|
| 440 |
+
consistency_mean = float(np.mean(consistency))
|
| 441 |
+
consistency_std = float(np.std(consistency))
|
| 442 |
+
gated_reward_mean = float(np.mean(gated_reward))
|
| 443 |
+
gated_reward_std = float(np.std(gated_reward))
|
| 444 |
+
direction_acc_mean = float(np.mean(direction_correct))
|
| 445 |
+
direction_acc_std = float(np.std(direction_correct))
|
| 446 |
+
success_rate_mean = float(np.mean(success))
|
| 447 |
+
success_rate_std = float(np.std(success))
|
| 448 |
+
batch_metrics = {
|
| 449 |
+
"direction_mean": direction_mean,
|
| 450 |
+
"direction_std": direction_std,
|
| 451 |
+
"affinity_mean": affinity_mean,
|
| 452 |
+
"affinity_std": affinity_std,
|
| 453 |
+
"consistency_mean": consistency_mean,
|
| 454 |
+
"consistency_std": consistency_std,
|
| 455 |
+
"gated_reward_mean": gated_reward_mean,
|
| 456 |
+
"gated_reward_std": gated_reward_std,
|
| 457 |
+
"direction_accuracy_mean": direction_acc_mean,
|
| 458 |
+
"direction_accuracy_std": direction_acc_std,
|
| 459 |
+
"valid_fraction": valid_fraction,
|
| 460 |
+
"success_rate_mean": success_rate_mean,
|
| 461 |
+
"success_rate_std": success_rate_std,
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
for i, seq in enumerate(sequences):
|
| 465 |
+
all_rows.append(
|
| 466 |
+
{
|
| 467 |
+
"rank": rank,
|
| 468 |
+
"sequence": seq,
|
| 469 |
+
"affinity": float(affinity[i]),
|
| 470 |
+
"direction": float(direction[i]),
|
| 471 |
+
"d_star": float(d_star),
|
| 472 |
+
"direction_name": direction_name,
|
| 473 |
+
"target_seq": target_seq,
|
| 474 |
+
"target_index": target_idx,
|
| 475 |
+
"row_index": row_index,
|
| 476 |
+
"binder_seq": binder_seq,
|
| 477 |
+
"seq_length": seq_length,
|
| 478 |
+
"gated_reward": float(gated_reward[i]),
|
| 479 |
+
"consistency_reward": float(consistency[i]),
|
| 480 |
+
"direction_accuracy": float(direction_correct[i]),
|
| 481 |
+
"valid": float(valid_mask[i]),
|
| 482 |
+
"success": float(success[i]),
|
| 483 |
+
"batch_index": batch_idx,
|
| 484 |
+
"batch_time_sec": elapsed,
|
| 485 |
+
**batch_metrics,
|
| 486 |
+
}
|
| 487 |
+
)
|
| 488 |
+
batch_rows.append(
|
| 489 |
+
{
|
| 490 |
+
"rank": rank,
|
| 491 |
+
"batch_index": batch_idx,
|
| 492 |
+
"batch_time_sec": elapsed,
|
| 493 |
+
"target_index": target_idx,
|
| 494 |
+
"row_index": row_index,
|
| 495 |
+
"binder_seq": binder_seq,
|
| 496 |
+
"seq_length": seq_length,
|
| 497 |
+
"direction_name": direction_name,
|
| 498 |
+
}
|
| 499 |
+
)
|
| 500 |
+
metrics_rows.append(
|
| 501 |
+
{
|
| 502 |
+
"rank": rank,
|
| 503 |
+
"target_index": target_idx,
|
| 504 |
+
"target_seq": target_seq,
|
| 505 |
+
"row_index": row_index,
|
| 506 |
+
"binder_seq": binder_seq,
|
| 507 |
+
"seq_length": seq_length,
|
| 508 |
+
"direction_name": direction_name,
|
| 509 |
+
"d_star": float(d_star),
|
| 510 |
+
"batch_index": batch_idx,
|
| 511 |
+
"num_samples": len(sequences),
|
| 512 |
+
**batch_metrics,
|
| 513 |
+
}
|
| 514 |
+
)
|
| 515 |
+
print(
|
| 516 |
+
f"Target {target_idx} dir {direction_name}: "
|
| 517 |
+
f"generated {len(sequences)} sequences in {elapsed:.3f}s"
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
import pandas as pd
|
| 521 |
+
|
| 522 |
+
if world_size > 1:
|
| 523 |
+
output_csv = os.path.join(output_dir, f"{args.baseline}_samples_rank{rank}.csv")
|
| 524 |
+
batch_csv = os.path.join(output_dir, f"batch_times_rank{rank}.csv")
|
| 525 |
+
metrics_csv = os.path.join(output_dir, f"{args.baseline}_metrics_rank{rank}.csv")
|
| 526 |
+
else:
|
| 527 |
+
output_csv = os.path.join(output_dir, f"{args.baseline}_samples.csv")
|
| 528 |
+
batch_csv = os.path.join(output_dir, "batch_times.csv")
|
| 529 |
+
metrics_csv = os.path.join(output_dir, f"{args.baseline}_metrics.csv")
|
| 530 |
+
pd.DataFrame(all_rows).to_csv(output_csv, index=False)
|
| 531 |
+
pd.DataFrame(batch_rows).to_csv(batch_csv, index=False)
|
| 532 |
+
pd.DataFrame(metrics_rows).to_csv(metrics_csv, index=False)
|
| 533 |
+
|
| 534 |
+
print(f"Saved samples to {output_csv}")
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
if __name__ == "__main__":
|
| 538 |
+
main()
|
configs/finetune_config.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared Configuration Classes for TD3B Finetuning
|
| 3 |
+
|
| 4 |
+
This module contains all configuration dataclasses used by both:
|
| 5 |
+
- finetune_v1.py (single-target training)
|
| 6 |
+
- finetune_multi_target.py (multi-target training)
|
| 7 |
+
|
| 8 |
+
Extracted to avoid code duplication and ensure consistency.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class RoFormerConfig:
|
| 17 |
+
"""Configuration for RoFormer model architecture."""
|
| 18 |
+
hidden_size: int
|
| 19 |
+
n_layers: int
|
| 20 |
+
n_heads: int
|
| 21 |
+
max_position_embeddings: int = 1035 # Must match pretrained model
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass(frozen=True)
|
| 25 |
+
class NoiseConfig:
|
| 26 |
+
"""Configuration for noise scheduling."""
|
| 27 |
+
type: str = 'loglinear'
|
| 28 |
+
sigma_min: float = 1e-4
|
| 29 |
+
sigma_max: float = 20.0
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass(frozen=True)
|
| 33 |
+
class TrainingConfig:
|
| 34 |
+
"""Configuration for training parameters."""
|
| 35 |
+
sampling_eps: float
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass(frozen=True)
|
| 39 |
+
class SamplingConfig:
|
| 40 |
+
"""Configuration for sampling parameters."""
|
| 41 |
+
steps: int
|
| 42 |
+
sampling_eps: float
|
| 43 |
+
predictor: str = 'ddpm_cache'
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass(frozen=True)
|
| 47 |
+
class EvalConfig:
|
| 48 |
+
"""Configuration for evaluation parameters."""
|
| 49 |
+
gen_ppl_eval_model_name_or_path: str = 'gpt2-large'
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass(frozen=True)
|
| 53 |
+
class OptimConfig:
|
| 54 |
+
"""Configuration for optimizer parameters."""
|
| 55 |
+
lr: float
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass(frozen=True)
|
| 59 |
+
class MCTSConfig:
|
| 60 |
+
"""Configuration for MCTS parameters."""
|
| 61 |
+
sampling: int = 0 # 0 for Gumbel sampling
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class DiffusionConfig:
|
| 65 |
+
"""
|
| 66 |
+
Complete configuration for Diffusion model.
|
| 67 |
+
|
| 68 |
+
This class encapsulates all nested configuration objects required
|
| 69 |
+
by the Diffusion model, providing a clean interface and type safety.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
roformer: RoFormerConfig,
|
| 75 |
+
noise: NoiseConfig,
|
| 76 |
+
training: TrainingConfig,
|
| 77 |
+
sampling: SamplingConfig,
|
| 78 |
+
eval_cfg: EvalConfig,
|
| 79 |
+
optim: OptimConfig,
|
| 80 |
+
mcts: MCTSConfig
|
| 81 |
+
):
|
| 82 |
+
# Create anonymous objects for backward compatibility
|
| 83 |
+
self.roformer = type('RoFormerObj', (), {
|
| 84 |
+
'hidden_size': roformer.hidden_size,
|
| 85 |
+
'n_layers': roformer.n_layers,
|
| 86 |
+
'n_heads': roformer.n_heads,
|
| 87 |
+
'max_position_embeddings': roformer.max_position_embeddings
|
| 88 |
+
})()
|
| 89 |
+
|
| 90 |
+
self.noise = type('NoiseObj', (), {
|
| 91 |
+
'type': noise.type,
|
| 92 |
+
'sigma_min': noise.sigma_min,
|
| 93 |
+
'sigma_max': noise.sigma_max
|
| 94 |
+
})()
|
| 95 |
+
|
| 96 |
+
self.training = type('TrainingObj', (), {
|
| 97 |
+
'sampling_eps': training.sampling_eps
|
| 98 |
+
})()
|
| 99 |
+
|
| 100 |
+
self.sampling = type('SamplingObj', (), {
|
| 101 |
+
'steps': sampling.steps,
|
| 102 |
+
'sampling_eps': sampling.sampling_eps,
|
| 103 |
+
'predictor': sampling.predictor
|
| 104 |
+
})()
|
| 105 |
+
|
| 106 |
+
self.eval = type('EvalObj', (), {
|
| 107 |
+
'gen_ppl_eval_model_name_or_path': eval_cfg.gen_ppl_eval_model_name_or_path
|
| 108 |
+
})()
|
| 109 |
+
|
| 110 |
+
self.optim = type('OptimObj', (), {
|
| 111 |
+
'lr': optim.lr
|
| 112 |
+
})()
|
| 113 |
+
|
| 114 |
+
self.mcts = type('MCTSObj', (), {
|
| 115 |
+
'sampling': mcts.sampling
|
| 116 |
+
})()
|
| 117 |
+
|
| 118 |
+
# Fixed parameters
|
| 119 |
+
self.backbone = 'roformer'
|
| 120 |
+
self.parameterization = 'subs'
|
| 121 |
+
self.time_conditioning = False
|
| 122 |
+
self.T = 0
|
configs/peptune_config.yaml
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
noise:
|
| 2 |
+
type: loglinear
|
| 3 |
+
sigma_min: 1e-4
|
| 4 |
+
sigma_max: 20
|
| 5 |
+
state_dependent: True
|
| 6 |
+
|
| 7 |
+
mode: ppl_eval # train / ppl_eval / sample_eval
|
| 8 |
+
diffusion: absorbing_state
|
| 9 |
+
vocab: old_smiles # old_smiles / new_smiles / selfies / helm
|
| 10 |
+
backbone: roformer # peptideclm / helmgpt / dit / roformer / finetune_roformer
|
| 11 |
+
parameterization: subs # subs
|
| 12 |
+
time_conditioning: False
|
| 13 |
+
T: 0 # 0 (continuous time) / 1000
|
| 14 |
+
subs_masking: False
|
| 15 |
+
|
| 16 |
+
seed: 42
|
| 17 |
+
|
| 18 |
+
mcts:
|
| 19 |
+
num_children: 50
|
| 20 |
+
num_objectives: 5
|
| 21 |
+
topk: 100
|
| 22 |
+
mask_token: 4
|
| 23 |
+
num_iter: 128
|
| 24 |
+
sampling: 0 # 0 is gumbel sampling / > 0 samples children from top k probs
|
| 25 |
+
invalid_penalty: 0.5
|
| 26 |
+
sample_prob: 1.0
|
| 27 |
+
perm: True
|
| 28 |
+
dual: False
|
| 29 |
+
single: False
|
| 30 |
+
time_dependent: True
|
| 31 |
+
|
| 32 |
+
lr_scheduler:
|
| 33 |
+
_target_: transformers.get_constant_schedule_with_warmup
|
| 34 |
+
num_warmup_steps: 2500
|
| 35 |
+
|
| 36 |
+
data:
|
| 37 |
+
train: To Be Added
|
| 38 |
+
valid: To Be Added
|
| 39 |
+
batchinohup ng: wrapping # padding / wrapping
|
| 40 |
+
|
| 41 |
+
loader:
|
| 42 |
+
global_batch_size: 64
|
| 43 |
+
eval_global_batch_size: ${.global_batch_size}
|
| 44 |
+
# Note: batch_size and eval_batch_size are **per machine**
|
| 45 |
+
batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 46 |
+
eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 47 |
+
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
|
| 48 |
+
pin_memory: True
|
| 49 |
+
|
| 50 |
+
sampling:
|
| 51 |
+
predictor: ddpm_cache # analytic, ddpm, ddpm_cache
|
| 52 |
+
num_sequences: 100
|
| 53 |
+
sampling_eps: 1e-3
|
| 54 |
+
steps: 128
|
| 55 |
+
seq_length: 100
|
| 56 |
+
noise_removal: True
|
| 57 |
+
num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
|
| 58 |
+
num_sample_log: 2
|
| 59 |
+
stride_length: 1
|
| 60 |
+
num_strides: 1
|
| 61 |
+
|
| 62 |
+
training:
|
| 63 |
+
antithetic_sampling: True
|
| 64 |
+
sampling_eps: 1e-3
|
| 65 |
+
focus_mask: False
|
| 66 |
+
#dynamic_batching: True
|
| 67 |
+
accumulator: False
|
| 68 |
+
|
| 69 |
+
eval:
|
| 70 |
+
checkpoint_path:
|
| 71 |
+
disable_ema: False
|
| 72 |
+
compute_generative_perplexity: False
|
| 73 |
+
perplexity_batch_size: 8
|
| 74 |
+
compute_perplexity_on_sanity: False
|
| 75 |
+
gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
|
| 76 |
+
generate_samples: True
|
| 77 |
+
generation_model:
|
| 78 |
+
|
| 79 |
+
optim:
|
| 80 |
+
weight_decay: 0.075
|
| 81 |
+
lr: 3e-4
|
| 82 |
+
beta1: 0.9
|
| 83 |
+
beta2: 0.999
|
| 84 |
+
eps: 1e-8
|
| 85 |
+
|
| 86 |
+
pepclm:
|
| 87 |
+
hidden_size: 768
|
| 88 |
+
cond_dim: 256
|
| 89 |
+
n_heads: 20
|
| 90 |
+
n_blocks: 4
|
| 91 |
+
dropout: 0.5
|
| 92 |
+
length: 512
|
| 93 |
+
#scale_by_sigma: True
|
| 94 |
+
|
| 95 |
+
model:
|
| 96 |
+
type: ddit
|
| 97 |
+
hidden_size: 768
|
| 98 |
+
cond_dim: 128
|
| 99 |
+
length: 512
|
| 100 |
+
n_blocks: 12
|
| 101 |
+
n_heads: 12
|
| 102 |
+
scale_by_sigma: True
|
| 103 |
+
dropout: 0.1
|
| 104 |
+
|
| 105 |
+
roformer:
|
| 106 |
+
hidden_size: 768
|
| 107 |
+
n_layers: 8
|
| 108 |
+
n_heads: 8
|
| 109 |
+
max_position_embeddings: 1035
|
| 110 |
+
|
| 111 |
+
helmgpt:
|
| 112 |
+
hidden_size: 256
|
| 113 |
+
embd_pdrop: 0.1
|
| 114 |
+
resid_pdrop: 0.1
|
| 115 |
+
attn_pdrop: 0.1
|
| 116 |
+
ff_dropout: 0.
|
| 117 |
+
block_size: 140
|
| 118 |
+
n_layer: 8
|
| 119 |
+
n_heads: 8
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
trainer:
|
| 123 |
+
_target_: lightning.Trainer
|
| 124 |
+
accelerator: cuda
|
| 125 |
+
num_nodes: 1
|
| 126 |
+
devices: ${device_count:}
|
| 127 |
+
accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
|
| 128 |
+
gradient_clip_val: 1.0
|
| 129 |
+
precision: 64-true
|
| 130 |
+
num_sanity_val_steps: 2
|
| 131 |
+
max_epochs: 100
|
| 132 |
+
max_steps: 1_000_000
|
| 133 |
+
log_every_n_steps: 10
|
| 134 |
+
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
|
| 135 |
+
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
|
| 136 |
+
#val_check_interval: 40 #954
|
| 137 |
+
check_val_every_n_epoch: 1
|
| 138 |
+
|
| 139 |
+
hydra:
|
| 140 |
+
run:
|
| 141 |
+
dir: ./${now:%Y.%m.%d}/
|
| 142 |
+
job:
|
| 143 |
+
chdir: True
|
| 144 |
+
|
| 145 |
+
checkpointing:
|
| 146 |
+
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
|
| 147 |
+
save_dir: ${cwd:}
|
| 148 |
+
# Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
|
| 149 |
+
resume_from_ckpt: True
|
| 150 |
+
resume_ckpt_path:
|
| 151 |
+
|
| 152 |
+
callbacks:
|
| 153 |
+
model_checkpoint:
|
| 154 |
+
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
| 155 |
+
every_n_epochs: 1
|
| 156 |
+
monitor: "val/nll"
|
| 157 |
+
save_top_k: 10
|
| 158 |
+
mode: "min"
|
| 159 |
+
dirpath:
|
diffusion.py
ADDED
|
@@ -0,0 +1,1588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import sys
|
| 3 |
+
import itertools
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
import math
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
import random as rd
|
| 11 |
+
import lightning as L
|
| 12 |
+
import torchmetrics
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
import gc
|
| 15 |
+
import utils.utils as utils
|
| 16 |
+
|
| 17 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 18 |
+
import noise_schedule
|
| 19 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 20 |
+
import roformer as roformer
|
| 21 |
+
from utils.app import PeptideAnalyzer
|
| 22 |
+
import pandas as pd
|
| 23 |
+
|
| 24 |
+
base_path = 'To Be Added'
|
| 25 |
+
|
| 26 |
+
def _sample_categorical(categorical_probs):
|
| 27 |
+
gumbel_norm = (
|
| 28 |
+
1e-10
|
| 29 |
+
- (torch.rand_like(categorical_probs) + 1e-10).log())
|
| 30 |
+
return (categorical_probs / gumbel_norm).argmax(dim=-1).to(dtype=torch.long)
|
| 31 |
+
|
| 32 |
+
def _sample_categorical_gradient(categorical_probs, temp = 1.0):
|
| 33 |
+
gumbel_norm = (
|
| 34 |
+
1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log())
|
| 35 |
+
output = torch.nn.functional.softmax((torch.log(categorical_probs)-torch.log(gumbel_norm))/temp, 2)
|
| 36 |
+
return output
|
| 37 |
+
|
| 38 |
+
def _unsqueeze(x, reference):
|
| 39 |
+
return x.view(
|
| 40 |
+
* x.shape,
|
| 41 |
+
* ((1,) * (len(reference.shape) - len(x.shape))))
|
| 42 |
+
|
| 43 |
+
def sample_batched_categorical(categorical_probs, batch_size):
|
| 44 |
+
"""
|
| 45 |
+
Generates `m` distinct sequences sampled from categorical probabilities
|
| 46 |
+
using the Gumbel distribution to ensure randomness while following probabilities
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
categorical_probs (torch.Tensor): tensor of shape (sequence_length, vocab_length)
|
| 50 |
+
representing categorical probabilities
|
| 51 |
+
m (int): number of distinct sequences to sample
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
torch.Tensor: tensor of shape (m, sequence_length), where each row is a
|
| 55 |
+
distinct sequence of sampled category indices.
|
| 56 |
+
"""
|
| 57 |
+
_, sequence_length, vocab_size = categorical_probs.shape
|
| 58 |
+
|
| 59 |
+
# add Gumbel noise and sample m sequences
|
| 60 |
+
gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_size) + 1e-10) + 1e-10)).to(categorical_probs.device)
|
| 61 |
+
noisy_scores = torch.log(categorical_probs) + gumbel_noise # add Gumbel noise to log probabilities
|
| 62 |
+
|
| 63 |
+
# select the highest score (most likely category after Gumbel noise)
|
| 64 |
+
sampled_sequences = noisy_scores.argmax(dim=-1).to(dtype=torch.long) # shape: (m, sequence_length)
|
| 65 |
+
|
| 66 |
+
return sampled_sequences
|
| 67 |
+
|
| 68 |
+
def sample_batched_top_k(categorical_probs, batch_size, k):
|
| 69 |
+
"""
|
| 70 |
+
Generates `m` sequences sampled from the top-k probabilities of each token
|
| 71 |
+
using Gumbel noise to ensure randomness and reduce bias towards the most likely options.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
categorical_probs (torch.Tensor): A tensor of shape (sequence_length, vocab_length)
|
| 75 |
+
representing categorical probabilities.
|
| 76 |
+
m (int): Number of sequences to sample.
|
| 77 |
+
k (int): Number of top probabilities to consider for sampling.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
torch.Tensor: A tensor of shape (m, sequence_length), where each row is a
|
| 81 |
+
sampled sequence of category indices.
|
| 82 |
+
"""
|
| 83 |
+
_, sequence_length, vocab_length = categorical_probs.shape
|
| 84 |
+
|
| 85 |
+
# Add Gumbel noise to the log probabilities
|
| 86 |
+
gumbel_noise = -torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_length) + 1e-10) + 1e-10).to(categorical_probs.device)
|
| 87 |
+
noisy_scores = torch.log(categorical_probs[None, :, :]) + gumbel_noise # Shape: (m, sequence_length, vocab_length)
|
| 88 |
+
|
| 89 |
+
# Get the top-k categories based on noisy scores
|
| 90 |
+
top_k_scores, top_k_indices = torch.topk(noisy_scores, k, dim=-1) # Shape: (m, sequence_length, k)
|
| 91 |
+
|
| 92 |
+
# Convert top-k scores back to probabilities and normalize
|
| 93 |
+
top_k_probs = torch.softmax(top_k_scores, dim=-1).to(categorical_probs.device) # Shape: (m, sequence_length, k)
|
| 94 |
+
|
| 95 |
+
# Sample randomly from the top-k probabilities
|
| 96 |
+
sampled_indices_in_top_k = torch.multinomial(top_k_probs.reshape(-1, k), num_samples=1).squeeze(-1).to(categorical_probs.device)
|
| 97 |
+
sampled_indices_in_top_k = sampled_indices_in_top_k.view(batch_size, sequence_length).to(categorical_probs.device) # Shape: (batch_size, sequence_length)
|
| 98 |
+
|
| 99 |
+
# Map sampled indices back to the original vocabulary indices
|
| 100 |
+
sampled_sequences = torch.gather(top_k_indices, -1, sampled_indices_in_top_k.unsqueeze(-1)).squeeze(-1).to(categorical_probs.device).to(dtype=torch.long)
|
| 101 |
+
|
| 102 |
+
return sampled_sequences
|
| 103 |
+
|
| 104 |
+
@dataclass
|
| 105 |
+
class Loss:
|
| 106 |
+
loss: torch.FloatTensor
|
| 107 |
+
nlls: torch.FloatTensor
|
| 108 |
+
attn_mask: torch.FloatTensor
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class NLL(torchmetrics.aggregation.MeanMetric):
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class BPD(NLL):
|
| 116 |
+
def compute(self) -> Tensor:
|
| 117 |
+
"""Computes the bits per dimension.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
bpd
|
| 121 |
+
"""
|
| 122 |
+
return self.mean_value / self.weight / math.log(2)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class Perplexity(NLL):
|
| 126 |
+
def compute(self) -> Tensor:
|
| 127 |
+
"""Computes the Perplexity.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Perplexity
|
| 131 |
+
"""
|
| 132 |
+
return torch.exp(self.mean_value / self.weight)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class Diffusion(L.LightningModule):
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
config,
|
| 139 |
+
tokenizer = None,
|
| 140 |
+
mode="finetune",
|
| 141 |
+
device=None,
|
| 142 |
+
):
|
| 143 |
+
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.config = config
|
| 146 |
+
#self.save_hyperparameters()
|
| 147 |
+
|
| 148 |
+
# PeptideCLM tokenizer
|
| 149 |
+
if tokenizer is None:
|
| 150 |
+
self.tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/tr2d2-pep/tokenizer/new_vocab.txt',
|
| 151 |
+
f'{base_path}/tr2d2-pep/tokenizer/new_splits.txt')
|
| 152 |
+
else:
|
| 153 |
+
self.tokenizer = tokenizer
|
| 154 |
+
|
| 155 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 156 |
+
self.mask_index = self.tokenizer.mask_token_id
|
| 157 |
+
self.sampler = self.config.sampling.predictor
|
| 158 |
+
self.analyzer = PeptideAnalyzer()
|
| 159 |
+
|
| 160 |
+
# backbone LM PeptideCLM model
|
| 161 |
+
self.backbone = roformer.Roformer(self.config, self.tokenizer, device=device)
|
| 162 |
+
if mode == "finetune":
|
| 163 |
+
self.backbone.freeze_model()
|
| 164 |
+
self.backbone.unfreeze_n_layers(n=8)
|
| 165 |
+
elif mode == "eval":
|
| 166 |
+
self.backbone.freeze_model()
|
| 167 |
+
self.backbone.requires_grad_(False)
|
| 168 |
+
self.backbone.eval()
|
| 169 |
+
elif mode == "train":
|
| 170 |
+
self.backbone.requires_grad_(True)
|
| 171 |
+
self.backbone.train()
|
| 172 |
+
|
| 173 |
+
self.neg_infinity = -1000000.0
|
| 174 |
+
self.T = config.T
|
| 175 |
+
# noise schedule for non-peptide bond tokens (default to log-linear)
|
| 176 |
+
self.noise = noise_schedule.get_noise(config)
|
| 177 |
+
|
| 178 |
+
# noise schedule for peptide bonds (log-polynomial)
|
| 179 |
+
self.bond_noise = noise_schedule.LogPolyNoise()
|
| 180 |
+
self.time_conditioning = self.config.time_conditioning
|
| 181 |
+
self.fast_forward_epochs = None
|
| 182 |
+
self.fast_forward_batches = None
|
| 183 |
+
|
| 184 |
+
self.gen_ppl_eval_model_name_or_path = self.config.eval.gen_ppl_eval_model_name_or_path
|
| 185 |
+
self.gen_ppl_metric = Perplexity()
|
| 186 |
+
|
| 187 |
+
self.lr = self.config.optim.lr
|
| 188 |
+
self.sampling_eps = self.config.training.sampling_eps
|
| 189 |
+
|
| 190 |
+
metrics = torchmetrics.MetricCollection({
|
| 191 |
+
'nll': NLL(),
|
| 192 |
+
'bpd': BPD(),
|
| 193 |
+
'ppl': Perplexity(),
|
| 194 |
+
})
|
| 195 |
+
metrics.set_dtype(torch.float64)
|
| 196 |
+
self.train_metrics = metrics.clone(prefix='trainer/')
|
| 197 |
+
self.valid_metrics = metrics.clone(prefix='val/')
|
| 198 |
+
self.test_metrics = metrics.clone(prefix='test/')
|
| 199 |
+
|
| 200 |
+
### FOR THE EXPANSION AND ROLLOUT STEP ###
|
| 201 |
+
def sample_finetuned_with_rnd(self, args, reward_model, pretrained, eps=1e-5):
|
| 202 |
+
num_steps = args.total_num_steps
|
| 203 |
+
B = args.batch_size
|
| 204 |
+
x_rollout = self.sample_prior(
|
| 205 |
+
B, args.seq_length).to(self.device)
|
| 206 |
+
|
| 207 |
+
log_rnd = torch.zeros(args.batch_size, device=self.device)
|
| 208 |
+
|
| 209 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
|
| 210 |
+
dt = (1 - eps) / num_steps
|
| 211 |
+
|
| 212 |
+
for i in range(num_steps):
|
| 213 |
+
t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
|
| 214 |
+
|
| 215 |
+
log_p, x_next, log_policy_step, log_pretrained_step = \
|
| 216 |
+
self.mcts_reverse_step(x_rollout, t=t, dt=dt, pretrained=pretrained)
|
| 217 |
+
|
| 218 |
+
log_rnd += log_pretrained_step - log_policy_step
|
| 219 |
+
|
| 220 |
+
x_rollout = x_next
|
| 221 |
+
|
| 222 |
+
# if mask token remains, fully unmask
|
| 223 |
+
mask_positions = (x_rollout == self.mask_index) # (B, L) bool
|
| 224 |
+
|
| 225 |
+
# does **any** mask remain in any sequence
|
| 226 |
+
any_mask_global = mask_positions.any().item() # true if mask remains
|
| 227 |
+
if any_mask_global:
|
| 228 |
+
log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
|
| 229 |
+
|
| 230 |
+
x_rollout = x_next
|
| 231 |
+
|
| 232 |
+
childSequences = self.tokenizer.batch_decode(x_rollout)
|
| 233 |
+
|
| 234 |
+
# change rewards for peptides
|
| 235 |
+
valid_x_final = []
|
| 236 |
+
validSequences = []
|
| 237 |
+
valid_log_rnd = []
|
| 238 |
+
|
| 239 |
+
for i in range(B):
|
| 240 |
+
# string sequence
|
| 241 |
+
childSeq = childSequences[i]
|
| 242 |
+
|
| 243 |
+
# check if the peptide is valid
|
| 244 |
+
if self.analyzer.is_peptide(childSeq):
|
| 245 |
+
valid_x_final.append(x_rollout[i])
|
| 246 |
+
validSequences.append(childSeq)
|
| 247 |
+
valid_log_rnd.append(log_rnd[i])
|
| 248 |
+
|
| 249 |
+
# compute multi-objective rewards
|
| 250 |
+
score_vectors = reward_model(input_seqs=validSequences)
|
| 251 |
+
scalar_rewards = np.sum(score_vectors, axis=-1)
|
| 252 |
+
scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=self.device)
|
| 253 |
+
|
| 254 |
+
print(f"scalar reward dim{len(scalar_rewards)}")
|
| 255 |
+
valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
|
| 256 |
+
|
| 257 |
+
log_rnd = valid_log_rnd + (scalar_rewards / args.alpha) # scale down by alpha
|
| 258 |
+
valid_x_final = torch.stack(valid_x_final, dim=0)
|
| 259 |
+
|
| 260 |
+
return valid_x_final, log_rnd, scalar_rewards
|
| 261 |
+
|
| 262 |
+
def sample_finetuned(self, args, reward_model, batch_size=None, dataframe=False, eps=1e-5):
|
| 263 |
+
torch.cuda.empty_cache()
|
| 264 |
+
self.backbone.eval()
|
| 265 |
+
self.noise.eval()
|
| 266 |
+
print(f"device:{self.device}")
|
| 267 |
+
|
| 268 |
+
if batch_size is None:
|
| 269 |
+
batch_size = args.batch_size
|
| 270 |
+
|
| 271 |
+
num_steps = args.total_num_steps
|
| 272 |
+
x_rollout = self.sample_prior(
|
| 273 |
+
batch_size,
|
| 274 |
+
args.seq_length).to(self.device, dtype=torch.long)
|
| 275 |
+
|
| 276 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
|
| 277 |
+
dt = torch.tensor((1 - eps) / num_steps, device=self.device)
|
| 278 |
+
|
| 279 |
+
for i in range(num_steps):
|
| 280 |
+
t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
|
| 281 |
+
|
| 282 |
+
log_p, x_next = self.single_reverse_step(x_rollout, t=t, dt=dt)
|
| 283 |
+
|
| 284 |
+
x_rollout = x_next
|
| 285 |
+
x_rollout = x_rollout.to(self.device)
|
| 286 |
+
|
| 287 |
+
# if mask token remains, fully unmask
|
| 288 |
+
mask_positions = (x_rollout == self.mask_index) # (B, L) bool
|
| 289 |
+
|
| 290 |
+
# does **any** mask remain in any sequence
|
| 291 |
+
any_mask_global = mask_positions.any().item() # true if mask remains
|
| 292 |
+
if any_mask_global:
|
| 293 |
+
log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
|
| 294 |
+
|
| 295 |
+
x_rollout = x_next
|
| 296 |
+
x_rollout = x_rollout.to(self.device)
|
| 297 |
+
|
| 298 |
+
childSequences = self.tokenizer.batch_decode(x_rollout)
|
| 299 |
+
valid_x_final = []
|
| 300 |
+
validSequences = []
|
| 301 |
+
|
| 302 |
+
for idx, seq in enumerate(childSequences):
|
| 303 |
+
if self.analyzer.is_peptide(seq):
|
| 304 |
+
valid_x_final.append(x_rollout[idx])
|
| 305 |
+
validSequences.append(seq)
|
| 306 |
+
|
| 307 |
+
valid_fraction = len(validSequences) / batch_size
|
| 308 |
+
|
| 309 |
+
if (len(validSequences) != 0):
|
| 310 |
+
# add scores to log
|
| 311 |
+
result = reward_model(input_seqs=validSequences)
|
| 312 |
+
|
| 313 |
+
# Handle both TD3B (returns tuple) and base ScoringFunctions (returns array directly)
|
| 314 |
+
if isinstance(result, tuple):
|
| 315 |
+
# TD3BRewardFunction returns (total_rewards, info) tuple
|
| 316 |
+
# info contains 'score_vectors' which is (N, 2) array [affinities, total_rewards]
|
| 317 |
+
total_rewards, info = result
|
| 318 |
+
affinity = info['affinities']
|
| 319 |
+
# TD3B doesn't compute sol/hemo/nf/permeability, set to zeros
|
| 320 |
+
sol = np.zeros_like(affinity)
|
| 321 |
+
hemo = np.zeros_like(affinity)
|
| 322 |
+
nf = np.zeros_like(affinity)
|
| 323 |
+
permeability = np.zeros_like(affinity)
|
| 324 |
+
else:
|
| 325 |
+
# Base scoring functions return (N, num_objectives) array directly
|
| 326 |
+
score_vectors = np.asarray(result)
|
| 327 |
+
if score_vectors.ndim == 1:
|
| 328 |
+
score_vectors = score_vectors[:, None]
|
| 329 |
+
average_scores = score_vectors.T
|
| 330 |
+
|
| 331 |
+
affinity = average_scores[0] if average_scores.shape[0] > 0 else np.zeros((0,))
|
| 332 |
+
sol = average_scores[1] if average_scores.shape[0] > 1 else np.zeros_like(affinity)
|
| 333 |
+
hemo = average_scores[2] if average_scores.shape[0] > 2 else np.zeros_like(affinity)
|
| 334 |
+
nf = average_scores[3] if average_scores.shape[0] > 3 else np.zeros_like(affinity)
|
| 335 |
+
permeability = average_scores[4] if average_scores.shape[0] > 4 else np.zeros_like(affinity)
|
| 336 |
+
|
| 337 |
+
else:
|
| 338 |
+
zeros = [0.0]
|
| 339 |
+
|
| 340 |
+
affinity = zeros
|
| 341 |
+
sol = zeros
|
| 342 |
+
hemo = zeros
|
| 343 |
+
nf = zeros
|
| 344 |
+
permeability = zeros
|
| 345 |
+
|
| 346 |
+
if dataframe:
|
| 347 |
+
df = pd.DataFrame({
|
| 348 |
+
"Peptide Sequence": validSequences,
|
| 349 |
+
"Binding Affinity": affinity if len(validSequences) else [0.0],
|
| 350 |
+
"Solubility": sol if len(validSequences) else [0.0],
|
| 351 |
+
"Hemolysis": hemo if len(validSequences) else [0.0],
|
| 352 |
+
"Nonfouling": nf if len(validSequences) else [0.0],
|
| 353 |
+
"Permeability": permeability if len(validSequences) else [0.0],
|
| 354 |
+
})
|
| 355 |
+
return x_rollout, affinity, sol, hemo, nf, permeability, valid_fraction, df
|
| 356 |
+
|
| 357 |
+
return x_rollout, affinity, sol, hemo, nf, permeability, valid_fraction
|
| 358 |
+
|
| 359 |
+
def compute_log_policy(self, token_array, x_next, t, dt, attn_mask=None):
|
| 360 |
+
torch.cuda.empty_cache()
|
| 361 |
+
self.backbone.eval()
|
| 362 |
+
self.noise.eval()
|
| 363 |
+
|
| 364 |
+
sigma_t, _ = self.noise(t)
|
| 365 |
+
|
| 366 |
+
if token_array.ndim == 1:
|
| 367 |
+
token_array = token_array.unsqueeze(0)
|
| 368 |
+
|
| 369 |
+
if x_next.ndim == 1:
|
| 370 |
+
x_next = x_next.unsqueeze(0)
|
| 371 |
+
|
| 372 |
+
if t.ndim > 1:
|
| 373 |
+
t = t.squeeze(-1)
|
| 374 |
+
assert t.ndim == 1
|
| 375 |
+
|
| 376 |
+
change_prob_t = t[:, None, None]
|
| 377 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 378 |
+
|
| 379 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 380 |
+
|
| 381 |
+
if attn_mask is None:
|
| 382 |
+
attn_mask = torch.ones_like(token_array).to(self.device)
|
| 383 |
+
|
| 384 |
+
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 385 |
+
p_x0 = log_p.exp()
|
| 386 |
+
|
| 387 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 388 |
+
|
| 389 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 390 |
+
|
| 391 |
+
# zero-masking probability
|
| 392 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 393 |
+
|
| 394 |
+
copy_flag = (token_array != self.mask_index)
|
| 395 |
+
|
| 396 |
+
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
|
| 397 |
+
changed_mask = (~copy_flag)
|
| 398 |
+
|
| 399 |
+
# compute the per-sequence log-probability under the pretrained model
|
| 400 |
+
log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1)
|
| 401 |
+
|
| 402 |
+
unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_policy_token.dtype)
|
| 403 |
+
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
|
| 404 |
+
|
| 405 |
+
# returns:
|
| 406 |
+
# log_policy_step (B, ) log probability x_next tokens under policy
|
| 407 |
+
if log_policy_step.ndim == 1:
|
| 408 |
+
log_policy_step = log_policy_step.squeeze(0)
|
| 409 |
+
|
| 410 |
+
return log_policy_step
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def single_reverse_step(self, token_array, t, dt, p_x0=None, attn_mask=None):
|
| 414 |
+
torch.cuda.empty_cache()
|
| 415 |
+
dev = self.device
|
| 416 |
+
self.backbone.to(dev).eval()
|
| 417 |
+
self.noise.eval()
|
| 418 |
+
|
| 419 |
+
t = t.to(dev)
|
| 420 |
+
dt = torch.as_tensor(dt, device=dev, dtype=t.dtype)
|
| 421 |
+
assert self.config.noise.type == 'loglinear'
|
| 422 |
+
sigma_t, _ = self.noise(t)
|
| 423 |
+
sigma_t = sigma_t.to(dev)
|
| 424 |
+
|
| 425 |
+
if t.ndim > 1:
|
| 426 |
+
t = t.squeeze(-1)
|
| 427 |
+
assert t.ndim == 1
|
| 428 |
+
|
| 429 |
+
change_prob_t = t[:, None, None]
|
| 430 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 431 |
+
|
| 432 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 433 |
+
|
| 434 |
+
if attn_mask is None:
|
| 435 |
+
attn_mask = torch.ones_like(token_array, device=dev, dtype=torch.long)
|
| 436 |
+
else:
|
| 437 |
+
attn_mask = attn_mask.to(dev)
|
| 438 |
+
|
| 439 |
+
if p_x0 is None:
|
| 440 |
+
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 441 |
+
p_x0 = log_p.exp()
|
| 442 |
+
else:
|
| 443 |
+
# ensure provided p_x0 is on dev
|
| 444 |
+
log_p = None
|
| 445 |
+
p_x0 = p_x0.to(dev)
|
| 446 |
+
|
| 447 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 448 |
+
|
| 449 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 450 |
+
|
| 451 |
+
# zero-masking probability
|
| 452 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 453 |
+
|
| 454 |
+
x_changed = _sample_categorical(q_xs)
|
| 455 |
+
if x_changed.device != dev or x_changed.dtype != token_array.dtype:
|
| 456 |
+
x_changed = x_changed.to(dev, dtype=token_array.dtype)
|
| 457 |
+
|
| 458 |
+
copy_flag = (token_array != self.mask_index)
|
| 459 |
+
|
| 460 |
+
int_copy_flag = copy_flag.to(token_array.dtype)
|
| 461 |
+
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
|
| 462 |
+
|
| 463 |
+
# returns:
|
| 464 |
+
# log_p (B, L, D) log probabilties of each token under the policy model
|
| 465 |
+
# x_next (B, L) next sequences
|
| 466 |
+
return log_p, x_next
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def single_noise_removal(self, token_array, t, dt, p_x0=None, attn_mask=None):
|
| 470 |
+
torch.cuda.empty_cache()
|
| 471 |
+
self.backbone.eval()
|
| 472 |
+
self.noise.eval()
|
| 473 |
+
|
| 474 |
+
assert self.config.noise.type == 'loglinear'
|
| 475 |
+
sigma_t, _ = self.noise(t)
|
| 476 |
+
|
| 477 |
+
if t.ndim > 1:
|
| 478 |
+
t = t.squeeze(-1)
|
| 479 |
+
assert t.ndim == 1
|
| 480 |
+
|
| 481 |
+
change_prob_t = t[:, None, None]
|
| 482 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 483 |
+
|
| 484 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 485 |
+
|
| 486 |
+
if attn_mask is None:
|
| 487 |
+
attn_mask = torch.ones_like(token_array).to(self.device)
|
| 488 |
+
|
| 489 |
+
if p_x0 is None:
|
| 490 |
+
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 491 |
+
p_x0 = log_p.exp()
|
| 492 |
+
|
| 493 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 494 |
+
|
| 495 |
+
# changed for noise removal
|
| 496 |
+
p_x0 = p_x0.clone()
|
| 497 |
+
p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask
|
| 498 |
+
p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK
|
| 499 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 500 |
+
|
| 501 |
+
x_changed = _sample_categorical(q_xs)
|
| 502 |
+
|
| 503 |
+
copy_flag = (token_array != self.mask_index)
|
| 504 |
+
|
| 505 |
+
int_copy_flag = copy_flag.to(token_array.dtype)
|
| 506 |
+
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
|
| 507 |
+
|
| 508 |
+
# returns:
|
| 509 |
+
# log_p (B, L, D) log probabilties of each token under the policy model
|
| 510 |
+
# x_next (B, L) next sequences
|
| 511 |
+
return log_p, x_next
|
| 512 |
+
|
| 513 |
+
def mcts_reverse_step(self, token_array, t, dt, pretrained, p_x0=None, attn_mask=None):
|
| 514 |
+
torch.cuda.empty_cache()
|
| 515 |
+
self.backbone.eval()
|
| 516 |
+
self.noise.eval()
|
| 517 |
+
assert self.config.noise.type == 'loglinear'
|
| 518 |
+
sigma_t, _ = self.noise(t)
|
| 519 |
+
|
| 520 |
+
if t.ndim > 1:
|
| 521 |
+
t = t.squeeze(-1)
|
| 522 |
+
assert t.ndim == 1
|
| 523 |
+
|
| 524 |
+
change_prob_t = t[:, None, None]
|
| 525 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 526 |
+
|
| 527 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 528 |
+
|
| 529 |
+
if attn_mask is None:
|
| 530 |
+
attn_mask = torch.ones_like(token_array).to(self.device)
|
| 531 |
+
|
| 532 |
+
if p_x0 is None:
|
| 533 |
+
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 534 |
+
p_x0 = log_p.exp()
|
| 535 |
+
|
| 536 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 537 |
+
|
| 538 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 539 |
+
|
| 540 |
+
# zero-masking probability
|
| 541 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 542 |
+
|
| 543 |
+
x_changed = _sample_categorical(q_xs)
|
| 544 |
+
|
| 545 |
+
copy_flag = (token_array != self.mask_index)
|
| 546 |
+
|
| 547 |
+
int_copy_flag = copy_flag.to(token_array.dtype)
|
| 548 |
+
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
|
| 549 |
+
|
| 550 |
+
# compute the log-probability under pretrained model at each step
|
| 551 |
+
with torch.no_grad():
|
| 552 |
+
# pretrained should output log-probs over vocab at each position given the *parent* (masked) input
|
| 553 |
+
log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 554 |
+
|
| 555 |
+
# log-prob of the *sampled token* at each position
|
| 556 |
+
log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
|
| 557 |
+
|
| 558 |
+
# sum only over the sites actually sampled this step (i.e., where parent was mask)
|
| 559 |
+
|
| 560 |
+
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
|
| 561 |
+
changed_mask = (~copy_flag)
|
| 562 |
+
# mask of tokens that were unmasked in this step
|
| 563 |
+
unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype)
|
| 564 |
+
|
| 565 |
+
log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
|
| 566 |
+
|
| 567 |
+
# compute the per-sequence log-probability under the pretrained model
|
| 568 |
+
log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
|
| 569 |
+
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
|
| 570 |
+
|
| 571 |
+
# returns:
|
| 572 |
+
# log_p (B, L, D) log probabilties of each token under the policy model
|
| 573 |
+
# x_next (B, L) next sequences
|
| 574 |
+
# log_policy_step (B, ) log probability of all unmasked tokens under policy
|
| 575 |
+
# log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
|
| 576 |
+
return log_p, x_next, log_policy_step, log_pretrained_step
|
| 577 |
+
|
| 578 |
+
def mcts_noise_removal(self, token_array, t, dt, pretrained, p_x0=None, attn_mask=None):
|
| 579 |
+
torch.cuda.empty_cache()
|
| 580 |
+
self.backbone.eval()
|
| 581 |
+
self.noise.eval()
|
| 582 |
+
|
| 583 |
+
assert self.config.noise.type == 'loglinear'
|
| 584 |
+
sigma_t, _ = self.noise(t)
|
| 585 |
+
|
| 586 |
+
if t.ndim > 1:
|
| 587 |
+
t = t.squeeze(-1)
|
| 588 |
+
assert t.ndim == 1
|
| 589 |
+
|
| 590 |
+
change_prob_t = t[:, None, None]
|
| 591 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 592 |
+
|
| 593 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 594 |
+
|
| 595 |
+
if attn_mask is None:
|
| 596 |
+
attn_mask = torch.ones_like(token_array).to(self.device)
|
| 597 |
+
|
| 598 |
+
if p_x0 is None:
|
| 599 |
+
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 600 |
+
p_x0 = log_p.exp()
|
| 601 |
+
|
| 602 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 603 |
+
|
| 604 |
+
# changed for noise removal
|
| 605 |
+
p_x0 = p_x0.clone()
|
| 606 |
+
p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask
|
| 607 |
+
p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK
|
| 608 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 609 |
+
|
| 610 |
+
x_changed = _sample_categorical(q_xs)
|
| 611 |
+
|
| 612 |
+
copy_flag = (token_array != self.mask_index)
|
| 613 |
+
|
| 614 |
+
int_copy_flag = copy_flag.to(token_array.dtype)
|
| 615 |
+
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
|
| 616 |
+
|
| 617 |
+
# compute the log-probability under pretrained model at each step
|
| 618 |
+
with torch.no_grad():
|
| 619 |
+
# pretrained should output log-probs over vocab at each position given the *parent* (masked) input
|
| 620 |
+
log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 621 |
+
|
| 622 |
+
# log-prob of the *sampled token* at each position
|
| 623 |
+
log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
|
| 624 |
+
|
| 625 |
+
# sum only over the sites actually sampled this step (i.e., where parent was mask)
|
| 626 |
+
|
| 627 |
+
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
|
| 628 |
+
changed_mask = (~copy_flag)
|
| 629 |
+
# mask of tokens that were unmasked in this step
|
| 630 |
+
unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype)
|
| 631 |
+
|
| 632 |
+
log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
|
| 633 |
+
|
| 634 |
+
# compute the per-sequence log-probability under the pretrained model
|
| 635 |
+
log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
|
| 636 |
+
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
|
| 637 |
+
|
| 638 |
+
# returns:
|
| 639 |
+
# log_p (B, L, D) log probabilties of each token under the policy model
|
| 640 |
+
# x_next (B, L) next sequences
|
| 641 |
+
# log_policy_step (B, ) log probability of all unmasked tokens under policy
|
| 642 |
+
# log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
|
| 643 |
+
return log_p, x_next, log_policy_step, log_pretrained_step
|
| 644 |
+
|
| 645 |
+
# first step in expansion
|
| 646 |
+
def batch_mcts_reverse_step(self, token_array, t, dt, batch_size, pretrained, p_x0=None, attn_mask=None):
|
| 647 |
+
torch.cuda.empty_cache()
|
| 648 |
+
self.backbone.eval()
|
| 649 |
+
self.noise.eval()
|
| 650 |
+
|
| 651 |
+
assert self.config.noise.type == 'loglinear'
|
| 652 |
+
sigma_t, _ = self.noise(t)
|
| 653 |
+
|
| 654 |
+
if t.ndim > 1:
|
| 655 |
+
t = t.squeeze(-1)
|
| 656 |
+
assert t.ndim == 1
|
| 657 |
+
|
| 658 |
+
change_prob_t = t[:, None, None]
|
| 659 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 660 |
+
|
| 661 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 662 |
+
|
| 663 |
+
if token_array.dim() == 1:
|
| 664 |
+
token_array = token_array.unsqueeze(0)
|
| 665 |
+
|
| 666 |
+
# expand to match (num_children, L)
|
| 667 |
+
if attn_mask is None:
|
| 668 |
+
attn_mask = torch.ones_like(token_array).to(self.device)
|
| 669 |
+
|
| 670 |
+
token_array = token_array.to(self.device)
|
| 671 |
+
sigma_t = sigma_t.to(self.device)
|
| 672 |
+
|
| 673 |
+
# ====== INPUT VALIDATION for batch_mcts_reverse_step ======
|
| 674 |
+
token_min = token_array.min().item()
|
| 675 |
+
token_max = token_array.max().item()
|
| 676 |
+
if token_min < 0 or token_max >= self.vocab_size:
|
| 677 |
+
raise ValueError(
|
| 678 |
+
f"batch_mcts_reverse_step: Invalid token IDs in token_array: "
|
| 679 |
+
f"min={token_min}, max={token_max}, vocab_size={self.vocab_size}"
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
if p_x0 is None:
|
| 683 |
+
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 684 |
+
p_x0 = log_p.exp()
|
| 685 |
+
|
| 686 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 687 |
+
|
| 688 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 689 |
+
|
| 690 |
+
# zero-masking probability
|
| 691 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 692 |
+
|
| 693 |
+
# repeat the parent token along the first dimension which will be unmasked into distinct sequences
|
| 694 |
+
token_array_expanded = token_array.repeat(batch_size, 1)
|
| 695 |
+
|
| 696 |
+
if self.config.mcts.sampling == 0:
|
| 697 |
+
x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
|
| 698 |
+
else:
|
| 699 |
+
x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
|
| 700 |
+
|
| 701 |
+
copy_flag = (token_array_expanded != self.mask_index)
|
| 702 |
+
|
| 703 |
+
int_copy_flag = copy_flag.to(token_array.dtype)
|
| 704 |
+
x_children = int_copy_flag * token_array_expanded + (1 - int_copy_flag) * x_changed
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
# compute the log-probability under pretrained model at each step
|
| 708 |
+
with torch.no_grad():
|
| 709 |
+
# pretrained should output log-probs over vocab at each position given the *parent* (masked) input
|
| 710 |
+
log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 711 |
+
|
| 712 |
+
# expand to match the shape of x_children
|
| 713 |
+
log_pre = log_pre.repeat(batch_size, 1, 1)
|
| 714 |
+
|
| 715 |
+
# log-prob of the *sampled token* at each position
|
| 716 |
+
log_pre_token = log_pre.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
|
| 717 |
+
|
| 718 |
+
# sum only over the sites actually sampled this step (i.e., where parent was mask)
|
| 719 |
+
|
| 720 |
+
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
|
| 721 |
+
changed_mask = (~copy_flag)
|
| 722 |
+
# mask of tokens that were unmasked in this step
|
| 723 |
+
unmasked_this_step = (changed_mask & (x_children != self.mask_index)).to(log_pre_token.dtype)
|
| 724 |
+
|
| 725 |
+
log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
|
| 726 |
+
|
| 727 |
+
# compute the per-child log-probability under the pretrained model
|
| 728 |
+
log_p = log_p.repeat(batch_size, 1, 1)
|
| 729 |
+
log_policy_token = log_p.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # (B, L) probability of each chosen token
|
| 730 |
+
#print(log_policy_token)
|
| 731 |
+
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
|
| 732 |
+
|
| 733 |
+
# returns:
|
| 734 |
+
# log_p (B, L, D) log probabilties of each token under the policy model
|
| 735 |
+
# x_children (B, L) child sequences
|
| 736 |
+
# log_policy_step (B, ) log probability of all unmasked tokens under policy
|
| 737 |
+
# log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
|
| 738 |
+
return log_p, x_children, log_policy_step, log_pretrained_step
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def compute_invalid_loss(self, logits, k=None, temp=None):
|
| 742 |
+
"""
|
| 743 |
+
Penalizes logits that produce invalid sequences using the `is_peptide` function,
|
| 744 |
+
scaling penalties inversely with token probabilities.
|
| 745 |
+
|
| 746 |
+
Args:
|
| 747 |
+
logits: Tensor of shape [batch_size, seq_len, vocab_size].
|
| 748 |
+
k: Number of samples for Gumbel-Rao.
|
| 749 |
+
temp: Temperature for softmax.
|
| 750 |
+
|
| 751 |
+
Returns:
|
| 752 |
+
loss: A scalar tensor representing the total loss for invalid sequences.
|
| 753 |
+
"""
|
| 754 |
+
|
| 755 |
+
#samples = self.gumbel_rao(logits, k=k, temp=temp) # (batch_size, seq_len, vocab_size)
|
| 756 |
+
|
| 757 |
+
# Convert logits to sequences using the tokenizer
|
| 758 |
+
batch_token_ids = logits.argmax(dim=-1).to(self.device) # (batch_size, seq_len)
|
| 759 |
+
sampled_sequences = self.tokenizer.batch_decode(batch_token_ids)
|
| 760 |
+
|
| 761 |
+
# Check validity of each sampled sequence (not differentiable)
|
| 762 |
+
penalties = torch.tensor(
|
| 763 |
+
[1 if not self.analyzer.is_peptide(seq) else 0 for seq in sampled_sequences],
|
| 764 |
+
dtype=torch.float32,
|
| 765 |
+
device=self.device
|
| 766 |
+
)
|
| 767 |
+
#print(penalties)
|
| 768 |
+
|
| 769 |
+
# Compute probabilities for each token (batch_size, seq_length)
|
| 770 |
+
sampled_probs = torch.softmax(logits, dim=-1).gather(dim=-1, index=batch_token_ids.unsqueeze(-1)).squeeze(-1).to(self.device)
|
| 771 |
+
|
| 772 |
+
# scale penalties by softmax probability of sampled tokens
|
| 773 |
+
scaled_penalty = penalties[:, None] * sampled_probs # (batch_size, seq_length)
|
| 774 |
+
|
| 775 |
+
return scaled_penalty.to(self.device)
|
| 776 |
+
|
| 777 |
+
### DIFFUSION LOSS ###
|
| 778 |
+
|
| 779 |
+
def sample_t(self, n, device):
|
| 780 |
+
"""
|
| 781 |
+
Sample random time steps for batch training
|
| 782 |
+
"""
|
| 783 |
+
# sample values uniformly at random from [0, 1)
|
| 784 |
+
eps_t = torch.rand(n, device=device)
|
| 785 |
+
# antithetic sampling: reduce variance by pairing each sample with complementary sample
|
| 786 |
+
if self.config.training.antithetic_sampling:
|
| 787 |
+
# compute interval between sampled time steps
|
| 788 |
+
offset = torch.arange(n, device=device) / n
|
| 789 |
+
# ensure that each eps value is evenly spaced between [0, 1)
|
| 790 |
+
eps_t = ((eps_t / n) + offset) % 1
|
| 791 |
+
|
| 792 |
+
# ensures values are not exactly 0 or 1
|
| 793 |
+
t = (1 - self.config.training.sampling_eps) * eps_t + self.config.training.sampling_eps
|
| 794 |
+
|
| 795 |
+
return t
|
| 796 |
+
|
| 797 |
+
"""def mask_samples(self, x0, mask_prob):
|
| 798 |
+
|
| 799 |
+
# generate array of values in range [0, 1] uniformly at random
|
| 800 |
+
# will be used to determine which tokens are masked
|
| 801 |
+
mask_indices = torch.rand(* x0.shape, device=x0.device) # (batch_size, L)
|
| 802 |
+
|
| 803 |
+
# select tokens to mask if the random value in mask_indices is less than mask_prob
|
| 804 |
+
# this will mask approximately the fraction of tokens indicated by mask_prob
|
| 805 |
+
zt = torch.where(mask_indices < mask_prob, self.mask_index, x0)
|
| 806 |
+
|
| 807 |
+
return zt"""
|
| 808 |
+
|
| 809 |
+
def q_xt(self, x, mask_prob):
|
| 810 |
+
"""Computes the noisy sample xt.
|
| 811 |
+
|
| 812 |
+
Args:
|
| 813 |
+
x: int torch.Tensor with shape (batch_size,
|
| 814 |
+
diffusion_model_input_length), input.
|
| 815 |
+
move_chance: float torch.Tensor with shape (batch_size, 1).
|
| 816 |
+
"""
|
| 817 |
+
|
| 818 |
+
actual_seq_length = (x != 0).sum(dim=-1, keepdim=True)
|
| 819 |
+
#print(actual_seq_length)
|
| 820 |
+
|
| 821 |
+
max_mask_length = (actual_seq_length * 0.75).long()
|
| 822 |
+
|
| 823 |
+
mask_indices = torch.rand(*x.shape, device=x.device) < mask_prob
|
| 824 |
+
|
| 825 |
+
restricted_move_indices = torch.zeros_like(mask_indices, dtype=torch.bool)
|
| 826 |
+
|
| 827 |
+
for i in range(x.shape[0]):
|
| 828 |
+
true_positions = torch.where(mask_indices[i])[0]
|
| 829 |
+
if len(true_positions) > max_mask_length[i]:
|
| 830 |
+
selected_positions = true_positions[:max_mask_length[i].item()]
|
| 831 |
+
restricted_move_indices[i, selected_positions] = True
|
| 832 |
+
else:
|
| 833 |
+
restricted_move_indices[i] = mask_indices[i]
|
| 834 |
+
|
| 835 |
+
xt = torch.where(restricted_move_indices, self.tokenizer.mask_token_id, x)
|
| 836 |
+
|
| 837 |
+
return xt
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
def sample_prior(self, *batch_dims):
|
| 841 |
+
"""
|
| 842 |
+
Returns array of fully masked sequences with same shape as input
|
| 843 |
+
"""
|
| 844 |
+
return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64)
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
### COMPUTING LOSS ###
|
| 848 |
+
|
| 849 |
+
def compute_diffusion_loss(self, model_output, xt, x0, t):
|
| 850 |
+
"""
|
| 851 |
+
Computes diffusion loss term in ELBO
|
| 852 |
+
(evaluates how accurately the model predicts the token probabilities at each time step)
|
| 853 |
+
|
| 854 |
+
Inputs:
|
| 855 |
+
- model_output: [sequence length, vocab size, vocab size] array of logits for each token at each sequence position
|
| 856 |
+
- zt: corrupted version of original input x0 at timestep t
|
| 857 |
+
- x0: original input sequence
|
| 858 |
+
- t: timestep
|
| 859 |
+
"""
|
| 860 |
+
# compute interval between each timestep
|
| 861 |
+
dt = 1 / self.T
|
| 862 |
+
|
| 863 |
+
# compute vectorized alpha scaling terms for the logits at timestep s and t
|
| 864 |
+
alpha_t = 1 - t + torch.zeros_like(x0)
|
| 865 |
+
# s = t - dt
|
| 866 |
+
alpha_s = 1 - (t - dt) + torch.zeros_like(x0)
|
| 867 |
+
|
| 868 |
+
# gather vector of log-probabilities for each token in x0
|
| 869 |
+
# log<x_theta, x>
|
| 870 |
+
log_x_theta_at_x0 = torch.gather(model_output, -1, x0[:, :, None]) # shape (B, L, vocab_size)
|
| 871 |
+
# gather log-probabillities for assigning a masked token at each position in the sequence at time t
|
| 872 |
+
# log<x_theta, m>
|
| 873 |
+
log_x_theta_at_m = model_output[:, :, self.mask_index]
|
| 874 |
+
# obtain non-log probability of assigning a masked token
|
| 875 |
+
# <xt, m>
|
| 876 |
+
x_theta_at_m = log_x_theta_at_m.exp()
|
| 877 |
+
|
| 878 |
+
# first term of diffusion loss
|
| 879 |
+
term_1_coef = dt / t
|
| 880 |
+
term_1_log_numerator = torch.log((alpha_t * x_theta_at_m) / t + 1)
|
| 881 |
+
term_1_log_denom = log_x_theta_at_x0
|
| 882 |
+
|
| 883 |
+
# second term of diffusion loss
|
| 884 |
+
term_2_coef = 1 - (dt / t)
|
| 885 |
+
term_2_log_numerator = term_1_log_numerator
|
| 886 |
+
term_2_log_denom = torch.log((alpha_s * x_theta_at_m) / (t - dt) + 1)
|
| 887 |
+
|
| 888 |
+
L_vb_masked = (term_1_coef * (term_1_log_numerator - term_1_log_denom) +
|
| 889 |
+
term_2_coef * (term_2_log_numerator - term_2_log_denom))
|
| 890 |
+
|
| 891 |
+
# multiply by <zt, m> term
|
| 892 |
+
L_vb = L_vb_masked * (xt == self.mask_index)
|
| 893 |
+
|
| 894 |
+
# scale by T and return
|
| 895 |
+
return self.T * L_vb
|
| 896 |
+
|
| 897 |
+
def _forward_pass_diffusion(self, x0, attn_mask, bond_mask=None, mask=None):
|
| 898 |
+
"""
|
| 899 |
+
Training reverse diffusion model x_theta to reconstruct samples x0
|
| 900 |
+
|
| 901 |
+
bond_mask: (batch, seq_length)
|
| 902 |
+
"""
|
| 903 |
+
# randomly sample time steps to start the denoising process for each x0 in batch
|
| 904 |
+
t = self.sample_t(x0.shape[0], self.device)
|
| 905 |
+
|
| 906 |
+
# if we are training the intermediate transition blocks
|
| 907 |
+
if self.T > 0:
|
| 908 |
+
# scale by total timesteps T and cast to integer
|
| 909 |
+
t = (t * self.T).to(torch.int)
|
| 910 |
+
# scale down by T to get a multiple of 1/T
|
| 911 |
+
t = t / self.T
|
| 912 |
+
# add 1/T to ensure no 0 values
|
| 913 |
+
t += (1 / self.T)
|
| 914 |
+
|
| 915 |
+
# get noise and rate of noise at timestep t
|
| 916 |
+
# sigma = -log(1-t); dsigma = 1 / (1-t)
|
| 917 |
+
sigma, dsigma = self.noise(t)
|
| 918 |
+
time_conditioning = sigma[:, None]
|
| 919 |
+
|
| 920 |
+
# Get masking probabilities for all tokens for each batch
|
| 921 |
+
# log-linear: 1 - alpha = t
|
| 922 |
+
base_mask_prob = 1 - torch.exp(-sigma[:, None]) # (batch_size, L)
|
| 923 |
+
|
| 924 |
+
if self.config.noise.state_dependent and (bond_mask is not None):
|
| 925 |
+
# log-polynomial masking schedule: alpha = 1 - t^w
|
| 926 |
+
# bond_sigma = -log(1-t^w) for w = 3 (default)
|
| 927 |
+
# bond_dsigma = -wt^(w-1) / (1-t^w)
|
| 928 |
+
bond_sigma, bond_dsigma = self.bond_noise(t) # scalar
|
| 929 |
+
# expand dimensions for broadcasting to (B, L)
|
| 930 |
+
bond_sigma = bond_sigma[:, None]
|
| 931 |
+
bond_dsigma = bond_dsigma[:, None]
|
| 932 |
+
sigma = sigma[:, None]
|
| 933 |
+
dsigma = dsigma[:, None]
|
| 934 |
+
|
| 935 |
+
# compute masking probability for peptide bonds 1 - bond_alpha = t^w
|
| 936 |
+
bond_mask_prob = 1 - torch.exp(-bond_sigma).to(self.device)
|
| 937 |
+
# piece together (B, L) tensor with modified masking prob at peptide-bond locations
|
| 938 |
+
mask_prob = torch.where(bond_mask == 1, bond_mask_prob, base_mask_prob).to(self.device)
|
| 939 |
+
#print(mask_prob)
|
| 940 |
+
dsigma = torch.where(bond_mask == 1, bond_dsigma, dsigma).to(self.device)
|
| 941 |
+
sigma = torch.where(bond_mask == 1, bond_sigma, sigma).to(self.device)
|
| 942 |
+
else:
|
| 943 |
+
mask_prob = base_mask_prob.to(self.device)
|
| 944 |
+
|
| 945 |
+
# get masked samples at different timesteps
|
| 946 |
+
if mask is None:
|
| 947 |
+
zt = self.q_xt(x0, mask_prob).to(self.device)
|
| 948 |
+
else:
|
| 949 |
+
zt = x0.where(mask==1, torch.full_like(x0, self.mask_index)).to(self.device)
|
| 950 |
+
|
| 951 |
+
model_output = self.forward(zt, attn_mask=attn_mask.to(self.device), sigma=time_conditioning).to(self.device)
|
| 952 |
+
|
| 953 |
+
# debugging
|
| 954 |
+
assert not torch.isnan(model_output).any()
|
| 955 |
+
assert model_output.is_cuda
|
| 956 |
+
utils.print_nans(model_output, 'model_output')
|
| 957 |
+
|
| 958 |
+
# compute invalid loss
|
| 959 |
+
invalid_loss = self.compute_invalid_loss(logits=model_output).to(self.device) # (B, L)
|
| 960 |
+
#print(invalid_loss)
|
| 961 |
+
|
| 962 |
+
if self.T > 0:
|
| 963 |
+
# compute diffusion loss
|
| 964 |
+
diffusion_loss = self.compute_diffusion_loss(model_output, zt, x0, t)
|
| 965 |
+
return diffusion_loss
|
| 966 |
+
|
| 967 |
+
# compute loss for the final that converts from z0 to x0
|
| 968 |
+
# -log(p_theta)
|
| 969 |
+
# get (batch_size, L) array of log-probabilities
|
| 970 |
+
log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1).to(self.device) # (B, L)
|
| 971 |
+
|
| 972 |
+
if self.config.noise.state_dependent and (bond_mask is not None):
|
| 973 |
+
return (-log_p_theta * (dsigma / torch.expm1(sigma)) + invalid_loss).to(self.device)
|
| 974 |
+
else:
|
| 975 |
+
return ((-log_p_theta * (dsigma / torch.expm1(sigma))[:, None]) + invalid_loss).to(self.device)
|
| 976 |
+
|
| 977 |
+
def _loss(self, x0, attn_mask, bond_mask=None, mask=None):
|
| 978 |
+
loss = self._forward_pass_diffusion(x0, attn_mask, bond_mask, mask)
|
| 979 |
+
|
| 980 |
+
# negative log loss
|
| 981 |
+
nlls = loss * attn_mask
|
| 982 |
+
|
| 983 |
+
# count number of tokens
|
| 984 |
+
num_tokens = attn_mask.sum()
|
| 985 |
+
|
| 986 |
+
# compute batch loss
|
| 987 |
+
batch_nll = nlls.sum()
|
| 988 |
+
# compute per token loss
|
| 989 |
+
token_nll = batch_nll / num_tokens
|
| 990 |
+
# return losses
|
| 991 |
+
return Loss(loss = token_nll.to(self.device), nlls = nlls.to(self.device), attn_mask = attn_mask.to(self.device))
|
| 992 |
+
|
| 993 |
+
def _compute_loss(self, batch, prefix, bond_mask=None):
|
| 994 |
+
|
| 995 |
+
attn_mask = batch['attention_mask'].to(self.device)
|
| 996 |
+
|
| 997 |
+
if 'mask' in batch:
|
| 998 |
+
mask = batch['mask'].to(self.device)
|
| 999 |
+
else:
|
| 1000 |
+
mask = None
|
| 1001 |
+
|
| 1002 |
+
if 'bond_mask' in batch:
|
| 1003 |
+
bond_mask = batch['bond_mask'].to(self.device)
|
| 1004 |
+
else:
|
| 1005 |
+
bond_mask = None
|
| 1006 |
+
|
| 1007 |
+
losses = self._loss(batch['input_ids'].to(self.device), attn_mask, bond_mask, mask)
|
| 1008 |
+
loss = losses.loss
|
| 1009 |
+
|
| 1010 |
+
if prefix == 'train':
|
| 1011 |
+
self.train_metrics.update(
|
| 1012 |
+
losses.nlls.to(self.device),
|
| 1013 |
+
losses.attn_mask.to(self.device)
|
| 1014 |
+
)
|
| 1015 |
+
metrics = self.train_metrics
|
| 1016 |
+
elif prefix == 'val':
|
| 1017 |
+
self.valid_metrics.update(
|
| 1018 |
+
losses.nlls.to(self.device),
|
| 1019 |
+
losses.attn_mask.to(self.device)
|
| 1020 |
+
)
|
| 1021 |
+
metrics = self.valid_metrics
|
| 1022 |
+
elif prefix == 'test':
|
| 1023 |
+
self.test_metrics.update(losses.nlls, losses.attn_mask)
|
| 1024 |
+
metrics = self.test_metrics
|
| 1025 |
+
else:
|
| 1026 |
+
raise ValueError(f'Invalid prefix: {prefix}')
|
| 1027 |
+
|
| 1028 |
+
self.log_dict(metrics,
|
| 1029 |
+
on_step=False,
|
| 1030 |
+
on_epoch=True,
|
| 1031 |
+
sync_dist=True)
|
| 1032 |
+
|
| 1033 |
+
return loss
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
### SAMPLING ###
|
| 1037 |
+
|
| 1038 |
+
def generate_from_masked(self, num_samples=None, seq_length=None, sample_steps=128, eps=1e-5):
|
| 1039 |
+
# get number of timesteps
|
| 1040 |
+
if sample_steps is None:
|
| 1041 |
+
sample_steps = self.config.sampling.steps
|
| 1042 |
+
|
| 1043 |
+
if seq_length is None:
|
| 1044 |
+
seq_length = self.config.sampling.seq_length
|
| 1045 |
+
|
| 1046 |
+
# sample fully masked sequences
|
| 1047 |
+
z = self.sample_prior(num_samples, seq_length).to(self.device)
|
| 1048 |
+
|
| 1049 |
+
# create vector of sample_steps timesteps
|
| 1050 |
+
timesteps = torch.linspace(1, eps, sample_steps + 1, device=self.device)
|
| 1051 |
+
|
| 1052 |
+
# compute interval between timesteps
|
| 1053 |
+
dt = (1 - eps) / sample_steps
|
| 1054 |
+
|
| 1055 |
+
for i in range(sample_steps):
|
| 1056 |
+
t = timesteps[i] * torch.ones(z.shape[0], 1, device=self.device)
|
| 1057 |
+
|
| 1058 |
+
z = self.single_reverse_step(z, t, dt)
|
| 1059 |
+
|
| 1060 |
+
return z
|
| 1061 |
+
|
| 1062 |
+
|
| 1063 |
+
### SAMPLING STEP ###
|
| 1064 |
+
"""
|
| 1065 |
+
def single_reverse_step(self, zt, t, dt, attn_mask=None):
|
| 1066 |
+
# get sigma values that determine masking prob
|
| 1067 |
+
sigma_t, _ = self.noise(t)
|
| 1068 |
+
sigma_s, _ = self.noise(t - dt)
|
| 1069 |
+
|
| 1070 |
+
# reshape sigmas
|
| 1071 |
+
if sigma_t.ndim > 1:
|
| 1072 |
+
sigma_t = sigma_t.squeeze(-1)
|
| 1073 |
+
if sigma_s.ndim > 1:
|
| 1074 |
+
sigma_s = sigma_s.squeeze(-1)
|
| 1075 |
+
assert sigma_t.ndim == 1, sigma_t.shape
|
| 1076 |
+
assert sigma_s.ndim == 1, sigma_s.shape
|
| 1077 |
+
|
| 1078 |
+
# compute masking probabilities for each timestep
|
| 1079 |
+
change_prob_t = 1 - torch.exp(-sigma_t)
|
| 1080 |
+
change_prob_s = 1 - torch.exp(-sigma_s)
|
| 1081 |
+
|
| 1082 |
+
# expand dimensions
|
| 1083 |
+
change_prob_t = change_prob_t[:, None, None]
|
| 1084 |
+
change_prob_s = change_prob_s[:, None, None]
|
| 1085 |
+
|
| 1086 |
+
# get prodiction model that outputs token probabilities
|
| 1087 |
+
log_p_x0 = self.forward(zt, attn_mask=attn_mask, sigma=sigma_t)
|
| 1088 |
+
|
| 1089 |
+
# check dimensions match
|
| 1090 |
+
assert change_prob_t.ndim == log_p_x0.ndim
|
| 1091 |
+
|
| 1092 |
+
# compute reverse diffusion probability of being unmasked at timestep s
|
| 1093 |
+
# (sigma_s - sigma_t)*x_theta
|
| 1094 |
+
q_zs = log_p_x0.exp() * (change_prob_t - change_prob_s)
|
| 1095 |
+
|
| 1096 |
+
# compute reverse diffusion probability of remaining masked at timestep s
|
| 1097 |
+
# (1 - sigma_s)*m
|
| 1098 |
+
q_zs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 1099 |
+
|
| 1100 |
+
# sample sequence at timestep s from categorical distribution of q_zs
|
| 1101 |
+
z_changed = _sample_categorical(q_zs)
|
| 1102 |
+
|
| 1103 |
+
copy_flag = (zt != self.mask_index).to(zt.dtype)
|
| 1104 |
+
return (copy_flag * zt) + ((1 - copy_flag) * z_changed)"""
|
| 1105 |
+
|
| 1106 |
+
def cached_reverse_step(self, x, t, dt, p_x0=None, attn_mask=None):
|
| 1107 |
+
assert self.config.noise.type == 'loglinear'
|
| 1108 |
+
sigma_t, _ = self.noise(t)
|
| 1109 |
+
|
| 1110 |
+
if t.ndim > 1:
|
| 1111 |
+
t = t.squeeze(-1)
|
| 1112 |
+
assert t.ndim == 1
|
| 1113 |
+
|
| 1114 |
+
change_prob_t = t[:, None, None]
|
| 1115 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 1116 |
+
|
| 1117 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 1118 |
+
|
| 1119 |
+
if p_x0 is None:
|
| 1120 |
+
p_x0 = self.forward(x, attn_mask=attn_mask, sigma=sigma_t).exp()
|
| 1121 |
+
|
| 1122 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 1123 |
+
|
| 1124 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 1125 |
+
|
| 1126 |
+
# zero-masking probability
|
| 1127 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 1128 |
+
|
| 1129 |
+
x_changed = _sample_categorical(q_xs)
|
| 1130 |
+
|
| 1131 |
+
copy_flag = (x != self.mask_index).to(x.dtype)
|
| 1132 |
+
|
| 1133 |
+
return p_x0, copy_flag * x + (1 - copy_flag) * x_changed
|
| 1134 |
+
|
| 1135 |
+
# first step in expansion
|
| 1136 |
+
def batch_cached_reverse_step(self, token_array, t, dt, batch_size, p_x0=None, attn_mask=None):
|
| 1137 |
+
"""
|
| 1138 |
+
Generates batch_size different samples from the same starting point for the
|
| 1139 |
+
first expansion step of MCTS
|
| 1140 |
+
"""
|
| 1141 |
+
|
| 1142 |
+
assert self.config.noise.type == 'loglinear'
|
| 1143 |
+
sigma_t, _ = self.noise(t)
|
| 1144 |
+
|
| 1145 |
+
if t.ndim > 1:
|
| 1146 |
+
t = t.squeeze(-1)
|
| 1147 |
+
assert t.ndim == 1
|
| 1148 |
+
|
| 1149 |
+
change_prob_t = t[:, None, None]
|
| 1150 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 1151 |
+
|
| 1152 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 1153 |
+
|
| 1154 |
+
if token_array.dim() == 1:
|
| 1155 |
+
token_array = token_array.unsqueeze(0)
|
| 1156 |
+
#token_array = token_array.repeat(batch_size, 1)
|
| 1157 |
+
|
| 1158 |
+
attn_mask = torch.ones_like(token_array).to(self.device)
|
| 1159 |
+
|
| 1160 |
+
if p_x0 is None:
|
| 1161 |
+
p_x0 = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t).exp()
|
| 1162 |
+
|
| 1163 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 1164 |
+
|
| 1165 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 1166 |
+
|
| 1167 |
+
# zero-masking probability
|
| 1168 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 1169 |
+
|
| 1170 |
+
# repeat the parent token along the first dimension which will be unmasked into distinct sequences
|
| 1171 |
+
token_array = token_array.repeat(batch_size, 1)
|
| 1172 |
+
|
| 1173 |
+
if self.config.mcts.sampling == 0:
|
| 1174 |
+
x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
|
| 1175 |
+
else:
|
| 1176 |
+
x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
|
| 1177 |
+
|
| 1178 |
+
copy_flag = (token_array != self.mask_index).to(token_array.dtype)
|
| 1179 |
+
|
| 1180 |
+
return p_x0, copy_flag * token_array + (1 - copy_flag) * x_changed
|
| 1181 |
+
|
| 1182 |
+
def _process_sigma(self, sigma):
|
| 1183 |
+
if sigma.ndim > 1:
|
| 1184 |
+
sigma = sigma.squeeze(-1)
|
| 1185 |
+
if not self.time_conditioning:
|
| 1186 |
+
sigma = torch.zeros_like(sigma)
|
| 1187 |
+
assert sigma.ndim == 1, sigma.shape
|
| 1188 |
+
return sigma
|
| 1189 |
+
|
| 1190 |
+
def forward(self, zt, attn_mask, sigma):
|
| 1191 |
+
"""
|
| 1192 |
+
Predicts the token log-probabilities from zt at time t with noise schedule sigma
|
| 1193 |
+
"""
|
| 1194 |
+
sigma = self._process_sigma(sigma)
|
| 1195 |
+
|
| 1196 |
+
# ====== INPUT VALIDATION (CPU-side) ======
|
| 1197 |
+
# Check 1: Token IDs must be in valid range [0, vocab_size - 1]
|
| 1198 |
+
zt_min = zt.min().item()
|
| 1199 |
+
zt_max = zt.max().item()
|
| 1200 |
+
if zt_min < 0 or zt_max >= self.vocab_size:
|
| 1201 |
+
raise ValueError(
|
| 1202 |
+
f"Invalid token IDs in zt: min={zt_min}, max={zt_max}, "
|
| 1203 |
+
f"vocab_size={self.vocab_size}. Token IDs must be in [0, {self.vocab_size-1}]"
|
| 1204 |
+
)
|
| 1205 |
+
|
| 1206 |
+
# Check 2: Sequence length must not exceed model's max_position_embeddings
|
| 1207 |
+
seq_len = zt.shape[1]
|
| 1208 |
+
max_pos = getattr(self.backbone.model.config, 'max_position_embeddings', 512)
|
| 1209 |
+
if seq_len > max_pos:
|
| 1210 |
+
raise ValueError(
|
| 1211 |
+
f"Sequence length {seq_len} exceeds model's max_position_embeddings {max_pos}. "
|
| 1212 |
+
f"Input shape: {zt.shape}"
|
| 1213 |
+
)
|
| 1214 |
+
|
| 1215 |
+
with torch.cuda.amp.autocast(dtype=torch.float32):
|
| 1216 |
+
logits = self.backbone.forward(input_ids=zt, attn_mask=attn_mask).to(self.device)
|
| 1217 |
+
|
| 1218 |
+
return self.subs_parameterization(logits, zt)
|
| 1219 |
+
|
| 1220 |
+
def subs_parameterization(self, logits, zt):
|
| 1221 |
+
"""
|
| 1222 |
+
Updates reverse diffusion logits based on SUBS parameterization:
|
| 1223 |
+
- zero masking probabilities: -infinity probability of being masked during reverse diffusion
|
| 1224 |
+
- carry-over unmasking: unmasked input tokens remain unchanged during reverse diffusion
|
| 1225 |
+
|
| 1226 |
+
Args:
|
| 1227 |
+
logits: vector of token probabilities for unmasking masked tokens
|
| 1228 |
+
zt: partially unmasked sequence at current timestep
|
| 1229 |
+
"""
|
| 1230 |
+
logits[:, :, self.mask_index] += self.neg_infinity # [sequence index, current token, next token]
|
| 1231 |
+
|
| 1232 |
+
|
| 1233 |
+
logits = (logits - torch.logsumexp(logits, dim=-1, keepdim=True)).to(self.device)
|
| 1234 |
+
|
| 1235 |
+
|
| 1236 |
+
unmasked_indices = (zt != self.mask_index).to(self.device) # shape: [200, seq_length]
|
| 1237 |
+
batch_idx, seq_idx = torch.where(unmasked_indices) # Get explicit indices
|
| 1238 |
+
batch_idx = batch_idx.to(self.device)
|
| 1239 |
+
seq_idx = seq_idx.to(self.device)
|
| 1240 |
+
tokens = zt[batch_idx, seq_idx].to(self.device) # Get the tokens at those positions
|
| 1241 |
+
|
| 1242 |
+
#assert logits.is_contiguous(), "logits tensor is not contiguous"
|
| 1243 |
+
#assert unmasked_indices.shape == zt.shape, "same shape"
|
| 1244 |
+
#assert not torch.isnan(logits).any(), "NaN values found in logits"
|
| 1245 |
+
#assert tokens.max() < logits.shape[-1], "token indices out of bounds"
|
| 1246 |
+
#assert batch_idx.max() < logits.shape[0], "batch index out of bounds"
|
| 1247 |
+
#assert seq_idx.max() < logits.shape[1], "seq index out of bounds"
|
| 1248 |
+
#assert batch_idx.device == seq_idx.device == logits.device == tokens.device, "device inconsistent"
|
| 1249 |
+
|
| 1250 |
+
logits[unmasked_indices] = self.neg_infinity # Set everything to -inf first
|
| 1251 |
+
|
| 1252 |
+
# CRITICAL FIX: Clip token indices to valid vocab range to prevent index out of bounds
|
| 1253 |
+
# This can happen with variable-length sequences or corrupted tokens
|
| 1254 |
+
tokens_for_indexing = zt[unmasked_indices]
|
| 1255 |
+
valid_token_mask = tokens_for_indexing < logits.shape[-1]
|
| 1256 |
+
|
| 1257 |
+
if not valid_token_mask.all():
|
| 1258 |
+
# Log warning about invalid tokens
|
| 1259 |
+
import logging
|
| 1260 |
+
logger = logging.getLogger(__name__)
|
| 1261 |
+
invalid_count = (~valid_token_mask).sum().item()
|
| 1262 |
+
max_invalid_token = tokens_for_indexing[~valid_token_mask].max().item() if invalid_count > 0 else 0
|
| 1263 |
+
logger.warning(f"Found {invalid_count} invalid token indices (max={max_invalid_token}, vocab_size={logits.shape[-1]}). Clipping to valid range.")
|
| 1264 |
+
|
| 1265 |
+
# Clip to valid range
|
| 1266 |
+
tokens_for_indexing = torch.clamp(tokens_for_indexing, 0, logits.shape[-1] - 1)
|
| 1267 |
+
|
| 1268 |
+
logits[unmasked_indices, tokens_for_indexing] = 0 # Set only the specific token positions to 0
|
| 1269 |
+
# return logits with SUBS parameterization
|
| 1270 |
+
return logits.to(self.device)
|
| 1271 |
+
|
| 1272 |
+
"""SAMPLING"""
|
| 1273 |
+
@torch.no_grad()
|
| 1274 |
+
def _sample(self, num_steps=None, eps=1e-5, x_input=None):
|
| 1275 |
+
"""
|
| 1276 |
+
Generate samples
|
| 1277 |
+
"""
|
| 1278 |
+
batch_size_per_gpu = self.config.eval.perplexity_batch_size
|
| 1279 |
+
|
| 1280 |
+
if num_steps is None:
|
| 1281 |
+
num_steps = self.config.sampling.steps
|
| 1282 |
+
|
| 1283 |
+
if x_input is not None:
|
| 1284 |
+
x = x_input['input_ids'].to(self.device)
|
| 1285 |
+
attn_mask = x_input['attention_mask'].to(self.device)
|
| 1286 |
+
else:
|
| 1287 |
+
x = self.sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device)
|
| 1288 |
+
attn_mask = torch.ones_like(x).to(self.device)
|
| 1289 |
+
|
| 1290 |
+
|
| 1291 |
+
timesteps = torch.linspace(1, eps, num_steps+1, device=self.device)
|
| 1292 |
+
dt = (1 - eps) / num_steps
|
| 1293 |
+
p_x0_cache = None
|
| 1294 |
+
generation_history = [] # used to track which tokens are unmasked
|
| 1295 |
+
|
| 1296 |
+
for i in range(num_steps):
|
| 1297 |
+
t = timesteps[i] * torch.ones(x.shape[0], 1, device = self.device)
|
| 1298 |
+
if self.sampler == 'ddpm':
|
| 1299 |
+
x = self.single_reverse_step(x, t, dt).to(self.device)
|
| 1300 |
+
|
| 1301 |
+
elif self.sampler == 'ddpm_cache':
|
| 1302 |
+
p_x0_cache, x_next = self.cached_reverse_step(x, t, dt, p_x0=p_x0_cache, attn_mask=attn_mask)
|
| 1303 |
+
if (not torch.allclose(x_next, x) or self.time_conditioning):
|
| 1304 |
+
# Disable caching
|
| 1305 |
+
p_x0_cache = None
|
| 1306 |
+
x = x_next.to(self.device)
|
| 1307 |
+
#print(self.tokenizer.decode(x.squeeze()))
|
| 1308 |
+
else:
|
| 1309 |
+
x = self._analytic_update(x, t, dt, attn_mask).to(self.device)
|
| 1310 |
+
|
| 1311 |
+
if self.config.sampling.noise_removal:
|
| 1312 |
+
t = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device)
|
| 1313 |
+
if self.sampler == 'analytic':
|
| 1314 |
+
x = self._denoiser_update(x, t).to(self.device)
|
| 1315 |
+
else:
|
| 1316 |
+
time_conditioning = self.noise(t)[0].to(self.device)
|
| 1317 |
+
x = self.forward(x, attn_mask=attn_mask, sigma=time_conditioning).argmax(dim=-1).to(self.device)
|
| 1318 |
+
#print(self.tokenizer.decode(x.squeeze()))
|
| 1319 |
+
return x.to(self.device)
|
| 1320 |
+
|
| 1321 |
+
|
| 1322 |
+
def restore_model_and_sample(self, num_steps, eps=1e-5):
|
| 1323 |
+
"""Generate samples from the model."""
|
| 1324 |
+
self.backbone.eval()
|
| 1325 |
+
self.noise.eval()
|
| 1326 |
+
samples = self._sample(num_steps=num_steps, eps=eps)
|
| 1327 |
+
self.backbone.train()
|
| 1328 |
+
self.noise.train()
|
| 1329 |
+
return samples
|
| 1330 |
+
|
| 1331 |
+
def get_score(self, zt, sigma, attn_mask=None):
|
| 1332 |
+
|
| 1333 |
+
# score(x, t) = p_t(y) / p_t(x)
|
| 1334 |
+
# => log score(x, t) = log p_t(y) - log p_t(x)
|
| 1335 |
+
|
| 1336 |
+
# case 1: x = masked
|
| 1337 |
+
# (i) y = unmasked
|
| 1338 |
+
# log score(x, t) = log p_\theta(x)|_y + log k
|
| 1339 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
| 1340 |
+
# (ii) y = masked
|
| 1341 |
+
# log score(x, t) = 0
|
| 1342 |
+
|
| 1343 |
+
# case 2: x = unmasked
|
| 1344 |
+
# (i) y != masked, y != x
|
| 1345 |
+
# log score(x_i, t) = - inf
|
| 1346 |
+
# (ii) y = x
|
| 1347 |
+
# log score(x_i, t) = 0
|
| 1348 |
+
# (iii) y = masked token
|
| 1349 |
+
# log score(x_i, t) = - log k
|
| 1350 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
| 1351 |
+
|
| 1352 |
+
model_output = self.forward(zt, attn_mask=attn_mask, sigma=sigma)
|
| 1353 |
+
|
| 1354 |
+
log_k = -torch.log(torch.expm1(sigma)).squeeze(-1)
|
| 1355 |
+
assert log_k.ndim == 1
|
| 1356 |
+
|
| 1357 |
+
masked_score = model_output + log_k[:, None, None]
|
| 1358 |
+
masked_score[:, :, self.mask_index] = 0
|
| 1359 |
+
|
| 1360 |
+
unmasked_score = self.neg_infinity * torch.ones_like(model_output)
|
| 1361 |
+
unmasked_score = torch.scatter(
|
| 1362 |
+
unmasked_score, -1,
|
| 1363 |
+
zt[..., None],
|
| 1364 |
+
torch.zeros_like(unmasked_score[..., :1]))
|
| 1365 |
+
|
| 1366 |
+
unmasked_score[:, :, self.mask_index] = - (log_k[:, None] * torch.ones_like(zt))
|
| 1367 |
+
|
| 1368 |
+
masked_indices = (zt == self.mask_index).to(model_output.dtype)[:, :, None]
|
| 1369 |
+
|
| 1370 |
+
model_output = (masked_score * masked_indices + unmasked_score * (1 - masked_indices))
|
| 1371 |
+
|
| 1372 |
+
return model_output.exp()
|
| 1373 |
+
|
| 1374 |
+
def _staggered_score(self, score, dsigma):
|
| 1375 |
+
score = score.clone()
|
| 1376 |
+
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
|
| 1377 |
+
score *= dsigma.exp()[:, None]
|
| 1378 |
+
score[..., self.mask_index] += extra_const
|
| 1379 |
+
return score
|
| 1380 |
+
|
| 1381 |
+
def _analytic_update(self, x, t, step_size, attn_mask=None):
|
| 1382 |
+
curr_sigma, _ = self.noise(t)
|
| 1383 |
+
next_sigma, _ = self.noise(t - step_size)
|
| 1384 |
+
dsigma = curr_sigma - next_sigma
|
| 1385 |
+
score = self.get_score(x, attn_mask, curr_sigma)
|
| 1386 |
+
stag_score = self._staggered_score(score, dsigma)
|
| 1387 |
+
probs = stag_score * self._transp_transition(x, dsigma)
|
| 1388 |
+
return _sample_categorical(probs)
|
| 1389 |
+
|
| 1390 |
+
def _denoiser_update(self, x, t):
|
| 1391 |
+
sigma, _ = self.noise(t)
|
| 1392 |
+
score = self.get_score(x, sigma)
|
| 1393 |
+
stag_score = self._staggered_score(score, sigma)
|
| 1394 |
+
probs = stag_score * self._transp_transition(x, sigma)
|
| 1395 |
+
probs[..., self.mask_index] = 0
|
| 1396 |
+
samples = _sample_categorical(probs)
|
| 1397 |
+
return samples
|
| 1398 |
+
|
| 1399 |
+
def _transp_transition(self, i, sigma):
|
| 1400 |
+
sigma = unsqueeze(sigma, reference=i[..., None])
|
| 1401 |
+
edge = torch.exp(-sigma) * F.one_hot(
|
| 1402 |
+
i, num_classes=self.vocab_size)
|
| 1403 |
+
edge += torch.where(i == self.mask_index,
|
| 1404 |
+
1 - torch.exp(-sigma).squeeze(-1),
|
| 1405 |
+
0)[..., None]
|
| 1406 |
+
return edge
|
| 1407 |
+
|
| 1408 |
+
|
| 1409 |
+
"""TRAINING from https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py"""
|
| 1410 |
+
|
| 1411 |
+
def on_train_epoch_start(self):
|
| 1412 |
+
torch.cuda.empty_cache()
|
| 1413 |
+
self.backbone.train()
|
| 1414 |
+
self.noise.train()
|
| 1415 |
+
|
| 1416 |
+
|
| 1417 |
+
def training_step(self, batch, batch_idx):
|
| 1418 |
+
# Initialize throughput calculation
|
| 1419 |
+
start_time = time.time()
|
| 1420 |
+
|
| 1421 |
+
if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
|
| 1422 |
+
loss = self._compute_loss(batch, prefix='train', bond_mask=batch['bond_mask'])
|
| 1423 |
+
else:
|
| 1424 |
+
loss = self._compute_loss(batch, prefix='train')
|
| 1425 |
+
|
| 1426 |
+
self.log(name='trainer/loss',
|
| 1427 |
+
value=loss.item(),
|
| 1428 |
+
on_step=True,
|
| 1429 |
+
on_epoch=False,
|
| 1430 |
+
sync_dist=True)
|
| 1431 |
+
|
| 1432 |
+
# Calculate throughput
|
| 1433 |
+
elapsed_time = time.time() - start_time
|
| 1434 |
+
total_tokens = batch['input_ids'].numel()
|
| 1435 |
+
throughput = total_tokens / elapsed_time
|
| 1436 |
+
|
| 1437 |
+
self.log(name='trainer/throughput',
|
| 1438 |
+
value=throughput,
|
| 1439 |
+
on_step=True,
|
| 1440 |
+
on_epoch=False,
|
| 1441 |
+
sync_dist=True)
|
| 1442 |
+
|
| 1443 |
+
return loss
|
| 1444 |
+
|
| 1445 |
+
|
| 1446 |
+
def on_load_checkpoint(self, checkpoint):
|
| 1447 |
+
self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
|
| 1448 |
+
self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed']
|
| 1449 |
+
|
| 1450 |
+
### VALIDATION ###
|
| 1451 |
+
def on_validation_epoch_start(self):
|
| 1452 |
+
gc.collect()
|
| 1453 |
+
torch.cuda.empty_cache()
|
| 1454 |
+
self.backbone.eval()
|
| 1455 |
+
self.noise.eval()
|
| 1456 |
+
assert self.valid_metrics.nll.mean_value == 0
|
| 1457 |
+
assert self.valid_metrics.nll.weight == 0
|
| 1458 |
+
|
| 1459 |
+
def validation_step(self, batch, batch_idx):
|
| 1460 |
+
if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
|
| 1461 |
+
loss = self._compute_loss(batch, prefix='val', bond_mask=batch['bond_mask'])
|
| 1462 |
+
else:
|
| 1463 |
+
loss = self._compute_loss(batch, prefix='val')
|
| 1464 |
+
|
| 1465 |
+
self.log(name='trainer/val_loss',
|
| 1466 |
+
value=loss.item(),
|
| 1467 |
+
on_step=True,
|
| 1468 |
+
on_epoch=False,
|
| 1469 |
+
prog_bar=True,
|
| 1470 |
+
sync_dist=True)
|
| 1471 |
+
return loss
|
| 1472 |
+
|
| 1473 |
+
def on_validation_epoch_end(self):
|
| 1474 |
+
gc.collect()
|
| 1475 |
+
torch.cuda.empty_cache()
|
| 1476 |
+
|
| 1477 |
+
### OPTIMIZATION ###
|
| 1478 |
+
|
| 1479 |
+
def optimizer_step(self, *args, **kwargs):
|
| 1480 |
+
super().optimizer_step(*args, **kwargs)
|
| 1481 |
+
|
| 1482 |
+
gc.collect()
|
| 1483 |
+
torch.cuda.empty_cache()
|
| 1484 |
+
|
| 1485 |
+
def configure_optimizers(self):
|
| 1486 |
+
optimizer = torch.optim.AdamW(
|
| 1487 |
+
itertools.chain(self.backbone.parameters(),self.noise.parameters()),
|
| 1488 |
+
lr=self.config.optim.lr,
|
| 1489 |
+
betas=(self.config.optim.beta1, self.config.optim.beta2),
|
| 1490 |
+
eps=self.config.optim.eps,
|
| 1491 |
+
weight_decay=self.config.optim.weight_decay
|
| 1492 |
+
)
|
| 1493 |
+
|
| 1494 |
+
self.total_steps = self.config.trainer.max_steps
|
| 1495 |
+
scheduler = CosineWarmup(optimizer,
|
| 1496 |
+
warmup_steps=self.config.lr_scheduler.num_warmup_steps,
|
| 1497 |
+
total_steps=self.total_steps)
|
| 1498 |
+
|
| 1499 |
+
scheduler_dict = {
|
| 1500 |
+
'scheduler': scheduler,
|
| 1501 |
+
'interval': 'step',
|
| 1502 |
+
'frequency': 1,
|
| 1503 |
+
'monitor': 'val/loss',
|
| 1504 |
+
'name': 'trainer/lr'
|
| 1505 |
+
}
|
| 1506 |
+
|
| 1507 |
+
return [optimizer], [scheduler_dict]
|
| 1508 |
+
|
| 1509 |
+
@torch.no_grad()
|
| 1510 |
+
def compute_masked_perplexity(self, generated_ids, input_ids):
|
| 1511 |
+
"""
|
| 1512 |
+
Computes masked perplexity between array of generated token ids and masked ids that are converted to logits
|
| 1513 |
+
"""
|
| 1514 |
+
|
| 1515 |
+
total_nll = 0
|
| 1516 |
+
total_tokens = 0
|
| 1517 |
+
|
| 1518 |
+
input_ids = torch.tensor(input_ids).to(self.device)
|
| 1519 |
+
#print(input_ids)
|
| 1520 |
+
|
| 1521 |
+
for sequence in generated_ids:
|
| 1522 |
+
# tokenize the sequence
|
| 1523 |
+
|
| 1524 |
+
gt_ids = torch.tensor(sequence).to(self.device)
|
| 1525 |
+
#print(gt_ids)
|
| 1526 |
+
|
| 1527 |
+
sys.stdout.flush()
|
| 1528 |
+
|
| 1529 |
+
# forward pass thorugh backbone peptideclm model
|
| 1530 |
+
attn_mask = torch.ones_like(input_ids).to(self.device)
|
| 1531 |
+
|
| 1532 |
+
# compute logits using backbone
|
| 1533 |
+
|
| 1534 |
+
if self.config.mode in ['train', 'ppl_eval']:
|
| 1535 |
+
outputs = self.backbone.forward(input_ids=input_ids, attn_mask=attn_mask)
|
| 1536 |
+
elif self.config.mode == 'sample_eval':
|
| 1537 |
+
outputs = self.backbone.forward(input_ids=input_ids, attn_mask=attn_mask)
|
| 1538 |
+
|
| 1539 |
+
|
| 1540 |
+
# get logits for each position in sequence across all tokens in vocab
|
| 1541 |
+
#logits = outputs[-1] # (batch_size, seq_length, vocab_size)
|
| 1542 |
+
|
| 1543 |
+
logits = outputs.view(-1, outputs.size(-1))
|
| 1544 |
+
gt_ids = gt_ids.view(-1)
|
| 1545 |
+
|
| 1546 |
+
#print(logits.shape)
|
| 1547 |
+
#print(gt_ids.shape)
|
| 1548 |
+
|
| 1549 |
+
# compute loss
|
| 1550 |
+
# shift_logits = logits[:, :-1, :].contiguous() # remove eos
|
| 1551 |
+
# shift_labels = input_ids[:, 1:].contiguous()
|
| 1552 |
+
# print(masked)
|
| 1553 |
+
|
| 1554 |
+
loss = F.cross_entropy(logits,
|
| 1555 |
+
gt_ids.where(input_ids==self.mask_index, torch.full_like(gt_ids, -100)).view(-1),
|
| 1556 |
+
reduction='sum')
|
| 1557 |
+
|
| 1558 |
+
total_nll += loss.item()
|
| 1559 |
+
# count all non-padding tokens
|
| 1560 |
+
total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() # count in bos and eos
|
| 1561 |
+
|
| 1562 |
+
# compute pseudo-perplexity
|
| 1563 |
+
# print(total_nll, ",;,", total_tokens)
|
| 1564 |
+
pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens))
|
| 1565 |
+
self.gen_ppl_metric.update(pseudo_perplexity)
|
| 1566 |
+
|
| 1567 |
+
return pseudo_perplexity.item()
|
| 1568 |
+
|
| 1569 |
+
|
| 1570 |
+
def unsqueeze(x, reference):
|
| 1571 |
+
return x.view(* x.shape, * ((1,) * (len(reference.shape) - len(x.shape))))
|
| 1572 |
+
|
| 1573 |
+
class CosineWarmup(_LRScheduler):
|
| 1574 |
+
def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
|
| 1575 |
+
self.warmup_steps = warmup_steps
|
| 1576 |
+
self.total_steps = total_steps
|
| 1577 |
+
self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
|
| 1578 |
+
super(CosineWarmup, self).__init__(optimizer, last_epoch)
|
| 1579 |
+
|
| 1580 |
+
def get_lr(self):
|
| 1581 |
+
if self.last_epoch < self.warmup_steps:
|
| 1582 |
+
return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
|
| 1583 |
+
|
| 1584 |
+
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
| 1585 |
+
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
|
| 1586 |
+
decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
|
| 1587 |
+
|
| 1588 |
+
return [decayed_lr * base_lr for base_lr in self.base_lrs]
|
distributed_utils.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Minimal distributed training utilities."""
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def setup_distributed(rank: int, world_size: int, backend: str = "nccl") -> None:
|
| 8 |
+
"""Initialize distributed process group."""
|
| 9 |
+
if world_size <= 1:
|
| 10 |
+
return
|
| 11 |
+
os.environ.setdefault("MASTER_ADDR", "localhost")
|
| 12 |
+
os.environ.setdefault("MASTER_PORT", "29500")
|
| 13 |
+
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
| 14 |
+
if torch.cuda.is_available():
|
| 15 |
+
torch.cuda.set_device(rank)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def cleanup_distributed() -> None:
|
| 19 |
+
"""Destroy distributed process group."""
|
| 20 |
+
if dist.is_initialized():
|
| 21 |
+
dist.destroy_process_group()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def is_main_process() -> bool:
|
| 25 |
+
"""Check if this is the main (rank 0) process."""
|
| 26 |
+
if not dist.is_initialized():
|
| 27 |
+
return True
|
| 28 |
+
return dist.get_rank() == 0
|
env.yml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: td3b
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- nvidia
|
| 5 |
+
- conda-forge
|
| 6 |
+
dependencies:
|
| 7 |
+
- python=3.10
|
| 8 |
+
- pip
|
| 9 |
+
- pytorch
|
| 10 |
+
- torchvision
|
| 11 |
+
- pytorch-cuda=12.1
|
| 12 |
+
- rdkit
|
| 13 |
+
- numpy
|
| 14 |
+
- pandas
|
| 15 |
+
- scikit-learn
|
| 16 |
+
- jupyterlab
|
| 17 |
+
- matplotlib-base
|
| 18 |
+
- seaborn
|
| 19 |
+
- tqdm
|
| 20 |
+
- pyyaml
|
| 21 |
+
- pip:
|
| 22 |
+
- pytorch-lightning==2.5.5
|
| 23 |
+
- lightning==2.5.5
|
| 24 |
+
- fair-esm==2.0.0
|
| 25 |
+
- transformers==4.56.2
|
| 26 |
+
- SmilesPE==0.0.3
|
| 27 |
+
- scipy==1.13.1
|
| 28 |
+
- wandb==0.22.0
|
| 29 |
+
- hydra-core==1.3.2
|
| 30 |
+
- hydra-submitit-launcher==1.2.0
|
| 31 |
+
- pathos==0.3.4
|
| 32 |
+
- matplotlib==3.10.1
|
| 33 |
+
- pandas==2.2.2
|
| 34 |
+
- seaborn==0.13.2
|
| 35 |
+
- timm==1.0.20
|
| 36 |
+
- xgboost==3.0.5
|
| 37 |
+
- loguru==0.7.3
|
finetune_multi_target.py
ADDED
|
@@ -0,0 +1,1061 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-Target TD3B Fine-Tuning Script
|
| 3 |
+
|
| 4 |
+
Trains TD3B on multiple protein targets with random sampling strategy.
|
| 5 |
+
Uses the GPCR directional oracle for direction-aware gating.
|
| 6 |
+
|
| 7 |
+
Architecture: Transition-Directed Discrete Diffusion for Binders (TD3B)
|
| 8 |
+
Training: Random K-target sampling + MCTS-guided trajectory optimization + contrastive learning
|
| 9 |
+
|
| 10 |
+
Key Features:
|
| 11 |
+
- Random K targets sampled per MCTS round
|
| 12 |
+
- Small-batch training to prevent OOM
|
| 13 |
+
- Periodic validation on held-out targets
|
| 14 |
+
- Checkpoint saving with validation metrics
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
import argparse
|
| 20 |
+
import logging
|
| 21 |
+
import warnings
|
| 22 |
+
from typing import List, Tuple, Dict, Optional
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
import numpy as np
|
| 29 |
+
import pandas as pd
|
| 30 |
+
import wandb
|
| 31 |
+
from tqdm import tqdm
|
| 32 |
+
|
| 33 |
+
# Add project root to path
|
| 34 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 35 |
+
|
| 36 |
+
from diffusion import Diffusion
|
| 37 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 38 |
+
from utils.app import PeptideAnalyzer
|
| 39 |
+
from scoring.functions.binding import MultiTargetBindingAffinity, TargetSpecificBindingAffinity
|
| 40 |
+
from td3b.data_utils import peptide_seq_to_smiles, smiles_token_length
|
| 41 |
+
|
| 42 |
+
# TD3B imports
|
| 43 |
+
from td3b.td3b_losses import TD3BTotalLoss
|
| 44 |
+
from td3b.td3b_finetune import (
|
| 45 |
+
extract_embeddings_from_mdlm,
|
| 46 |
+
add_td3b_sampling_to_model
|
| 47 |
+
)
|
| 48 |
+
from td3b.direction_oracle import DirectionalOracle
|
| 49 |
+
|
| 50 |
+
# Import shared configuration classes
|
| 51 |
+
from configs.finetune_config import (
|
| 52 |
+
RoFormerConfig,
|
| 53 |
+
NoiseConfig,
|
| 54 |
+
TrainingConfig,
|
| 55 |
+
SamplingConfig,
|
| 56 |
+
EvalConfig,
|
| 57 |
+
OptimConfig,
|
| 58 |
+
MCTSConfig,
|
| 59 |
+
DiffusionConfig
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Import shared utilities
|
| 63 |
+
from finetune_utils import (
|
| 64 |
+
load_tokenizer,
|
| 65 |
+
initialize_device,
|
| 66 |
+
create_output_directory,
|
| 67 |
+
save_model,
|
| 68 |
+
setup_wandb,
|
| 69 |
+
cleanup_wandb,
|
| 70 |
+
create_mcts_instance,
|
| 71 |
+
create_reward_function,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Configure logging
|
| 75 |
+
logging.basicConfig(
|
| 76 |
+
level=logging.INFO,
|
| 77 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 78 |
+
)
|
| 79 |
+
logger = logging.getLogger(__name__)
|
| 80 |
+
|
| 81 |
+
# Suppress warnings
|
| 82 |
+
warnings.filterwarnings('ignore', category=FutureWarning)
|
| 83 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 84 |
+
|
| 85 |
+
# Constants
|
| 86 |
+
SEPARATOR_LINE = "=" * 80
|
| 87 |
+
eps = 1e-5
|
| 88 |
+
|
| 89 |
+
class TargetDataset:
|
| 90 |
+
"""Dataset handler for multi-target training."""
|
| 91 |
+
|
| 92 |
+
def __init__(self, csv_path: str, tokenizer: Optional[SMILES_SPE_Tokenizer] = None):
|
| 93 |
+
"""
|
| 94 |
+
Load target dataset from CSV.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
csv_path: Path to CSV file with columns:
|
| 98 |
+
- Target_Sequence: Protein target sequence
|
| 99 |
+
- Ligand_Sequence: Binder sequence (for length reference)
|
| 100 |
+
- label: 'agonist' or 'antagonist'
|
| 101 |
+
tokenizer: Tokenizer used to compute SMILES token length
|
| 102 |
+
"""
|
| 103 |
+
self.df = pd.read_csv(csv_path)
|
| 104 |
+
logger.info(f"Loaded {len(self.df)} samples from {csv_path}")
|
| 105 |
+
self.tokenizer = tokenizer
|
| 106 |
+
|
| 107 |
+
# Group by target
|
| 108 |
+
self.targets = {}
|
| 109 |
+
for target_seq in self.df['Target_Sequence'].unique():
|
| 110 |
+
target_df = self.df[self.df['Target_Sequence'] == target_seq]
|
| 111 |
+
|
| 112 |
+
# Get binder lengths for each direction
|
| 113 |
+
agonist_binders = target_df[target_df['label'] == 'agonist']['Ligand_Sequence'].tolist()
|
| 114 |
+
antagonist_binders = target_df[target_df['label'] == 'antagonist']['Ligand_Sequence'].tolist()
|
| 115 |
+
|
| 116 |
+
# Store actual sequence lengths
|
| 117 |
+
agonist_lengths = [self._binder_length(seq) for seq in agonist_binders] if agonist_binders else []
|
| 118 |
+
antagonist_lengths = [self._binder_length(seq) for seq in antagonist_binders] if antagonist_binders else []
|
| 119 |
+
|
| 120 |
+
# Use most common length for each direction, or average if tied
|
| 121 |
+
# This ensures we generate sequences similar to the provided data
|
| 122 |
+
if agonist_lengths:
|
| 123 |
+
agonist_target_length = int(np.median(agonist_lengths))
|
| 124 |
+
else:
|
| 125 |
+
# Default to antagonist length if no agonist, or 50 if neither
|
| 126 |
+
agonist_target_length = int(np.median(antagonist_lengths)) if antagonist_lengths else 50
|
| 127 |
+
|
| 128 |
+
if antagonist_lengths:
|
| 129 |
+
antagonist_target_length = int(np.median(antagonist_lengths))
|
| 130 |
+
else:
|
| 131 |
+
# Default to agonist length if no antagonist, or 50 if neither
|
| 132 |
+
antagonist_target_length = int(np.median(agonist_lengths)) if agonist_lengths else 50
|
| 133 |
+
|
| 134 |
+
self.targets[target_seq] = {
|
| 135 |
+
'sequence': target_seq,
|
| 136 |
+
'agonist_length': agonist_target_length, # Target length for agonist generation
|
| 137 |
+
'antagonist_length': antagonist_target_length, # Target length for antagonist generation
|
| 138 |
+
'agonist_count': len(agonist_binders),
|
| 139 |
+
'antagonist_count': len(antagonist_binders)
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
logger.info(f"Found {len(self.targets)} unique targets")
|
| 143 |
+
|
| 144 |
+
def _binder_length(self, binder_seq: str) -> int:
|
| 145 |
+
smiles = peptide_seq_to_smiles(binder_seq)
|
| 146 |
+
if self.tokenizer is None:
|
| 147 |
+
return len(smiles)
|
| 148 |
+
return smiles_token_length(smiles, self.tokenizer)
|
| 149 |
+
|
| 150 |
+
def sample_targets(self, k: int, random_state: Optional[int] = None) -> List[str]:
|
| 151 |
+
"""
|
| 152 |
+
Randomly sample K targets.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
k: Number of targets to sample
|
| 156 |
+
random_state: Random seed for reproducibility
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
List of target sequences
|
| 160 |
+
"""
|
| 161 |
+
if random_state is not None:
|
| 162 |
+
np.random.seed(random_state)
|
| 163 |
+
|
| 164 |
+
target_seqs = list(self.targets.keys())
|
| 165 |
+
k = min(k, len(target_seqs))
|
| 166 |
+
return np.random.choice(target_seqs, size=k, replace=False).tolist()
|
| 167 |
+
|
| 168 |
+
def get_target_info(self, target_seq: str) -> Dict:
|
| 169 |
+
"""Get information for a specific target."""
|
| 170 |
+
return self.targets[target_seq]
|
| 171 |
+
|
| 172 |
+
def get_sequence_length(self, target_seq: str, direction: str) -> int:
|
| 173 |
+
"""
|
| 174 |
+
Get the target sequence length for generation.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
target_seq: Target protein sequence
|
| 178 |
+
direction: 'agonist' or 'antagonist'
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
Target binder sequence length
|
| 182 |
+
"""
|
| 183 |
+
target_info = self.targets[target_seq]
|
| 184 |
+
if direction == 'agonist' or direction == 1.0 or direction == '+1':
|
| 185 |
+
return target_info['agonist_length']
|
| 186 |
+
else: # antagonist
|
| 187 |
+
return target_info['antagonist_length']
|
| 188 |
+
|
| 189 |
+
def get_all_targets(self) -> List[str]:
|
| 190 |
+
"""Get all target sequences."""
|
| 191 |
+
return list(self.targets.keys())
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def run_validation(
|
| 195 |
+
policy_model: Diffusion,
|
| 196 |
+
multi_target_affinity: MultiTargetBindingAffinity,
|
| 197 |
+
directional_oracle: DirectionalOracle,
|
| 198 |
+
tokenizer: SMILES_SPE_Tokenizer,
|
| 199 |
+
val_dataset: TargetDataset,
|
| 200 |
+
args: argparse.Namespace,
|
| 201 |
+
epoch: int,
|
| 202 |
+
device: torch.device,
|
| 203 |
+
protein_token_cache: Optional[Dict[str, torch.Tensor]] = None
|
| 204 |
+
) -> Dict:
|
| 205 |
+
"""
|
| 206 |
+
Run validation on all targets in validation dataset.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
policy_model: Trained diffusion model
|
| 210 |
+
affinity_predictor: Binding affinity predictor
|
| 211 |
+
directional_oracle: Directional oracle
|
| 212 |
+
tokenizer: Tokenizer
|
| 213 |
+
val_dataset: Validation dataset
|
| 214 |
+
args: Training arguments
|
| 215 |
+
epoch: Current epoch
|
| 216 |
+
device: Device
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Dictionary with validation metrics
|
| 220 |
+
"""
|
| 221 |
+
logger.info(f"\n{SEPARATOR_LINE}")
|
| 222 |
+
logger.info(f"Running validation at epoch {epoch}")
|
| 223 |
+
logger.info(f"{SEPARATOR_LINE}")
|
| 224 |
+
|
| 225 |
+
policy_model.eval()
|
| 226 |
+
|
| 227 |
+
all_sequences = []
|
| 228 |
+
all_affinities = []
|
| 229 |
+
all_gated_rewards = []
|
| 230 |
+
all_directions = []
|
| 231 |
+
all_target_directions = [] # d* for each sequence
|
| 232 |
+
all_valid_fractions = []
|
| 233 |
+
all_valid_fractions_per_sample = []
|
| 234 |
+
all_target_names = []
|
| 235 |
+
|
| 236 |
+
val_targets = val_dataset.get_all_targets()
|
| 237 |
+
|
| 238 |
+
if protein_token_cache is None:
|
| 239 |
+
protein_token_cache = {}
|
| 240 |
+
|
| 241 |
+
with torch.no_grad():
|
| 242 |
+
for target_seq in tqdm(val_targets, desc="Validating targets"):
|
| 243 |
+
target_info = val_dataset.get_target_info(target_seq)
|
| 244 |
+
target_protein_tokens = protein_token_cache.get(target_seq)
|
| 245 |
+
if target_protein_tokens is None:
|
| 246 |
+
target_protein_tokens = directional_oracle.encode_protein(target_seq)
|
| 247 |
+
protein_token_cache[target_seq] = target_protein_tokens
|
| 248 |
+
|
| 249 |
+
# Generate for both agonist and antagonist
|
| 250 |
+
for direction_name, d_star in [('agonist', 1.0), ('antagonist', -1.0)]:
|
| 251 |
+
# Get the target sequence length for this direction
|
| 252 |
+
target_length = val_dataset.get_sequence_length(target_seq, direction_name)
|
| 253 |
+
|
| 254 |
+
# Temporarily set args.seq_length for this generation
|
| 255 |
+
original_seq_length = args.seq_length
|
| 256 |
+
args.seq_length = target_length
|
| 257 |
+
|
| 258 |
+
# Create target-specific affinity predictor for this target
|
| 259 |
+
target_affinity = TargetSpecificBindingAffinity(multi_target_affinity, target_seq)
|
| 260 |
+
|
| 261 |
+
# Create reward model for this target+direction
|
| 262 |
+
reward_model = create_reward_function(
|
| 263 |
+
affinity_predictor=target_affinity,
|
| 264 |
+
directional_oracle=directional_oracle,
|
| 265 |
+
target_direction=d_star,
|
| 266 |
+
target_protein_tokens=target_protein_tokens,
|
| 267 |
+
tokenizer=tokenizer,
|
| 268 |
+
device=device,
|
| 269 |
+
min_affinity_threshold=args.min_affinity_threshold,
|
| 270 |
+
use_confidence_weighting=True,
|
| 271 |
+
temperature=args.sigmoid_temperature
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Sample sequences with the correct length
|
| 275 |
+
x_eval, eval_metrics = policy_model.sample_finetuned_td3b(
|
| 276 |
+
args,
|
| 277 |
+
reward_model,
|
| 278 |
+
batch_size=args.val_samples_per_target,
|
| 279 |
+
dataframe=False
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Restore original seq_length
|
| 283 |
+
args.seq_length = original_seq_length
|
| 284 |
+
|
| 285 |
+
# Decode sequences
|
| 286 |
+
sequences = tokenizer.batch_decode(x_eval)
|
| 287 |
+
|
| 288 |
+
# Get metrics
|
| 289 |
+
affinities = eval_metrics.get('affinity', [])
|
| 290 |
+
gated_rewards = eval_metrics.get('gated_reward', [])
|
| 291 |
+
directions = eval_metrics.get('direction_predictions', [])
|
| 292 |
+
valid_fraction = eval_metrics.get('valid_fraction', 0.0)
|
| 293 |
+
|
| 294 |
+
# CRITICAL FIX: Metrics are only computed for valid sequences
|
| 295 |
+
# So we should extend based on the length of metrics arrays, not all sequences
|
| 296 |
+
num_valid = len(affinities) # Number of valid sequences with metrics
|
| 297 |
+
|
| 298 |
+
# Filter to only valid sequences (metrics are only for valid ones)
|
| 299 |
+
from utils.app import PeptideAnalyzer
|
| 300 |
+
analyzer = PeptideAnalyzer()
|
| 301 |
+
valid_sequences = [seq for seq in sequences if analyzer.is_peptide(seq)][:num_valid]
|
| 302 |
+
|
| 303 |
+
# Store (all arrays must have the same length = num_valid)
|
| 304 |
+
all_sequences.extend(valid_sequences) # Only valid sequences
|
| 305 |
+
all_affinities.extend(affinities)
|
| 306 |
+
all_gated_rewards.extend(gated_rewards)
|
| 307 |
+
all_directions.extend(directions)
|
| 308 |
+
all_target_directions.extend([d_star] * num_valid)
|
| 309 |
+
all_valid_fractions.append(valid_fraction)
|
| 310 |
+
all_valid_fractions_per_sample.extend([valid_fraction] * num_valid)
|
| 311 |
+
all_target_names.extend([target_seq[:20]] * num_valid)
|
| 312 |
+
|
| 313 |
+
# Compute validation metrics
|
| 314 |
+
all_affinities = np.array(all_affinities)
|
| 315 |
+
all_gated_rewards = np.array(all_gated_rewards)
|
| 316 |
+
all_directions = np.array(all_directions)
|
| 317 |
+
all_target_directions = np.array(all_target_directions)
|
| 318 |
+
|
| 319 |
+
if all_directions.size == 0:
|
| 320 |
+
direction_correct = np.array([], dtype=np.float32)
|
| 321 |
+
else:
|
| 322 |
+
direction_correct = np.where(
|
| 323 |
+
all_target_directions > 0,
|
| 324 |
+
all_directions >= 0.5,
|
| 325 |
+
all_directions < 0.5
|
| 326 |
+
).astype(np.float32)
|
| 327 |
+
|
| 328 |
+
# Consistency rewards: d* × (f_φ - 0.5)
|
| 329 |
+
consistency_rewards = all_target_directions * (all_directions - 0.5) # range from -1 to 1.
|
| 330 |
+
success_rates = direction_correct * np.array(all_valid_fractions_per_sample, dtype=np.float32)
|
| 331 |
+
|
| 332 |
+
# Separate by direction
|
| 333 |
+
agonist_mask = all_target_directions == 1.0
|
| 334 |
+
antagonist_mask = all_target_directions == -1.0
|
| 335 |
+
|
| 336 |
+
consistency_agonist = consistency_rewards[agonist_mask]
|
| 337 |
+
consistency_antagonist = consistency_rewards[antagonist_mask]
|
| 338 |
+
|
| 339 |
+
val_metrics = {
|
| 340 |
+
'affinity_mean': np.mean(all_affinities),
|
| 341 |
+
'affinity_std': np.std(all_affinities),
|
| 342 |
+
'gated_reward_mean': np.mean(all_gated_rewards),
|
| 343 |
+
'gated_reward_std': np.std(all_gated_rewards),
|
| 344 |
+
'direction_oracle_mean': np.mean(all_directions),
|
| 345 |
+
'direction_oracle_std': np.std(all_directions),
|
| 346 |
+
'consistency_reward_mean': np.mean(consistency_rewards),
|
| 347 |
+
'consistency_reward_std': np.std(consistency_rewards),
|
| 348 |
+
'consistency_agonist_mean': np.mean(consistency_agonist) if len(consistency_agonist) > 0 else 0.0,
|
| 349 |
+
'consistency_agonist_std': np.std(consistency_agonist) if len(consistency_agonist) > 0 else 0.0,
|
| 350 |
+
'consistency_antagonist_mean': np.mean(consistency_antagonist) if len(consistency_antagonist) > 0 else 0.0,
|
| 351 |
+
'consistency_antagonist_std': np.std(consistency_antagonist) if len(consistency_antagonist) > 0 else 0.0,
|
| 352 |
+
'valid_fraction_mean': np.mean(all_valid_fractions),
|
| 353 |
+
'valid_fraction_std': np.std(all_valid_fractions),
|
| 354 |
+
'direction_accuracy_mean': np.mean(direction_correct) if direction_correct.size else 0.0,
|
| 355 |
+
'direction_accuracy_std': np.std(direction_correct) if direction_correct.size else 0.0,
|
| 356 |
+
'success_rate_mean': np.mean(success_rates) if success_rates.size else 0.0,
|
| 357 |
+
'success_rate_std': np.std(success_rates) if success_rates.size else 0.0
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
# Log validation metrics
|
| 361 |
+
logger.info(f"\nValidation Results (Epoch {epoch}):")
|
| 362 |
+
logger.info(f" Affinity: {val_metrics['affinity_mean']:.4f} ± {val_metrics['affinity_std']:.4f}")
|
| 363 |
+
logger.info(f" Gated Reward: {val_metrics['gated_reward_mean']:.4f} ± {val_metrics['gated_reward_std']:.4f}")
|
| 364 |
+
logger.info(f" Direction Oracle: {val_metrics['direction_oracle_mean']:.4f} ± {val_metrics['direction_oracle_std']:.4f}")
|
| 365 |
+
logger.info(f" Consistency Reward: {val_metrics['consistency_reward_mean']:.4f} ± {val_metrics['consistency_reward_std']:.4f}")
|
| 366 |
+
logger.info(f" Consistency (d*=+1): {val_metrics['consistency_agonist_mean']:.4f} ± {val_metrics['consistency_agonist_std']:.4f}")
|
| 367 |
+
logger.info(f" Consistency (d*=-1): {val_metrics['consistency_antagonist_mean']:.4f} ± {val_metrics['consistency_antagonist_std']:.4f}")
|
| 368 |
+
logger.info(f" Valid Fraction: {val_metrics['valid_fraction_mean']:.4f} ± {val_metrics['valid_fraction_std']:.4f}")
|
| 369 |
+
logger.info(f" Direction Accuracy: {val_metrics['direction_accuracy_mean']:.4f} ± {val_metrics['direction_accuracy_std']:.4f}")
|
| 370 |
+
logger.info(f" Success Rate: {val_metrics['success_rate_mean']:.4f} ± {val_metrics['success_rate_std']:.4f}")
|
| 371 |
+
|
| 372 |
+
# Save validation sequences to file
|
| 373 |
+
val_df = pd.DataFrame({
|
| 374 |
+
'target': all_target_names,
|
| 375 |
+
'sequence': all_sequences,
|
| 376 |
+
'target_direction': all_target_directions,
|
| 377 |
+
'affinity': all_affinities,
|
| 378 |
+
'gated_reward': all_gated_rewards,
|
| 379 |
+
'direction_oracle': all_directions,
|
| 380 |
+
'consistency_reward': consistency_rewards,
|
| 381 |
+
'direction_accuracy': direction_correct,
|
| 382 |
+
'success_rate': success_rates
|
| 383 |
+
})
|
| 384 |
+
|
| 385 |
+
val_output_path = os.path.join(args.save_path, f'validation_epoch_{epoch}.csv')
|
| 386 |
+
val_df.to_csv(val_output_path, index=False)
|
| 387 |
+
logger.info(f"Validation sequences saved to {val_output_path}")
|
| 388 |
+
|
| 389 |
+
policy_model.train()
|
| 390 |
+
|
| 391 |
+
return val_metrics
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def parse_args():
|
| 395 |
+
"""Parse command-line arguments."""
|
| 396 |
+
parser = argparse.ArgumentParser(description='Multi-Target TD3B Fine-Tuning')
|
| 397 |
+
|
| 398 |
+
# Paths
|
| 399 |
+
path_group = parser.add_argument_group('Paths')
|
| 400 |
+
path_group.add_argument('--base_path', type=str, required=True,
|
| 401 |
+
help='Base path for TR2-D2 project')
|
| 402 |
+
path_group.add_argument('--train_csv', type=str, required=True,
|
| 403 |
+
help='Path to training CSV file')
|
| 404 |
+
path_group.add_argument('--val_csv', type=str, default=None,
|
| 405 |
+
help='Path to validation CSV file (optional)')
|
| 406 |
+
path_group.add_argument('--pretrained_checkpoint', type=str, required=True,
|
| 407 |
+
help='Path to pretrained diffusion model checkpoint')
|
| 408 |
+
path_group.add_argument('--run_name', type=str, required=True,
|
| 409 |
+
help='Name for this training run')
|
| 410 |
+
path_group.add_argument('--device', type=str, default='cuda',
|
| 411 |
+
help='Device to use (cuda or cpu)')
|
| 412 |
+
|
| 413 |
+
# Multi-target sampling
|
| 414 |
+
target_group = parser.add_argument_group('Multi-Target Sampling')
|
| 415 |
+
target_group.add_argument('--targets_per_mcts', type=int, default=5,
|
| 416 |
+
help='Number of targets to sample per MCTS round (K)')
|
| 417 |
+
target_group.add_argument('--resample_targets_every', type=int, default=1,
|
| 418 |
+
help='Resample targets every N epochs')
|
| 419 |
+
|
| 420 |
+
# Training hyperparameters
|
| 421 |
+
train_group = parser.add_argument_group('Training')
|
| 422 |
+
train_group.add_argument('--num_epochs', type=int, default=200,
|
| 423 |
+
help='Total number of training epochs')
|
| 424 |
+
train_group.add_argument('--learning_rate', type=float, default=3e-4,
|
| 425 |
+
help='Learning rate for optimizer')
|
| 426 |
+
train_group.add_argument('--train_batch_size', type=int, default=16,
|
| 427 |
+
help='Batch size for training (small to prevent OOM)')
|
| 428 |
+
train_group.add_argument('--gradient_accumulation_steps', type=int, default=4,
|
| 429 |
+
help='Accumulate gradients over N steps')
|
| 430 |
+
train_group.add_argument('--resample_every_n_step', type=int, default=10,
|
| 431 |
+
help='Resample MCTS every N epochs')
|
| 432 |
+
train_group.add_argument('--save_every_n_epochs', type=int, default=20,
|
| 433 |
+
help='Save checkpoint every N epochs')
|
| 434 |
+
train_group.add_argument('--validate_every_n_epochs', type=int, default=20,
|
| 435 |
+
help='Run validation every N epochs')
|
| 436 |
+
train_group.add_argument('--num_epoch_for_sampling', type=int, default=5,
|
| 437 |
+
help='Run evaluation sampling every N epochs (set <=0 to disable)')
|
| 438 |
+
train_group.add_argument('--reset_every_n_step', type=int, default=50,
|
| 439 |
+
help='Reset MCTS tree every N epochs')
|
| 440 |
+
|
| 441 |
+
# MCTS hyperparameters
|
| 442 |
+
mcts_group = parser.add_argument_group('MCTS')
|
| 443 |
+
mcts_group.add_argument('--num_iter', type=int, default=50,
|
| 444 |
+
help='MCTS iterations per resample (v1 default: 50, reduce for multi-target)')
|
| 445 |
+
mcts_group.add_argument('--num_children', type=int, default=30,
|
| 446 |
+
help='Children per MCTS expansion')
|
| 447 |
+
mcts_group.add_argument('--buffer_size', type=int, default=50,
|
| 448 |
+
help='Pareto buffer size (v1 default: 50)')
|
| 449 |
+
mcts_group.add_argument('--replay_buffer_size', type=int, default=0,
|
| 450 |
+
help='Max replay buffer size across resamples (0 disables replay)')
|
| 451 |
+
mcts_group.add_argument('--replay_buffer_strategy', type=str, default='fifo',
|
| 452 |
+
choices=['fifo', 'random'],
|
| 453 |
+
help='Replay buffer eviction strategy when full')
|
| 454 |
+
mcts_group.add_argument('--alpha', type=float, default=0.1,
|
| 455 |
+
help='Temperature for importance weighting')
|
| 456 |
+
mcts_group.add_argument('--exploration', type=float, default=1.0,
|
| 457 |
+
help='UCB exploration constant')
|
| 458 |
+
|
| 459 |
+
# TD3B loss hyperparameters
|
| 460 |
+
loss_group = parser.add_argument_group('TD3B Loss')
|
| 461 |
+
loss_group.add_argument('--contrastive_weight', type=float, default=0.1,
|
| 462 |
+
help='Weight for contrastive loss (v1 default: 0.1)')
|
| 463 |
+
loss_group.add_argument('--contrastive_margin', type=float, default=1.0,
|
| 464 |
+
help='Margin for contrastive loss')
|
| 465 |
+
loss_group.add_argument('--contrastive_type', type=str, default='triplet',
|
| 466 |
+
choices=['triplet', 'ntxent', 'supcon'],
|
| 467 |
+
help='Type of contrastive loss')
|
| 468 |
+
loss_group.add_argument('--kl_beta', type=float, default=0.1,
|
| 469 |
+
help='KL divergence regularization coefficient (v1 default: 0.1)')
|
| 470 |
+
loss_group.add_argument('--min_affinity_threshold', type=float, default=0.0,
|
| 471 |
+
help='Minimum affinity threshold for allosteric control (CRITICAL)')
|
| 472 |
+
loss_group.add_argument('--sigmoid_temperature', type=float, default=0.1,
|
| 473 |
+
help='Temperature for sigmoid gating')
|
| 474 |
+
|
| 475 |
+
# Validation
|
| 476 |
+
val_group = parser.add_argument_group('Validation')
|
| 477 |
+
val_group.add_argument('--val_samples_per_target', type=int, default=20,
|
| 478 |
+
help='Number of sequences to generate per target during validation')
|
| 479 |
+
|
| 480 |
+
# Architecture
|
| 481 |
+
arch_group = parser.add_argument_group('Architecture')
|
| 482 |
+
arch_group.add_argument('--seq_length', type=int, default=200,
|
| 483 |
+
help='Maximum sequence length')
|
| 484 |
+
arch_group.add_argument('--embedding_pool_method', type=str, default='cls',
|
| 485 |
+
choices=['cls', 'mean', 'max'],
|
| 486 |
+
help='Pooling method for embeddings')
|
| 487 |
+
arch_group.add_argument('--hidden_dim', type=int, default=768,
|
| 488 |
+
help='Hidden dimension size')
|
| 489 |
+
arch_group.add_argument('--num_layers', type=int, default=8,
|
| 490 |
+
help='Number of transformer layers (v1 default: 8)')
|
| 491 |
+
arch_group.add_argument('--num_heads', type=int, default=8,
|
| 492 |
+
help='Number of attention heads (v1 default: 8)')
|
| 493 |
+
arch_group.add_argument('--sampling_eps', type=float, default=1e-3,
|
| 494 |
+
help='Sampling epsilon (v1 default: 1e-3)')
|
| 495 |
+
arch_group.add_argument('--total_num_steps', type=int, default=128,
|
| 496 |
+
help='Total number of diffusion steps (v1 default: 128)')
|
| 497 |
+
|
| 498 |
+
# Optimization
|
| 499 |
+
opt_group = parser.add_argument_group('Optimization')
|
| 500 |
+
opt_group.add_argument('--grad_clip', action='store_true',
|
| 501 |
+
help='Enable gradient clipping')
|
| 502 |
+
opt_group.add_argument('--gradnorm_clip', type=float, default=1.0,
|
| 503 |
+
help='Gradient norm clipping threshold')
|
| 504 |
+
opt_group.add_argument('--wdce_num_replicates', type=int, default=16,
|
| 505 |
+
help='Number of replicates for WDCE loss (v1 default: 16)')
|
| 506 |
+
opt_group.add_argument('--centering', action='store_true',
|
| 507 |
+
help='Enable centering in WDCE loss')
|
| 508 |
+
|
| 509 |
+
# Logging
|
| 510 |
+
log_group = parser.add_argument_group('Logging')
|
| 511 |
+
log_group.add_argument('--wandb_project', type=str, default='TD3B-multi-target',
|
| 512 |
+
help='W&B project name')
|
| 513 |
+
log_group.add_argument('--wandb_entity', type=str, default='phos_zj',
|
| 514 |
+
help='W&B entity name')
|
| 515 |
+
|
| 516 |
+
# Directional oracle
|
| 517 |
+
oracle_group = parser.add_argument_group('Directional Oracle')
|
| 518 |
+
oracle_group.add_argument('--direction_oracle_ckpt', type=str, default=None,
|
| 519 |
+
help='Path to directional oracle checkpoint')
|
| 520 |
+
oracle_group.add_argument('--direction_oracle_tr2d2_checkpoint', type=str, default=None,
|
| 521 |
+
help='Path to TR2D2 checkpoint used by the oracle')
|
| 522 |
+
oracle_group.add_argument('--direction_oracle_tokenizer_vocab', type=str, default=None,
|
| 523 |
+
help='Path to SMILES tokenizer vocab for oracle')
|
| 524 |
+
oracle_group.add_argument('--direction_oracle_tokenizer_splits', type=str, default=None,
|
| 525 |
+
help='Path to SMILES tokenizer splits for oracle')
|
| 526 |
+
oracle_group.add_argument('--direction_oracle_esm_name', type=str,
|
| 527 |
+
default='facebook/esm2_t33_650M_UR50D',
|
| 528 |
+
help='ESM model name or local path')
|
| 529 |
+
oracle_group.add_argument('--direction_oracle_esm_cache_dir', type=str, default=None,
|
| 530 |
+
help='Optional cache directory for ESM model')
|
| 531 |
+
oracle_group.add_argument('--direction_oracle_esm_local_files_only', action='store_true',
|
| 532 |
+
help='Load ESM from local cache only (no network)')
|
| 533 |
+
oracle_group.add_argument('--direction_oracle_max_ligand_length', type=int, default=768,
|
| 534 |
+
help='Max SMILES token length for oracle')
|
| 535 |
+
oracle_group.add_argument('--direction_oracle_max_protein_length', type=int, default=1024,
|
| 536 |
+
help='Max protein token length for oracle')
|
| 537 |
+
oracle_group.add_argument('--direction_oracle_d_model', type=int, default=256,
|
| 538 |
+
help='Oracle hidden dimension (must match checkpoint)')
|
| 539 |
+
oracle_group.add_argument('--direction_oracle_n_heads', type=int, default=4,
|
| 540 |
+
help='Oracle attention heads (must match checkpoint)')
|
| 541 |
+
oracle_group.add_argument('--direction_oracle_n_self_attn_layers', type=int, default=1,
|
| 542 |
+
help='Oracle self-attention layers (must match checkpoint)')
|
| 543 |
+
oracle_group.add_argument('--direction_oracle_n_bmca_layers', type=int, default=2,
|
| 544 |
+
help='Oracle cross-attention layers (must match checkpoint)')
|
| 545 |
+
oracle_group.add_argument('--direction_oracle_dropout', type=float, default=0.3,
|
| 546 |
+
help='Oracle dropout (must match checkpoint)')
|
| 547 |
+
|
| 548 |
+
args = parser.parse_args()
|
| 549 |
+
|
| 550 |
+
# Resolve default oracle paths relative to base_path
|
| 551 |
+
base_tr2d2_path = os.path.join(args.base_path, 'tr2d2-pep')
|
| 552 |
+
if args.direction_oracle_ckpt is None:
|
| 553 |
+
args.direction_oracle_ckpt = os.path.join(
|
| 554 |
+
base_tr2d2_path, 'best_model_tr2d2_gpcr_fixed.pt'
|
| 555 |
+
)
|
| 556 |
+
if args.direction_oracle_tr2d2_checkpoint is None:
|
| 557 |
+
args.direction_oracle_tr2d2_checkpoint = os.path.join(
|
| 558 |
+
base_tr2d2_path, 'pretrained', 'peptune-pretrained.ckpt'
|
| 559 |
+
)
|
| 560 |
+
if args.direction_oracle_tokenizer_vocab is None:
|
| 561 |
+
args.direction_oracle_tokenizer_vocab = os.path.join(
|
| 562 |
+
base_tr2d2_path, 'tokenizer', 'new_vocab.txt'
|
| 563 |
+
)
|
| 564 |
+
if args.direction_oracle_tokenizer_splits is None:
|
| 565 |
+
args.direction_oracle_tokenizer_splits = os.path.join(
|
| 566 |
+
base_tr2d2_path, 'tokenizer', 'new_splits.txt'
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
# Add derived attributes (required by MCTS)
|
| 570 |
+
args.time_conditioning = False
|
| 571 |
+
args.num_obj = 5 # Must match padded score vector size
|
| 572 |
+
args.scalarization = "sum"
|
| 573 |
+
|
| 574 |
+
# Create save path
|
| 575 |
+
args.save_path = create_output_directory(
|
| 576 |
+
args.base_path,
|
| 577 |
+
args.run_name,
|
| 578 |
+
add_timestamp=True
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
return args
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def main():
|
| 585 |
+
args = parse_args()
|
| 586 |
+
|
| 587 |
+
logger.info(f"\n{SEPARATOR_LINE}")
|
| 588 |
+
logger.info("Multi-Target TD3B Fine-Tuning")
|
| 589 |
+
logger.info(f"{SEPARATOR_LINE}\n")
|
| 590 |
+
|
| 591 |
+
# Set device
|
| 592 |
+
device = initialize_device(args.device)
|
| 593 |
+
|
| 594 |
+
# Initialize W&B
|
| 595 |
+
setup_wandb(
|
| 596 |
+
project=args.wandb_project,
|
| 597 |
+
name=args.run_name,
|
| 598 |
+
config=vars(args),
|
| 599 |
+
entity=args.wandb_entity
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
# Tokenizer
|
| 603 |
+
tokenizer = load_tokenizer(args.base_path)
|
| 604 |
+
|
| 605 |
+
# Load datasets
|
| 606 |
+
logger.info("\n[1/6] Loading datasets...")
|
| 607 |
+
train_dataset = TargetDataset(args.train_csv, tokenizer=tokenizer)
|
| 608 |
+
val_dataset = TargetDataset(args.val_csv, tokenizer=tokenizer) if args.val_csv else None
|
| 609 |
+
|
| 610 |
+
# Load models
|
| 611 |
+
logger.info("\n[2/6] Loading models...")
|
| 612 |
+
|
| 613 |
+
# Create diffusion config
|
| 614 |
+
config = DiffusionConfig(
|
| 615 |
+
roformer=RoFormerConfig(
|
| 616 |
+
hidden_size=args.hidden_dim,
|
| 617 |
+
n_layers=args.num_layers,
|
| 618 |
+
n_heads=args.num_heads
|
| 619 |
+
),
|
| 620 |
+
noise=NoiseConfig(),
|
| 621 |
+
training=TrainingConfig(sampling_eps=args.sampling_eps),
|
| 622 |
+
sampling=SamplingConfig(
|
| 623 |
+
steps=args.total_num_steps,
|
| 624 |
+
sampling_eps=args.sampling_eps
|
| 625 |
+
),
|
| 626 |
+
eval_cfg=EvalConfig(),
|
| 627 |
+
optim=OptimConfig(lr=args.learning_rate),
|
| 628 |
+
mcts=MCTSConfig()
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
# Policy model
|
| 632 |
+
policy_model = Diffusion(
|
| 633 |
+
config=config,
|
| 634 |
+
tokenizer=tokenizer,
|
| 635 |
+
device=device
|
| 636 |
+
).to(device)
|
| 637 |
+
|
| 638 |
+
# Load pretrained checkpoint
|
| 639 |
+
checkpoint = torch.load(args.pretrained_checkpoint, map_location=device, weights_only=False)
|
| 640 |
+
|
| 641 |
+
# Handle different checkpoint formats (like v1)
|
| 642 |
+
CHECKPOINT_KEYS = ('state_dict', 'model_state_dict')
|
| 643 |
+
state_dict = None
|
| 644 |
+
for key in CHECKPOINT_KEYS:
|
| 645 |
+
if key in checkpoint:
|
| 646 |
+
state_dict = checkpoint[key]
|
| 647 |
+
logger.info(f"Loading checkpoint from key: {key}")
|
| 648 |
+
break
|
| 649 |
+
|
| 650 |
+
if state_dict is None:
|
| 651 |
+
# Assume checkpoint is already a state_dict
|
| 652 |
+
state_dict = checkpoint
|
| 653 |
+
logger.info("Loading checkpoint as direct state_dict")
|
| 654 |
+
|
| 655 |
+
policy_model.load_state_dict(state_dict, strict=False)
|
| 656 |
+
logger.info(f"Loaded pretrained checkpoint from {args.pretrained_checkpoint}")
|
| 657 |
+
|
| 658 |
+
# Reference model (frozen)
|
| 659 |
+
reference_model = Diffusion(
|
| 660 |
+
config=config,
|
| 661 |
+
tokenizer=tokenizer,
|
| 662 |
+
device=device
|
| 663 |
+
).to(device)
|
| 664 |
+
reference_model.load_state_dict(state_dict, strict=False)
|
| 665 |
+
reference_model.eval()
|
| 666 |
+
for param in reference_model.parameters():
|
| 667 |
+
param.requires_grad = False
|
| 668 |
+
logger.info("Created reference model (frozen)")
|
| 669 |
+
|
| 670 |
+
# Add TD3B sampling method, fix bugs, sampling sequences with w(t) as condition
|
| 671 |
+
policy_model = add_td3b_sampling_to_model(policy_model)
|
| 672 |
+
|
| 673 |
+
# Multi-target affinity predictor
|
| 674 |
+
multi_target_affinity = MultiTargetBindingAffinity(
|
| 675 |
+
tokenizer=tokenizer,
|
| 676 |
+
base_path=args.base_path,
|
| 677 |
+
device=device,
|
| 678 |
+
emb_model=policy_model.backbone # Use backbone Roformer model (matches v1)
|
| 679 |
+
)
|
| 680 |
+
logger.info("Created multi-target binding affinity predictor")
|
| 681 |
+
|
| 682 |
+
# Directional oracle (GPCR classifier)
|
| 683 |
+
for path_label, path in [
|
| 684 |
+
("direction_oracle_ckpt", args.direction_oracle_ckpt),
|
| 685 |
+
("direction_oracle_tr2d2_checkpoint", args.direction_oracle_tr2d2_checkpoint),
|
| 686 |
+
("direction_oracle_tokenizer_vocab", args.direction_oracle_tokenizer_vocab),
|
| 687 |
+
("direction_oracle_tokenizer_splits", args.direction_oracle_tokenizer_splits),
|
| 688 |
+
]:
|
| 689 |
+
if not os.path.isfile(path):
|
| 690 |
+
raise FileNotFoundError(f"Missing {path_label}: {path}")
|
| 691 |
+
|
| 692 |
+
directional_oracle = DirectionalOracle(
|
| 693 |
+
model_ckpt=args.direction_oracle_ckpt,
|
| 694 |
+
tr2d2_checkpoint=args.direction_oracle_tr2d2_checkpoint,
|
| 695 |
+
tokenizer_vocab=args.direction_oracle_tokenizer_vocab,
|
| 696 |
+
tokenizer_splits=args.direction_oracle_tokenizer_splits,
|
| 697 |
+
esm_name=args.direction_oracle_esm_name,
|
| 698 |
+
d_model=args.direction_oracle_d_model,
|
| 699 |
+
n_heads=args.direction_oracle_n_heads,
|
| 700 |
+
n_self_attn_layers=args.direction_oracle_n_self_attn_layers,
|
| 701 |
+
n_bmca_layers=args.direction_oracle_n_bmca_layers,
|
| 702 |
+
dropout=args.direction_oracle_dropout,
|
| 703 |
+
max_ligand_length=args.direction_oracle_max_ligand_length,
|
| 704 |
+
max_protein_length=args.direction_oracle_max_protein_length,
|
| 705 |
+
device=device,
|
| 706 |
+
esm_cache_dir=args.direction_oracle_esm_cache_dir,
|
| 707 |
+
esm_local_files_only=args.direction_oracle_esm_local_files_only
|
| 708 |
+
)
|
| 709 |
+
directional_oracle.eval()
|
| 710 |
+
|
| 711 |
+
protein_token_cache: Dict[str, torch.Tensor] = {}
|
| 712 |
+
|
| 713 |
+
def get_protein_tokens(target_seq: str) -> torch.Tensor:
|
| 714 |
+
cached = protein_token_cache.get(target_seq)
|
| 715 |
+
if cached is None:
|
| 716 |
+
cached = directional_oracle.encode_protein(target_seq)
|
| 717 |
+
protein_token_cache[target_seq] = cached
|
| 718 |
+
return cached
|
| 719 |
+
|
| 720 |
+
# Loss function
|
| 721 |
+
logger.info("\n[3/6] Creating loss function...")
|
| 722 |
+
td3b_loss_fn = TD3BTotalLoss(
|
| 723 |
+
contrastive_weight=args.contrastive_weight,
|
| 724 |
+
contrastive_margin=args.contrastive_margin,
|
| 725 |
+
kl_beta=args.kl_beta,
|
| 726 |
+
reference_model=reference_model,
|
| 727 |
+
adaptive_margin=True
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
# WDCE loss
|
| 731 |
+
from finetune_utils import loss_wdce
|
| 732 |
+
|
| 733 |
+
logger.info("\n[4/6] Setting up training...")
|
| 734 |
+
policy_model.train()
|
| 735 |
+
torch.set_grad_enabled(True)
|
| 736 |
+
optimizer = torch.optim.AdamW(policy_model.parameters(), lr=args.learning_rate)
|
| 737 |
+
|
| 738 |
+
# Training logs
|
| 739 |
+
batch_losses = []
|
| 740 |
+
batch_wdce_losses = []
|
| 741 |
+
batch_contrastive_losses = []
|
| 742 |
+
batch_kl_losses = []
|
| 743 |
+
|
| 744 |
+
# Multi-target buffer
|
| 745 |
+
# We'll store sequences from all sampled targets here
|
| 746 |
+
buffer_sequences = [] # List of (x, log_rnd, reward, directional_label, confidence)
|
| 747 |
+
current_targets = []
|
| 748 |
+
|
| 749 |
+
def trim_replay_buffer(items, max_size, strategy):
|
| 750 |
+
if max_size <= 0 or len(items) <= max_size:
|
| 751 |
+
return items
|
| 752 |
+
if strategy == "fifo":
|
| 753 |
+
return items[-max_size:]
|
| 754 |
+
indices = np.random.choice(len(items), size=max_size, replace=False)
|
| 755 |
+
return [items[i] for i in indices]
|
| 756 |
+
|
| 757 |
+
logger.info(f"\n{SEPARATOR_LINE}")
|
| 758 |
+
logger.info("Starting Training")
|
| 759 |
+
logger.info(f"{SEPARATOR_LINE}\n")
|
| 760 |
+
|
| 761 |
+
# Training loop
|
| 762 |
+
pbar = tqdm(range(args.num_epochs))
|
| 763 |
+
|
| 764 |
+
for epoch in pbar:
|
| 765 |
+
# Sample new targets if needed
|
| 766 |
+
if epoch % args.resample_targets_every == 0 or len(current_targets) == 0:
|
| 767 |
+
current_targets = train_dataset.sample_targets(
|
| 768 |
+
k=args.targets_per_mcts,
|
| 769 |
+
random_state=epoch
|
| 770 |
+
)
|
| 771 |
+
logger.info(f"\nEpoch {epoch}: Sampled {len(current_targets)} targets for training")
|
| 772 |
+
|
| 773 |
+
# MCTS sampling phase (less frequent) - this is when we regenerate sequences
|
| 774 |
+
if epoch % args.resample_every_n_step == 0:
|
| 775 |
+
if args.replay_buffer_size <= 0:
|
| 776 |
+
# Clear buffer only when regenerating with new MCTS if replay is disabled
|
| 777 |
+
buffer_sequences = []
|
| 778 |
+
else:
|
| 779 |
+
logger.info(
|
| 780 |
+
f"Epoch {epoch}: Replay buffer enabled, keeping {len(buffer_sequences)} sequences before refresh"
|
| 781 |
+
)
|
| 782 |
+
logger.info(f"Epoch {epoch}: Running MCTS for {len(current_targets)} targets...")
|
| 783 |
+
mcts_valid_total = 0
|
| 784 |
+
mcts_run_count = 0
|
| 785 |
+
mcts_empty_runs = 0
|
| 786 |
+
|
| 787 |
+
with torch.no_grad():
|
| 788 |
+
for target_seq in current_targets:
|
| 789 |
+
target_info = train_dataset.get_target_info(target_seq)
|
| 790 |
+
|
| 791 |
+
# Sample both agonist and antagonist
|
| 792 |
+
for direction_name, d_star in [('agonist', 1.0), ('antagonist', -1.0)]:
|
| 793 |
+
# Get the target sequence length for this direction
|
| 794 |
+
target_length = train_dataset.get_sequence_length(target_seq, direction_name)
|
| 795 |
+
|
| 796 |
+
# Temporarily set args.seq_length for this generation
|
| 797 |
+
original_seq_length = args.seq_length
|
| 798 |
+
args.seq_length = target_length
|
| 799 |
+
|
| 800 |
+
# Create target-specific affinity predictor for this target
|
| 801 |
+
target_affinity = TargetSpecificBindingAffinity(multi_target_affinity, target_seq)
|
| 802 |
+
|
| 803 |
+
# Create reward model for this target
|
| 804 |
+
reward_model = create_reward_function(
|
| 805 |
+
affinity_predictor=target_affinity,
|
| 806 |
+
directional_oracle=directional_oracle,
|
| 807 |
+
target_direction=d_star,
|
| 808 |
+
target_protein_tokens=get_protein_tokens(target_seq),
|
| 809 |
+
tokenizer=tokenizer,
|
| 810 |
+
device=device,
|
| 811 |
+
min_affinity_threshold=args.min_affinity_threshold,
|
| 812 |
+
use_confidence_weighting=True,
|
| 813 |
+
temperature=args.sigmoid_temperature
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
# Create MCTS using shared utility
|
| 817 |
+
mcts = create_mcts_instance(
|
| 818 |
+
args=args,
|
| 819 |
+
policy_model=policy_model,
|
| 820 |
+
reward_function=reward_model,
|
| 821 |
+
tokenizer=tokenizer,
|
| 822 |
+
buffer_size=args.buffer_size
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
# Run MCTS
|
| 826 |
+
reset_tree = (epoch % args.reset_every_n_step == 0)
|
| 827 |
+
results = mcts.forward(resetTree=reset_tree)
|
| 828 |
+
|
| 829 |
+
# Restore original seq_length
|
| 830 |
+
args.seq_length = original_seq_length
|
| 831 |
+
|
| 832 |
+
# Unpack results
|
| 833 |
+
if len(results) == 7:
|
| 834 |
+
x_final, log_rnd, final_rewards, score_vectors, sequences, directional_labels, confidences = results
|
| 835 |
+
|
| 836 |
+
# Skip if MCTS returned empty buffer (no valid sequences found)
|
| 837 |
+
if len(x_final) == 0:
|
| 838 |
+
logger.warning(f"MCTS returned empty buffer for target={target_seq[:20]}, direction={direction_name}")
|
| 839 |
+
mcts_run_count += 1
|
| 840 |
+
mcts_empty_runs += 1
|
| 841 |
+
continue
|
| 842 |
+
mcts_run_count += 1
|
| 843 |
+
mcts_valid_total += len(sequences)
|
| 844 |
+
|
| 845 |
+
# Add to buffer
|
| 846 |
+
for i in range(len(x_final)):
|
| 847 |
+
buffer_sequences.append({
|
| 848 |
+
'x': x_final[i],
|
| 849 |
+
'log_rnd': log_rnd[i],
|
| 850 |
+
'reward': final_rewards[i],
|
| 851 |
+
'directional_label': d_star,
|
| 852 |
+
'confidence': confidences[i] if isinstance(confidences, np.ndarray) else 1.0
|
| 853 |
+
})
|
| 854 |
+
|
| 855 |
+
if args.replay_buffer_size > 0:
|
| 856 |
+
buffer_sequences = trim_replay_buffer(
|
| 857 |
+
buffer_sequences,
|
| 858 |
+
args.replay_buffer_size,
|
| 859 |
+
args.replay_buffer_strategy
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
logger.info(
|
| 863 |
+
f"Epoch {epoch}: MCTS runs={mcts_run_count}, "
|
| 864 |
+
f"valid_sequences={mcts_valid_total}, empty_runs={mcts_empty_runs}"
|
| 865 |
+
)
|
| 866 |
+
logger.info(f"Epoch {epoch}: Buffer size: {len(buffer_sequences)} sequences")
|
| 867 |
+
|
| 868 |
+
# Training phase: sample mini-batches from buffer
|
| 869 |
+
if len(buffer_sequences) == 0:
|
| 870 |
+
logger.warning(f"Epoch {epoch}: Buffer is empty, skipping training")
|
| 871 |
+
continue
|
| 872 |
+
|
| 873 |
+
# Shuffle buffer
|
| 874 |
+
np.random.shuffle(buffer_sequences)
|
| 875 |
+
|
| 876 |
+
# Mini-batch training
|
| 877 |
+
num_batches = max(1, len(buffer_sequences) // args.train_batch_size)
|
| 878 |
+
epoch_loss = 0.0
|
| 879 |
+
epoch_wdce_loss = 0.0
|
| 880 |
+
epoch_contrastive_loss = 0.0
|
| 881 |
+
epoch_kl_loss = 0.0
|
| 882 |
+
|
| 883 |
+
optimizer.zero_grad()
|
| 884 |
+
|
| 885 |
+
for batch_idx in range(num_batches):
|
| 886 |
+
start_idx = batch_idx * args.train_batch_size
|
| 887 |
+
end_idx = min(start_idx + args.train_batch_size, len(buffer_sequences))
|
| 888 |
+
batch_data = buffer_sequences[start_idx:end_idx]
|
| 889 |
+
|
| 890 |
+
# Pad sequences to the same length (efficient batching for variable-length sequences)
|
| 891 |
+
# Use padding to handle different sequence lengths from different targets
|
| 892 |
+
x_list = [item['x'] for item in batch_data]
|
| 893 |
+
log_rnd_list = [item['log_rnd'] for item in batch_data] # Scalars, not vectors!
|
| 894 |
+
|
| 895 |
+
# Pad x_batch: pad with mask_index (typically 0 or a special token)
|
| 896 |
+
mask_index = policy_model.mask_index if hasattr(policy_model, 'mask_index') else 0
|
| 897 |
+
max_len = max(x.shape[0] for x in x_list)
|
| 898 |
+
|
| 899 |
+
# Create padded tensors
|
| 900 |
+
x_batch = torch.full(
|
| 901 |
+
(len(x_list), max_len),
|
| 902 |
+
fill_value=mask_index,
|
| 903 |
+
dtype=x_list[0].dtype,
|
| 904 |
+
device=device
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
# Create attention mask: 1 for real tokens, 0 for padding
|
| 908 |
+
# This tells the model which positions are valid vs padded
|
| 909 |
+
attn_mask = torch.zeros(
|
| 910 |
+
(len(x_list), max_len),
|
| 911 |
+
dtype=torch.long,
|
| 912 |
+
device=device
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
# Fill in the real sequences and mark valid positions
|
| 916 |
+
for i, x in enumerate(x_list):
|
| 917 |
+
seq_len = x.shape[0]
|
| 918 |
+
x_batch[i, :seq_len] = x.to(device)
|
| 919 |
+
attn_mask[i, :seq_len] = 1 # Mark valid positions
|
| 920 |
+
|
| 921 |
+
# log_rnd is a SCALAR per sequence, not a vector - just stack them
|
| 922 |
+
log_rnd_batch = torch.stack([lr.to(device) if isinstance(lr, torch.Tensor) else torch.tensor(lr, device=device) for lr in log_rnd_list])
|
| 923 |
+
|
| 924 |
+
directional_labels_batch = torch.tensor(
|
| 925 |
+
[item['directional_label'] for item in batch_data],
|
| 926 |
+
dtype=torch.float32,
|
| 927 |
+
device=device
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
# WDCE loss (with attention mask to handle variable-length sequences)
|
| 931 |
+
wdce_loss = loss_wdce(
|
| 932 |
+
policy_model,
|
| 933 |
+
log_rnd_batch,
|
| 934 |
+
x_batch,
|
| 935 |
+
num_replicates=args.wdce_num_replicates,
|
| 936 |
+
centering=args.centering,
|
| 937 |
+
attn_mask=attn_mask # Pass attention mask to avoid computing loss on padding
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
# KL loss
|
| 941 |
+
mask_index = policy_model.mask_index
|
| 942 |
+
lamda = torch.rand(x_batch.shape[0], device=device)
|
| 943 |
+
sigma_kl = -torch.log1p(-(1 - eps) * lamda)
|
| 944 |
+
masked_index = torch.rand(*x_batch.shape, device=device) < lamda[..., None]
|
| 945 |
+
perturbed_batch = torch.where(masked_index, mask_index, x_batch)
|
| 946 |
+
# Use the actual attention mask (not all ones) to handle variable-length sequences
|
| 947 |
+
attn_mask_kl = attn_mask.to(device)
|
| 948 |
+
|
| 949 |
+
kl_loss = td3b_loss_fn.compute_kl_loss(
|
| 950 |
+
policy_model,
|
| 951 |
+
perturbed_batch,
|
| 952 |
+
attn_mask_kl,
|
| 953 |
+
sigma_kl
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
# Contrastive loss (if we have multiple directions)
|
| 957 |
+
if len(torch.unique(directional_labels_batch)) > 1:
|
| 958 |
+
embeddings = extract_embeddings_from_mdlm(
|
| 959 |
+
policy_model,
|
| 960 |
+
x_batch,
|
| 961 |
+
pool_method=args.embedding_pool_method
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
debug_mode = (epoch < 3) or (epoch > 0 and batch_contrastive_losses and batch_contrastive_losses[-1] < 1e-6)
|
| 965 |
+
|
| 966 |
+
total_loss, loss_dict = td3b_loss_fn.compute_loss(
|
| 967 |
+
wdce_loss,
|
| 968 |
+
embeddings,
|
| 969 |
+
directional_labels_batch,
|
| 970 |
+
kl_loss=kl_loss,
|
| 971 |
+
debug=debug_mode
|
| 972 |
+
)
|
| 973 |
+
else:
|
| 974 |
+
# Only WDCE + KL if no contrastive
|
| 975 |
+
total_loss = wdce_loss + args.kl_beta * kl_loss
|
| 976 |
+
loss_dict = {
|
| 977 |
+
'total_loss': total_loss.item(),
|
| 978 |
+
'wdce_loss': wdce_loss.item(),
|
| 979 |
+
'contrastive_loss': 0.0,
|
| 980 |
+
'kl_loss': kl_loss.item()
|
| 981 |
+
}
|
| 982 |
+
|
| 983 |
+
# Scale loss for gradient accumulation
|
| 984 |
+
scaled_loss = total_loss / args.gradient_accumulation_steps
|
| 985 |
+
scaled_loss.backward()
|
| 986 |
+
|
| 987 |
+
# Accumulate losses
|
| 988 |
+
epoch_loss += loss_dict['total_loss']
|
| 989 |
+
epoch_wdce_loss += loss_dict['wdce_loss']
|
| 990 |
+
epoch_contrastive_loss += loss_dict['contrastive_loss']
|
| 991 |
+
epoch_kl_loss += loss_dict['kl_loss']
|
| 992 |
+
|
| 993 |
+
# Gradient accumulation
|
| 994 |
+
if (batch_idx + 1) % args.gradient_accumulation_steps == 0 or (batch_idx + 1) == num_batches:
|
| 995 |
+
if args.grad_clip:
|
| 996 |
+
torch.nn.utils.clip_grad_norm_(policy_model.parameters(), args.gradnorm_clip)
|
| 997 |
+
optimizer.step()
|
| 998 |
+
optimizer.zero_grad()
|
| 999 |
+
|
| 1000 |
+
# Average losses
|
| 1001 |
+
epoch_loss /= num_batches
|
| 1002 |
+
epoch_wdce_loss /= num_batches
|
| 1003 |
+
epoch_contrastive_loss /= num_batches
|
| 1004 |
+
epoch_kl_loss /= num_batches
|
| 1005 |
+
|
| 1006 |
+
batch_losses.append(epoch_loss)
|
| 1007 |
+
batch_wdce_losses.append(epoch_wdce_loss)
|
| 1008 |
+
batch_contrastive_losses.append(epoch_contrastive_loss)
|
| 1009 |
+
batch_kl_losses.append(epoch_kl_loss)
|
| 1010 |
+
|
| 1011 |
+
# Validation
|
| 1012 |
+
if val_dataset is not None and (epoch + 1) % args.validate_every_n_epochs == 0:
|
| 1013 |
+
val_metrics = run_validation(
|
| 1014 |
+
policy_model,
|
| 1015 |
+
multi_target_affinity,
|
| 1016 |
+
directional_oracle,
|
| 1017 |
+
tokenizer,
|
| 1018 |
+
val_dataset,
|
| 1019 |
+
args,
|
| 1020 |
+
epoch,
|
| 1021 |
+
device,
|
| 1022 |
+
protein_token_cache=protein_token_cache
|
| 1023 |
+
)
|
| 1024 |
+
|
| 1025 |
+
# Log to W&B
|
| 1026 |
+
wandb.log({
|
| 1027 |
+
"epoch": epoch,
|
| 1028 |
+
"val/affinity_mean": val_metrics['affinity_mean'],
|
| 1029 |
+
"val/affinity_std": val_metrics['affinity_std'],
|
| 1030 |
+
"val/gated_reward_mean": val_metrics['gated_reward_mean'],
|
| 1031 |
+
"val/gated_reward_std": val_metrics['gated_reward_std'],
|
| 1032 |
+
"val/direction_oracle_mean": val_metrics['direction_oracle_mean'],
|
| 1033 |
+
"val/direction_oracle_std": val_metrics['direction_oracle_std'],
|
| 1034 |
+
"val/consistency_reward_mean": val_metrics['consistency_reward_mean'],
|
| 1035 |
+
"val/consistency_reward_std": val_metrics['consistency_reward_std'],
|
| 1036 |
+
"val/consistency_agonist_mean": val_metrics['consistency_agonist_mean'],
|
| 1037 |
+
"val/consistency_antagonist_mean": val_metrics['consistency_antagonist_mean'],
|
| 1038 |
+
"val/valid_fraction_mean": val_metrics['valid_fraction_mean'],
|
| 1039 |
+
"val/direction_accuracy_mean": val_metrics['direction_accuracy_mean'],
|
| 1040 |
+
"val/direction_accuracy_std": val_metrics['direction_accuracy_std'],
|
| 1041 |
+
"val/success_rate_mean": val_metrics['success_rate_mean'],
|
| 1042 |
+
"val/success_rate_std": val_metrics['success_rate_std']
|
| 1043 |
+
})
|
| 1044 |
+
|
| 1045 |
+
# Save checkpoint
|
| 1046 |
+
if (epoch + 1) % args.save_every_n_epochs == 0:
|
| 1047 |
+
model_path = os.path.join(args.save_path, f'model_epoch_{epoch}.ckpt')
|
| 1048 |
+
save_model(policy_model, model_path, config=vars(args), epoch=epoch)
|
| 1049 |
+
|
| 1050 |
+
# Final save
|
| 1051 |
+
final_model_path = os.path.join(args.save_path, 'model_final.ckpt')
|
| 1052 |
+
save_model(policy_model, final_model_path, config=vars(args))
|
| 1053 |
+
|
| 1054 |
+
cleanup_wandb()
|
| 1055 |
+
logger.info(f"\n{SEPARATOR_LINE}")
|
| 1056 |
+
logger.info("Training completed!")
|
| 1057 |
+
logger.info(f"{SEPARATOR_LINE}\n")
|
| 1058 |
+
|
| 1059 |
+
|
| 1060 |
+
if __name__ == '__main__':
|
| 1061 |
+
main()
|
finetune_utils.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions for TD3B finetuning and sampling."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import wandb
|
| 15 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
from diffusion import Diffusion
|
| 19 |
+
from td3b.td3b_mcts import create_td3b_mcts
|
| 20 |
+
from td3b.td3b_scoring import TD3BRewardFunction
|
| 21 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 22 |
+
from utils.utils import sample_categorical_logits
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
# Standard checkpoint keys to try when loading
|
| 27 |
+
CHECKPOINT_KEYS = ("state_dict", "model_state_dict")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def to_one_hot(x_idx, num_classes=4):
|
| 31 |
+
oh = F.one_hot(x_idx.long(), num_classes=num_classes)
|
| 32 |
+
return oh.float()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def rnd(model, reward_model, batch_size, scale=1, device="cuda:0"):
|
| 36 |
+
r"""
|
| 37 |
+
Run random order sampling and compute the RND $\log\frac{dP^*}{dP^u}$ along the trajectory
|
| 38 |
+
reward_model: r(X)
|
| 39 |
+
|
| 40 |
+
return:
|
| 41 |
+
- x: the final samples, [B, D]
|
| 42 |
+
- log_rnd: the log RND along this trajectory, [B]
|
| 43 |
+
"""
|
| 44 |
+
if hasattr(model, "module"):
|
| 45 |
+
model = model.module
|
| 46 |
+
|
| 47 |
+
x = torch.full((batch_size, model.length), model.vocab_size - 1).to(device=device, dtype=torch.int64)
|
| 48 |
+
batch_arange = torch.arange(batch_size, device=device)
|
| 49 |
+
jump_pos = torch.rand(x.shape, device=device).argsort(dim=-1)
|
| 50 |
+
# jump_times, jump_pos = torch.rand(x.shape, device=device).sort(dim=-1)
|
| 51 |
+
# jump_times: Unif[0,1] in increasing order
|
| 52 |
+
# jump_pos: random permutation of range(D)
|
| 53 |
+
log_rnd = torch.zeros(batch_size, device=device) # [B]
|
| 54 |
+
for d in range(model.length - 1, -1, -1):
|
| 55 |
+
# jump at time jump_times[:, d] at position jump_pos[:, d]
|
| 56 |
+
logits = model(x)[:, :, :-1] # [B, D, N-1]
|
| 57 |
+
update = sample_categorical_logits(logits[batch_arange, jump_pos[:, d]]) # [B]
|
| 58 |
+
if torch.is_grad_enabled(): # avoid issues with in-place operations
|
| 59 |
+
x = x.clone()
|
| 60 |
+
x[batch_arange, jump_pos[:, d]] = update
|
| 61 |
+
log_rnd += -np.log(model.vocab_size - 1) - logits[batch_arange, jump_pos[:, d], update]
|
| 62 |
+
log_rnd += scale * reward_model(x) # [B]
|
| 63 |
+
return x, log_rnd
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def sampling(model, batch_size, rounds=1, device="cuda:0"):
|
| 68 |
+
"""Any order autoregressive sampling"""
|
| 69 |
+
if hasattr(model, "module"):
|
| 70 |
+
model = model.module
|
| 71 |
+
batch_arange = torch.arange(batch_size, device=device)
|
| 72 |
+
all_samples = []
|
| 73 |
+
for _ in tqdm(range(rounds), leave=False):
|
| 74 |
+
x = torch.full((batch_size, model.length), model.vocab_size - 1).to(device=device, dtype=torch.int64)
|
| 75 |
+
jump_pos = torch.rand(x.shape, device=device).argsort(dim=-1)
|
| 76 |
+
# jump_times, jump_pos = torch.rand(x.shape, device=device).sort(dim=-1)
|
| 77 |
+
# jump_times: Unif[0,1] in increasing order
|
| 78 |
+
# jump_pos: random permutation of range(D)
|
| 79 |
+
for d in tqdm(range(model.length - 1, -1, -1), leave=False):
|
| 80 |
+
# jump at time jump_times[:, d] at position jump_pos[:, d]
|
| 81 |
+
logits = model.logits(x)[:, :, :-1] # [B, D, N-1], not log-softmaxed but fine
|
| 82 |
+
update = sample_categorical_logits(logits[batch_arange, jump_pos[:, d]]) # [B]
|
| 83 |
+
x[batch_arange, jump_pos[:, d]] = update
|
| 84 |
+
all_samples.append(x)
|
| 85 |
+
return torch.cat(all_samples) # (rounds * B, L)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def loss_ce(log_rnd):
|
| 89 |
+
"""Cross entropy loss KL(P^*||P^u)"""
|
| 90 |
+
weights = log_rnd.detach().softmax(dim=-1)
|
| 91 |
+
return (log_rnd * weights).sum()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def loss_lv(log_rnd):
|
| 95 |
+
r"""Log variance loss Var_{P^\bar{u}}\log\frac{dP^*}{dP^u}"""
|
| 96 |
+
return log_rnd.var()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def loss_re_rf(log_rnd, const=0):
|
| 100 |
+
r"""Relative entropy loss KL(P^u||P^*) with REINFORCE trick"""
|
| 101 |
+
return (-log_rnd * (-log_rnd.detach() + const)).mean()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def loss_wdce(
|
| 105 |
+
policy_model,
|
| 106 |
+
log_rnd,
|
| 107 |
+
x,
|
| 108 |
+
num_replicates=16,
|
| 109 |
+
weight_func=lambda l: 1 / l,
|
| 110 |
+
eps=1e-3,
|
| 111 |
+
centering=False,
|
| 112 |
+
attn_mask=None,
|
| 113 |
+
):
|
| 114 |
+
r"""
|
| 115 |
+
Weighted denoising cross entropy loss
|
| 116 |
+
X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
|
| 117 |
+
|
| 118 |
+
log_rnd: [B]; x: [B, L] (no mask)
|
| 119 |
+
num_replicates: R, number of replicates of each row in x
|
| 120 |
+
weight_func: w(lambda) for each sample, 1/lambda by default
|
| 121 |
+
attn_mask: [B, L] attention mask (1 for real tokens, 0 for padding) - IMPORTANT for variable-length sequences
|
| 122 |
+
"""
|
| 123 |
+
mask_index = policy_model.mask_index
|
| 124 |
+
if hasattr(policy_model, "module"):
|
| 125 |
+
policy_model = policy_model.module
|
| 126 |
+
|
| 127 |
+
batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
|
| 128 |
+
|
| 129 |
+
batch_weights = log_rnd.detach_().softmax(dim=-1) # [B*R]
|
| 130 |
+
if centering:
|
| 131 |
+
batch_weights = batch_weights - batch_weights.mean(dim=-1, keepdim=True)
|
| 132 |
+
|
| 133 |
+
batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
|
| 134 |
+
|
| 135 |
+
lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
|
| 136 |
+
lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
|
| 137 |
+
|
| 138 |
+
masked_index = torch.rand(*batch.shape, device=batch.device) < lamda[..., None] # [B*R, D]
|
| 139 |
+
perturbed_batch = torch.where(masked_index, mask_index, batch)
|
| 140 |
+
|
| 141 |
+
# add time conditioning
|
| 142 |
+
t = lamda
|
| 143 |
+
sigma_t = -torch.log1p(-(1 - eps) * t)
|
| 144 |
+
|
| 145 |
+
# Use provided attention mask or create default (all ones for fixed-length)
|
| 146 |
+
if attn_mask is not None:
|
| 147 |
+
attn_mask = attn_mask.repeat_interleave(num_replicates, dim=0).to(policy_model.device)
|
| 148 |
+
else:
|
| 149 |
+
attn_mask = torch.ones_like(perturbed_batch).to(policy_model.device)
|
| 150 |
+
|
| 151 |
+
# compute logits
|
| 152 |
+
logits = policy_model(perturbed_batch, attn_mask=attn_mask, sigma=sigma_t)
|
| 153 |
+
losses = torch.zeros(*batch.shape, device=batch.device, dtype=logits.dtype) # [B*R, D]
|
| 154 |
+
losses[masked_index] = torch.gather(
|
| 155 |
+
input=logits[masked_index], dim=-1, index=batch[masked_index][..., None]
|
| 156 |
+
).squeeze(-1)
|
| 157 |
+
|
| 158 |
+
# Apply attention mask to exclude padding tokens from loss computation.
|
| 159 |
+
losses = losses * attn_mask
|
| 160 |
+
|
| 161 |
+
return -((losses.sum(dim=-1) * lamda_weights * batch_weights).mean())
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def loss_dce(model, x, weight_func=lambda l: 1 / l):
|
| 165 |
+
r"""
|
| 166 |
+
Denoising cross entropy loss, x [B, D] are ground truth samples
|
| 167 |
+
weight_func: w(lambda) for each sample, 1/lambda by default
|
| 168 |
+
"""
|
| 169 |
+
lamda = torch.rand(x.shape[0], device=x.device) # [B]
|
| 170 |
+
lamda_weights = weight_func(lamda).clamp(max=1e5) # [B]
|
| 171 |
+
masked_index = torch.rand(*x.shape, device=x.device) < lamda[..., None] # [B, D]
|
| 172 |
+
perturbed_batch = torch.where(masked_index, model.vocab_size - 1, x)
|
| 173 |
+
logits = model(perturbed_batch)
|
| 174 |
+
losses = torch.zeros(*x.shape, device=x.device, dtype=logits.dtype) # [B, D]
|
| 175 |
+
losses[masked_index] = torch.gather(
|
| 176 |
+
input=logits[masked_index], dim=-1, index=x[masked_index][..., None]
|
| 177 |
+
).squeeze(-1)
|
| 178 |
+
return -((losses.sum(dim=-1) * lamda_weights).mean())
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def load_tokenizer(base_path: str) -> SMILES_SPE_Tokenizer:
|
| 182 |
+
"""
|
| 183 |
+
Load the peptide tokenizer from the standard location.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
base_path: Base directory path (e.g., 'To Be Added')
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Loaded SMILES_SPE_Tokenizer instance
|
| 190 |
+
|
| 191 |
+
Example:
|
| 192 |
+
>>> tokenizer = load_tokenizer('To Be Added')
|
| 193 |
+
"""
|
| 194 |
+
base_path = Path(base_path)
|
| 195 |
+
vocab_path = base_path / "tr2d2-pep" / "tokenizer" / "new_vocab.txt"
|
| 196 |
+
spe_path = base_path / "tr2d2-pep" / "tokenizer" / "new_splits.txt"
|
| 197 |
+
|
| 198 |
+
if not vocab_path.exists():
|
| 199 |
+
raise FileNotFoundError(f"Vocabulary file not found: {vocab_path}")
|
| 200 |
+
if not spe_path.exists():
|
| 201 |
+
raise FileNotFoundError(f"SPE splits file not found: {spe_path}")
|
| 202 |
+
|
| 203 |
+
tokenizer = SMILES_SPE_Tokenizer(str(vocab_path), str(spe_path))
|
| 204 |
+
logger.info("Loaded tokenizer with vocab_size=%s", tokenizer.vocab_size)
|
| 205 |
+
|
| 206 |
+
return tokenizer
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def load_checkpoint(
|
| 210 |
+
checkpoint_path: str,
|
| 211 |
+
model: torch.nn.Module,
|
| 212 |
+
device: torch.device,
|
| 213 |
+
strict: bool = True,
|
| 214 |
+
) -> Dict[str, Any]:
|
| 215 |
+
"""
|
| 216 |
+
Load model weights from checkpoint with automatic key detection.
|
| 217 |
+
|
| 218 |
+
Handles different checkpoint formats:
|
| 219 |
+
- {'state_dict': ...}
|
| 220 |
+
- {'model_state_dict': ...}
|
| 221 |
+
- Direct state_dict
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
checkpoint_path: Path to checkpoint file
|
| 225 |
+
model: Model to load weights into
|
| 226 |
+
device: Device to load checkpoint onto
|
| 227 |
+
strict: Whether to strictly enforce state_dict keys match
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
Full checkpoint dictionary (for accessing metadata like epoch, config, etc.)
|
| 231 |
+
|
| 232 |
+
Raises:
|
| 233 |
+
FileNotFoundError: If checkpoint file doesn't exist
|
| 234 |
+
RuntimeError: If checkpoint loading fails
|
| 235 |
+
|
| 236 |
+
Example:
|
| 237 |
+
>>> checkpoint = load_checkpoint('model.ckpt', model, device, strict=False)
|
| 238 |
+
>>> if 'epoch' in checkpoint:
|
| 239 |
+
>>> print(f"Loaded from epoch {checkpoint['epoch']}")
|
| 240 |
+
"""
|
| 241 |
+
if not os.path.exists(checkpoint_path):
|
| 242 |
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
| 243 |
+
|
| 244 |
+
logger.info("Loading checkpoint from: %s", checkpoint_path)
|
| 245 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 246 |
+
|
| 247 |
+
# Try to find state_dict in standard checkpoint keys
|
| 248 |
+
state_dict = None
|
| 249 |
+
for key in CHECKPOINT_KEYS:
|
| 250 |
+
if key in checkpoint:
|
| 251 |
+
state_dict = checkpoint[key]
|
| 252 |
+
logger.info("Found state_dict at checkpoint key: '%s'", key)
|
| 253 |
+
break
|
| 254 |
+
|
| 255 |
+
# If not found in standard keys, assume checkpoint IS the state_dict
|
| 256 |
+
if state_dict is None:
|
| 257 |
+
state_dict = checkpoint
|
| 258 |
+
logger.info("Loading checkpoint as direct state_dict")
|
| 259 |
+
|
| 260 |
+
# Load state dict into model
|
| 261 |
+
try:
|
| 262 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
| 263 |
+
if not strict and (incompatible_keys.missing_keys or incompatible_keys.unexpected_keys):
|
| 264 |
+
logger.warning("Incompatible keys when loading checkpoint:")
|
| 265 |
+
if incompatible_keys.missing_keys:
|
| 266 |
+
logger.warning(" Missing keys: %s...", incompatible_keys.missing_keys[:5])
|
| 267 |
+
if incompatible_keys.unexpected_keys:
|
| 268 |
+
logger.warning(" Unexpected keys: %s...", incompatible_keys.unexpected_keys[:5])
|
| 269 |
+
else:
|
| 270 |
+
logger.info("Checkpoint loaded successfully")
|
| 271 |
+
except Exception as exc:
|
| 272 |
+
raise RuntimeError(f"Failed to load checkpoint: {exc}")
|
| 273 |
+
|
| 274 |
+
return checkpoint
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def initialize_device(device_str: str = "cuda") -> torch.device:
|
| 278 |
+
"""
|
| 279 |
+
Initialize compute device with fallback to CPU if CUDA unavailable or invalid.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
device_str: Requested device ('cuda', 'cuda:0', 'cpu', or 'auto')
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
Torch device object
|
| 286 |
+
|
| 287 |
+
Example:
|
| 288 |
+
>>> device = initialize_device('cuda')
|
| 289 |
+
>>> print(device) # cuda:0 or cpu
|
| 290 |
+
"""
|
| 291 |
+
if device_str is None or str(device_str).lower() == "auto":
|
| 292 |
+
device_str = "cuda:0" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu"
|
| 293 |
+
|
| 294 |
+
try:
|
| 295 |
+
device = torch.device(device_str)
|
| 296 |
+
except Exception as exc:
|
| 297 |
+
logger.warning("Invalid device '%s': %s. Falling back to CPU.", device_str, exc)
|
| 298 |
+
return torch.device("cpu")
|
| 299 |
+
|
| 300 |
+
if device.type != "cuda":
|
| 301 |
+
logger.info("Using device: %s", device)
|
| 302 |
+
return device
|
| 303 |
+
|
| 304 |
+
if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
|
| 305 |
+
logger.warning("CUDA requested but not available, falling back to CPU")
|
| 306 |
+
return torch.device("cpu")
|
| 307 |
+
|
| 308 |
+
index = device.index if device.index is not None else 0
|
| 309 |
+
if index < 0 or index >= torch.cuda.device_count():
|
| 310 |
+
logger.warning(
|
| 311 |
+
"CUDA device %s requested but only %d visible; using cuda:0",
|
| 312 |
+
index,
|
| 313 |
+
torch.cuda.device_count(),
|
| 314 |
+
)
|
| 315 |
+
device = torch.device("cuda:0")
|
| 316 |
+
|
| 317 |
+
logger.info("Using device: %s (%s)", device, torch.cuda.get_device_name(device.index or 0))
|
| 318 |
+
return device
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def create_output_directory(base_path: str, run_name: str, add_timestamp: bool = True) -> str:
|
| 322 |
+
"""
|
| 323 |
+
Create output directory for saving results.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
base_path: Base directory (e.g., 'To Be Added')
|
| 327 |
+
run_name: Name for this training run
|
| 328 |
+
add_timestamp: Whether to append timestamp to run_name
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
Path to created output directory
|
| 332 |
+
|
| 333 |
+
Example:
|
| 334 |
+
>>> save_path = create_output_directory('To Be Added', 'my_run')
|
| 335 |
+
>>> # Creates: To Be Added
|
| 336 |
+
"""
|
| 337 |
+
if add_timestamp:
|
| 338 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 339 |
+
dir_name = f"{run_name}_{timestamp}"
|
| 340 |
+
else:
|
| 341 |
+
dir_name = run_name
|
| 342 |
+
|
| 343 |
+
output_dir = os.path.join(base_path, "tr2d2-pep", "results", dir_name)
|
| 344 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 345 |
+
|
| 346 |
+
logger.info("Created output directory: %s", output_dir)
|
| 347 |
+
return output_dir
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def save_model(
|
| 351 |
+
model: torch.nn.Module,
|
| 352 |
+
save_path: str,
|
| 353 |
+
config: Optional[Dict[str, Any]] = None,
|
| 354 |
+
epoch: Optional[int] = None,
|
| 355 |
+
optimizer_state: Optional[Dict] = None,
|
| 356 |
+
) -> None:
|
| 357 |
+
"""
|
| 358 |
+
Save model checkpoint with optional metadata.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
model: Model to save
|
| 362 |
+
save_path: Path to save checkpoint
|
| 363 |
+
config: Optional configuration dictionary to save
|
| 364 |
+
epoch: Optional epoch number
|
| 365 |
+
optimizer_state: Optional optimizer state dict
|
| 366 |
+
|
| 367 |
+
Example:
|
| 368 |
+
>>> save_model(model, 'checkpoint.ckpt', config=vars(args), epoch=10)
|
| 369 |
+
"""
|
| 370 |
+
checkpoint = {"model_state_dict": model.state_dict()}
|
| 371 |
+
|
| 372 |
+
if config is not None:
|
| 373 |
+
checkpoint["config"] = config
|
| 374 |
+
if epoch is not None:
|
| 375 |
+
checkpoint["epoch"] = epoch
|
| 376 |
+
if optimizer_state is not None:
|
| 377 |
+
checkpoint["optimizer_state_dict"] = optimizer_state
|
| 378 |
+
|
| 379 |
+
torch.save(checkpoint, save_path)
|
| 380 |
+
logger.info("Model saved: %s", save_path)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def setup_wandb(project: str, name: str, config: Dict[str, Any], entity: Optional[str] = None) -> None:
|
| 384 |
+
"""
|
| 385 |
+
Initialize Weights & Biases logging.
|
| 386 |
+
|
| 387 |
+
Args:
|
| 388 |
+
project: W&B project name
|
| 389 |
+
name: Run name
|
| 390 |
+
config: Configuration dictionary to log
|
| 391 |
+
entity: Optional W&B team/entity name
|
| 392 |
+
|
| 393 |
+
Example:
|
| 394 |
+
>>> setup_wandb('my-project', 'run1', vars(args), entity='my-team')
|
| 395 |
+
"""
|
| 396 |
+
wandb_config = {
|
| 397 |
+
"project": project,
|
| 398 |
+
"name": name,
|
| 399 |
+
"config": config,
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
if entity:
|
| 403 |
+
wandb_config["entity"] = entity
|
| 404 |
+
|
| 405 |
+
wandb.init(**wandb_config)
|
| 406 |
+
logger.info("Initialized W&B: project=%s, run=%s", project, name)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def cleanup_wandb() -> None:
|
| 410 |
+
"""Finish W&B logging session."""
|
| 411 |
+
wandb.finish()
|
| 412 |
+
logger.info("Finished W&B logging")
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def get_mask_index(tokenizer: SMILES_SPE_Tokenizer) -> int:
|
| 416 |
+
"""
|
| 417 |
+
Get mask token index from tokenizer.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
tokenizer: Peptide tokenizer
|
| 421 |
+
|
| 422 |
+
Returns:
|
| 423 |
+
Mask token ID
|
| 424 |
+
|
| 425 |
+
Note:
|
| 426 |
+
Standardizes mask index retrieval across different code paths.
|
| 427 |
+
"""
|
| 428 |
+
if hasattr(tokenizer, "mask_token_id"):
|
| 429 |
+
return tokenizer.mask_token_id
|
| 430 |
+
return tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def create_mcts_instance(
|
| 434 |
+
args,
|
| 435 |
+
policy_model: Diffusion,
|
| 436 |
+
reward_function: TD3BRewardFunction,
|
| 437 |
+
tokenizer: SMILES_SPE_Tokenizer,
|
| 438 |
+
buffer_size: Optional[int] = None,
|
| 439 |
+
) -> Any:
|
| 440 |
+
"""
|
| 441 |
+
Create TD3B MCTS instance with standardized configuration.
|
| 442 |
+
|
| 443 |
+
Args:
|
| 444 |
+
args: Training arguments
|
| 445 |
+
policy_model: Diffusion policy model
|
| 446 |
+
reward_function: TD3B reward function
|
| 447 |
+
tokenizer: Peptide tokenizer
|
| 448 |
+
buffer_size: Optional buffer size (uses args.buffer_size if None)
|
| 449 |
+
|
| 450 |
+
Returns:
|
| 451 |
+
TD3B_MCTS instance
|
| 452 |
+
|
| 453 |
+
Example:
|
| 454 |
+
>>> mcts = create_mcts_instance(args, model, reward_func, tokenizer)
|
| 455 |
+
"""
|
| 456 |
+
if hasattr(args, "no_mcts") and args.no_mcts:
|
| 457 |
+
logger.info("MCTS disabled (--no_mcts flag)")
|
| 458 |
+
return None
|
| 459 |
+
|
| 460 |
+
# Get mask index using standardized method
|
| 461 |
+
mask_index = get_mask_index(tokenizer)
|
| 462 |
+
|
| 463 |
+
# Use provided buffer_size or fall back to args
|
| 464 |
+
if buffer_size is None:
|
| 465 |
+
buffer_size = getattr(args, "buffer_size", 50)
|
| 466 |
+
|
| 467 |
+
mcts = create_td3b_mcts(
|
| 468 |
+
args=args,
|
| 469 |
+
diffusion_model=policy_model,
|
| 470 |
+
td3b_reward_function=reward_function,
|
| 471 |
+
alpha=getattr(args, "alpha", 0.1),
|
| 472 |
+
mask_index=mask_index,
|
| 473 |
+
buffer_size=buffer_size,
|
| 474 |
+
tokenizer=tokenizer,
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
logger.info("Created TD3B MCTS (buffer_size=%s, alpha=%s)", buffer_size, args.alpha)
|
| 478 |
+
return mcts
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def create_reward_function(
|
| 482 |
+
affinity_predictor,
|
| 483 |
+
directional_oracle,
|
| 484 |
+
target_direction: float,
|
| 485 |
+
target_protein_tokens: torch.Tensor,
|
| 486 |
+
tokenizer: SMILES_SPE_Tokenizer,
|
| 487 |
+
device: torch.device,
|
| 488 |
+
min_affinity_threshold: float = 0.0,
|
| 489 |
+
use_confidence_weighting: bool = True,
|
| 490 |
+
temperature: float = 0.1,
|
| 491 |
+
) -> TD3BRewardFunction:
|
| 492 |
+
"""
|
| 493 |
+
Create TD3B reward function with standardized parameters.
|
| 494 |
+
|
| 495 |
+
Args:
|
| 496 |
+
affinity_predictor: Binding affinity prediction model
|
| 497 |
+
directional_oracle: Directional prediction oracle
|
| 498 |
+
target_direction: Target direction (1.0 for agonist, -1.0 for antagonist)
|
| 499 |
+
target_protein_tokens: Protein target tokens
|
| 500 |
+
tokenizer: Peptide tokenizer
|
| 501 |
+
device: Compute device
|
| 502 |
+
min_affinity_threshold: Minimum affinity for allosteric control
|
| 503 |
+
use_confidence_weighting: Whether to use confidence weighting
|
| 504 |
+
temperature: Temperature for sigmoid gating
|
| 505 |
+
|
| 506 |
+
Returns:
|
| 507 |
+
TD3BRewardFunction instance
|
| 508 |
+
|
| 509 |
+
Example:
|
| 510 |
+
>>> reward_func = create_reward_function(
|
| 511 |
+
... affinity_pred, oracle, 1.0, target_tokens,
|
| 512 |
+
... tokenizer, device, min_affinity_threshold=0.5
|
| 513 |
+
... )
|
| 514 |
+
"""
|
| 515 |
+
reward_func = TD3BRewardFunction(
|
| 516 |
+
affinity_predictor=affinity_predictor,
|
| 517 |
+
directional_oracle=directional_oracle,
|
| 518 |
+
target_direction=target_direction,
|
| 519 |
+
target_protein_tokens=target_protein_tokens,
|
| 520 |
+
peptide_tokenizer=tokenizer,
|
| 521 |
+
device=device,
|
| 522 |
+
min_affinity_threshold=min_affinity_threshold,
|
| 523 |
+
use_confidence_weighting=use_confidence_weighting,
|
| 524 |
+
temperature=temperature,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
logger.info(
|
| 528 |
+
"Created TD3B reward function (d*=%s, threshold=%s)", target_direction, min_affinity_threshold
|
| 529 |
+
)
|
| 530 |
+
return reward_func
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def log_gpu_memory(stage: str = "") -> None:
|
| 534 |
+
"""
|
| 535 |
+
Log current GPU memory usage.
|
| 536 |
+
|
| 537 |
+
Args:
|
| 538 |
+
stage: Optional stage description for logging context
|
| 539 |
+
|
| 540 |
+
Example:
|
| 541 |
+
>>> log_gpu_memory("After model loading")
|
| 542 |
+
"""
|
| 543 |
+
if torch.cuda.is_available():
|
| 544 |
+
allocated = torch.cuda.memory_allocated() / 1024**3 # GB
|
| 545 |
+
reserved = torch.cuda.memory_reserved() / 1024**3 # GB
|
| 546 |
+
stage_str = f" [{stage}]" if stage else ""
|
| 547 |
+
logger.info(
|
| 548 |
+
"GPU Memory%s: %.2fGB allocated, %.2fGB reserved",
|
| 549 |
+
stage_str,
|
| 550 |
+
allocated,
|
| 551 |
+
reserved,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
| 556 |
+
"""
|
| 557 |
+
Count total and trainable parameters in model.
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
model: PyTorch model
|
| 561 |
+
|
| 562 |
+
Returns:
|
| 563 |
+
Tuple of (total_params, trainable_params)
|
| 564 |
+
|
| 565 |
+
Example:
|
| 566 |
+
>>> total, trainable = count_parameters(model)
|
| 567 |
+
>>> print(f"Total: {total:,}, Trainable: {trainable:,}")
|
| 568 |
+
"""
|
| 569 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 570 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 571 |
+
return total_params, trainable_params
|
inference.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
TD3B Inference Script
|
| 4 |
+
Generate directional binders for target proteins using a finetuned TD3B model.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python inference.py \
|
| 8 |
+
--ckpt_path checkpoints/td3b.ckpt \
|
| 9 |
+
--val_csv data/test.csv \
|
| 10 |
+
--save_path results/ \
|
| 11 |
+
--seed 42
|
| 12 |
+
"""
|
| 13 |
+
import argparse
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
import logging
|
| 17 |
+
from typing import Dict, List, Tuple
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import pandas as pd
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 24 |
+
if ROOT_DIR not in sys.path:
|
| 25 |
+
sys.path.insert(0, ROOT_DIR)
|
| 26 |
+
|
| 27 |
+
from diffusion import Diffusion
|
| 28 |
+
from configs.finetune_config import (
|
| 29 |
+
DiffusionConfig, RoFormerConfig, NoiseConfig,
|
| 30 |
+
TrainingConfig, SamplingConfig, EvalConfig, OptimConfig, MCTSConfig,
|
| 31 |
+
)
|
| 32 |
+
from finetune_utils import load_tokenizer, create_reward_function
|
| 33 |
+
from td3b.direction_oracle import DirectionalOracle
|
| 34 |
+
from td3b.td3b_scoring import create_td3b_reward_function
|
| 35 |
+
from utils.app import PeptideAnalyzer
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 39 |
+
|
| 40 |
+
# ─── Defaults ─────────────────────────────────────────────────────────────────
|
| 41 |
+
DEFAULTS = dict(
|
| 42 |
+
seq_length=200,
|
| 43 |
+
sampling_eps=1e-3,
|
| 44 |
+
total_num_steps=128,
|
| 45 |
+
hidden_dim=768,
|
| 46 |
+
num_layers=8,
|
| 47 |
+
num_heads=8,
|
| 48 |
+
alpha=0.1,
|
| 49 |
+
min_affinity_threshold=0.0,
|
| 50 |
+
sigmoid_temperature=0.1,
|
| 51 |
+
num_pool=32,
|
| 52 |
+
val_samples_per_target=8,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def load_model(ckpt_path: str, device: torch.device):
|
| 57 |
+
"""Load finetuned TD3B model from checkpoint."""
|
| 58 |
+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 59 |
+
state_dict = ckpt.get("model_state_dict") or ckpt.get("state_dict") or ckpt
|
| 60 |
+
config = ckpt.get("config") or {}
|
| 61 |
+
|
| 62 |
+
tokenizer = load_tokenizer(ROOT_DIR)
|
| 63 |
+
|
| 64 |
+
cfg = DiffusionConfig(
|
| 65 |
+
roformer=RoFormerConfig(
|
| 66 |
+
hidden_size=config.get("hidden_dim", 768),
|
| 67 |
+
n_layers=config.get("num_layers", 8),
|
| 68 |
+
n_heads=config.get("num_heads", 8),
|
| 69 |
+
),
|
| 70 |
+
noise=NoiseConfig(),
|
| 71 |
+
training=TrainingConfig(sampling_eps=1e-3),
|
| 72 |
+
sampling=SamplingConfig(steps=128, sampling_eps=1e-3),
|
| 73 |
+
eval_cfg=EvalConfig(),
|
| 74 |
+
optim=OptimConfig(lr=3e-4),
|
| 75 |
+
mcts=MCTSConfig(),
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
model = Diffusion(config=cfg, tokenizer=tokenizer, device=device).to(device)
|
| 79 |
+
model.load_state_dict(state_dict, strict=False)
|
| 80 |
+
model.eval()
|
| 81 |
+
model.tokenizer = tokenizer
|
| 82 |
+
return model, tokenizer
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def sample_sequences(model, batch_size: int, seq_length: int, num_steps: int, eps: float = 1e-5):
|
| 86 |
+
"""Sample sequences from the diffusion model."""
|
| 87 |
+
x = model.sample_prior(batch_size, seq_length).to(model.device, dtype=torch.long)
|
| 88 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=model.device)
|
| 89 |
+
dt = torch.tensor((1 - eps) / num_steps, device=model.device)
|
| 90 |
+
|
| 91 |
+
for i in range(num_steps):
|
| 92 |
+
t = timesteps[i] * torch.ones(x.shape[0], 1, device=model.device)
|
| 93 |
+
_, x = model.single_reverse_step(x, t=t, dt=dt)
|
| 94 |
+
x = x.to(model.device)
|
| 95 |
+
|
| 96 |
+
# Remove remaining masks
|
| 97 |
+
mask_pos = (x == model.mask_index)
|
| 98 |
+
if mask_pos.any():
|
| 99 |
+
t = timesteps[-2] * torch.ones(x.shape[0], 1, device=model.device)
|
| 100 |
+
_, x = model.single_noise_removal(x, t=t, dt=dt)
|
| 101 |
+
x = x.to(model.device)
|
| 102 |
+
|
| 103 |
+
return x
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def score_sequences(reward_model, sequences: List[str]):
|
| 107 |
+
"""Score sequences with the TD3B reward function."""
|
| 108 |
+
result = reward_model(sequences)
|
| 109 |
+
if isinstance(result, tuple):
|
| 110 |
+
rewards, info = result
|
| 111 |
+
return (
|
| 112 |
+
np.asarray(rewards),
|
| 113 |
+
np.asarray(info.get("affinities", rewards)),
|
| 114 |
+
np.asarray(info.get("directions", np.zeros_like(rewards))),
|
| 115 |
+
np.asarray(info.get("confidences", np.ones_like(rewards))),
|
| 116 |
+
)
|
| 117 |
+
rewards = np.asarray(result)
|
| 118 |
+
return rewards, rewards, np.zeros_like(rewards), np.ones_like(rewards)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def main():
|
| 122 |
+
parser = argparse.ArgumentParser(description="TD3B Inference")
|
| 123 |
+
parser.add_argument("--ckpt_path", type=str, required=True, help="Path to TD3B checkpoint")
|
| 124 |
+
parser.add_argument("--val_csv", type=str, required=True, help="CSV with Target_Sequence, Ligand_Sequence, label columns")
|
| 125 |
+
parser.add_argument("--save_path", type=str, default="results", help="Output directory")
|
| 126 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
| 127 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 128 |
+
parser.add_argument("--num_pool", type=int, default=32, help="Pool size for candidate generation")
|
| 129 |
+
parser.add_argument("--val_samples_per_target", type=int, default=8, help="Samples to keep per target-direction")
|
| 130 |
+
parser.add_argument("--resample_alpha", type=float, default=0.1, help="Temperature for weighted resampling")
|
| 131 |
+
parser.add_argument("--direction_oracle_ckpt", type=str, default=None)
|
| 132 |
+
parser.add_argument("--direction_oracle_tr2d2_checkpoint", type=str, default=None)
|
| 133 |
+
args = parser.parse_args()
|
| 134 |
+
|
| 135 |
+
# Setup
|
| 136 |
+
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
| 137 |
+
torch.manual_seed(args.seed)
|
| 138 |
+
np.random.seed(args.seed)
|
| 139 |
+
os.makedirs(args.save_path, exist_ok=True)
|
| 140 |
+
|
| 141 |
+
analyzer = PeptideAnalyzer()
|
| 142 |
+
|
| 143 |
+
# Load model
|
| 144 |
+
logger.info(f"Loading model from {args.ckpt_path}")
|
| 145 |
+
model, tokenizer = load_model(args.ckpt_path, device)
|
| 146 |
+
|
| 147 |
+
# Load targets
|
| 148 |
+
logger.info(f"Loading targets from {args.val_csv}")
|
| 149 |
+
df = pd.read_csv(args.val_csv)
|
| 150 |
+
targets = []
|
| 151 |
+
for _, row in df.iterrows():
|
| 152 |
+
targets.append({
|
| 153 |
+
"target_seq": row["Target_Sequence"],
|
| 154 |
+
"target_uid": row.get("Target_UniProt_ID", ""),
|
| 155 |
+
"binder_seq": row.get("Ligand_Sequence", ""),
|
| 156 |
+
"label": row.get("label", ""),
|
| 157 |
+
"seq_length": min(len(row.get("Ligand_SMILES", "x" * 200)), 200),
|
| 158 |
+
})
|
| 159 |
+
|
| 160 |
+
# Build reward function for each target
|
| 161 |
+
logger.info("Building reward functions...")
|
| 162 |
+
oracle_ckpt = args.direction_oracle_ckpt or os.path.join(ROOT_DIR, "checkpoints", "direction_oracle.pt")
|
| 163 |
+
oracle_tr2d2 = args.direction_oracle_tr2d2_checkpoint or os.path.join(ROOT_DIR, "checkpoints", "pretrained.ckpt")
|
| 164 |
+
|
| 165 |
+
records = []
|
| 166 |
+
|
| 167 |
+
for tidx, target in enumerate(targets):
|
| 168 |
+
for d_star, d_name in [(1.0, "agonist"), (-1.0, "antagonist")]:
|
| 169 |
+
logger.info(f"[{tidx+1}/{len(targets)}] Target {target['target_uid']} direction={d_name}")
|
| 170 |
+
|
| 171 |
+
# Create reward function
|
| 172 |
+
try:
|
| 173 |
+
reward_model = create_reward_function(
|
| 174 |
+
base_path=ROOT_DIR,
|
| 175 |
+
tokenizer=tokenizer,
|
| 176 |
+
target_protein_seq=target["target_seq"],
|
| 177 |
+
target_direction="agonist" if d_star > 0 else "antagonist",
|
| 178 |
+
device=device,
|
| 179 |
+
emb_model=model.backbone,
|
| 180 |
+
directional_oracle_checkpoint=oracle_ckpt,
|
| 181 |
+
direction_oracle_tr2d2_checkpoint=oracle_tr2d2,
|
| 182 |
+
)
|
| 183 |
+
except Exception as e:
|
| 184 |
+
logger.warning(f"Failed to create reward for {target['target_uid']}: {e}")
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
# Generate pool of candidates
|
| 188 |
+
target_length = target.get("seq_length", 200)
|
| 189 |
+
x_pool = sample_sequences(model, args.num_pool, target_length, 128)
|
| 190 |
+
sequences = tokenizer.batch_decode(x_pool)
|
| 191 |
+
|
| 192 |
+
# Check validity
|
| 193 |
+
valid_mask = np.array([analyzer.is_peptide(seq) for seq in sequences])
|
| 194 |
+
|
| 195 |
+
# Score all
|
| 196 |
+
gated_rewards, affinities, directions, confidences = score_sequences(reward_model, sequences)
|
| 197 |
+
direction_accuracy = ((directions > 0.5).astype(float) if d_star > 0
|
| 198 |
+
else (directions < 0.5).astype(float))
|
| 199 |
+
|
| 200 |
+
# Weighted resampling (Algorithm 2)
|
| 201 |
+
finite = np.isfinite(gated_rewards)
|
| 202 |
+
if finite.any():
|
| 203 |
+
rewards_t = torch.as_tensor(gated_rewards[finite], device=device)
|
| 204 |
+
alpha = max(args.resample_alpha, 1e-6)
|
| 205 |
+
weights = torch.softmax(rewards_t / alpha, dim=0)
|
| 206 |
+
idx = torch.multinomial(weights, num_samples=args.val_samples_per_target, replacement=True)
|
| 207 |
+
valid_idx = np.where(finite)[0]
|
| 208 |
+
chosen = valid_idx[idx.cpu().numpy()]
|
| 209 |
+
else:
|
| 210 |
+
chosen = np.arange(min(args.val_samples_per_target, len(sequences)))
|
| 211 |
+
|
| 212 |
+
# Save only VALID resampled samples
|
| 213 |
+
for i in chosen:
|
| 214 |
+
is_valid = bool(valid_mask[i]) if valid_mask.size else False
|
| 215 |
+
if not is_valid:
|
| 216 |
+
continue # Skip invalid samples
|
| 217 |
+
|
| 218 |
+
records.append({
|
| 219 |
+
"target": target["target_seq"][:20],
|
| 220 |
+
"target_uid": target["target_uid"],
|
| 221 |
+
"sequence": sequences[i],
|
| 222 |
+
"target_direction": d_star,
|
| 223 |
+
"direction_name": d_name,
|
| 224 |
+
"is_valid": True,
|
| 225 |
+
"affinity": float(affinities[i]),
|
| 226 |
+
"gated_reward": float(gated_rewards[i]),
|
| 227 |
+
"direction_oracle": float(directions[i]),
|
| 228 |
+
"direction_accuracy": float(direction_accuracy[i]),
|
| 229 |
+
})
|
| 230 |
+
|
| 231 |
+
# Save results
|
| 232 |
+
out_df = pd.DataFrame(records)
|
| 233 |
+
out_path = os.path.join(args.save_path, f"td3b_results_seed{args.seed}.csv")
|
| 234 |
+
out_df.to_csv(out_path, index=False)
|
| 235 |
+
|
| 236 |
+
# Print summary
|
| 237 |
+
if len(out_df) > 0:
|
| 238 |
+
dp = out_df[out_df["target_direction"] == 1.0]
|
| 239 |
+
dm = out_df[out_df["target_direction"] == -1.0]
|
| 240 |
+
logger.info(f"\n{'='*60}")
|
| 241 |
+
logger.info(f"Results saved to {out_path} ({len(out_df)} valid samples)")
|
| 242 |
+
logger.info(f" Aff(d*=+1) = {dp['affinity'].mean():.2f}" if len(dp) else " No agonist samples")
|
| 243 |
+
logger.info(f" Aff(d*=-1) = {dm['affinity'].mean():.2f}" if len(dm) else " No antagonist samples")
|
| 244 |
+
logger.info(f" DA(d*=+1) = {dp['direction_accuracy'].mean():.3f}" if len(dp) else "")
|
| 245 |
+
logger.info(f" DA(d*=-1) = {dm['direction_accuracy'].mean():.3f}" if len(dm) else "")
|
| 246 |
+
logger.info(f" Gated Reward = {out_df['gated_reward'].mean():.2f}")
|
| 247 |
+
logger.info(f"{'='*60}")
|
| 248 |
+
else:
|
| 249 |
+
logger.warning("No valid samples generated.")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
if __name__ == "__main__":
|
| 253 |
+
main()
|
launch_multi_target.sh
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Multi-Target TD3B Training Launch Script
|
| 4 |
+
# Trains TD3B on multiple protein targets with random sampling strategy
|
| 5 |
+
|
| 6 |
+
# ============================================================================
|
| 7 |
+
# Configuration
|
| 8 |
+
# ============================================================================
|
| 9 |
+
|
| 10 |
+
# Paths — update these to your local paths
|
| 11 |
+
BASE_PATH="/path/to/TD3B"
|
| 12 |
+
PRETRAINED_CHECKPOINT="${BASE_PATH}/checkpoints/pretrained.ckpt"
|
| 13 |
+
TRAIN_CSV="${BASE_PATH}/data/train.csv"
|
| 14 |
+
VAL_CSV="${BASE_PATH}/data/test.csv" # Optional: create validation split
|
| 15 |
+
|
| 16 |
+
# Run configuration
|
| 17 |
+
RUN_NAME="multi_target_td3b" # Timestamp will be added automatically
|
| 18 |
+
DEVICE="cuda:0"
|
| 19 |
+
# Multi-target sampling
|
| 20 |
+
TARGETS_PER_MCTS=2 # Number of targets sampled per MCTS round (K)
|
| 21 |
+
RESAMPLE_TARGETS_EVERY=1 # Resample targets every N epochs
|
| 22 |
+
|
| 23 |
+
# Training hyperparameters
|
| 24 |
+
NUM_EPOCHS=200
|
| 25 |
+
LEARNING_RATE=3e-4
|
| 26 |
+
TRAIN_BATCH_SIZE=1 # Small batch size to prevent OOM
|
| 27 |
+
GRADIENT_ACCUMULATION_STEPS=32 # Effective batch size = 16 * 4 = 64
|
| 28 |
+
RESAMPLE_EVERY=10 # Run MCTS every N epochs
|
| 29 |
+
SAVE_EVERY=20
|
| 30 |
+
VALIDATE_EVERY=20
|
| 31 |
+
RESET_TREE_EVERY=50
|
| 32 |
+
|
| 33 |
+
# MCTS hyperparameters (aligned with v1, but can reduce for multi-target)
|
| 34 |
+
NUM_ITER=20 # MCTS iterations per resample (v1 default: 50, reduced for multi-target)
|
| 35 |
+
NUM_CHILDREN=16 # Children per MCTS expansion
|
| 36 |
+
BUFFER_SIZE=50 # Pareto buffer size (v1 default: 50)
|
| 37 |
+
REPLAY_BUFFER_SIZE=1000 # Recommended range: 500-5000 (0 disables replay)
|
| 38 |
+
REPLAY_BUFFER_STRATEGY="fifo" # fifo or random
|
| 39 |
+
ALPHA=0.1 # Temperature for importance weighting
|
| 40 |
+
EXPLORATION=1.0 # UCB exploration constant
|
| 41 |
+
|
| 42 |
+
# TD3B hyperparameters (aligned with v1 defaults)
|
| 43 |
+
CONTRASTIVE_WEIGHT=0.1 # v1 default: 0.1
|
| 44 |
+
CONTRASTIVE_MARGIN=1.0
|
| 45 |
+
KL_BETA=0.1 # v1 default: 0.1
|
| 46 |
+
MIN_AFFINITY_THRESHOLD=0.0 # CRITICAL: minimum affinity for allosteric control
|
| 47 |
+
SIGMOID_TEMPERATURE=0.1
|
| 48 |
+
|
| 49 |
+
# Validation
|
| 50 |
+
VAL_SAMPLES_PER_TARGET=20 # Number of sequences per target during validation
|
| 51 |
+
|
| 52 |
+
# Directional oracle (GPCR classifier)
|
| 53 |
+
ORACLE_CKPT="${BASE_PATH}/checkpoints/direction_oracle.pt"
|
| 54 |
+
ORACLE_TR2D2_CHECKPOINT="${BASE_PATH}/checkpoints/pretrained.ckpt"
|
| 55 |
+
ORACLE_TOKENIZER_VOCAB="${BASE_PATH}/tokenizer/new_vocab.txt"
|
| 56 |
+
ORACLE_TOKENIZER_SPLITS="${BASE_PATH}/tokenizer/new_splits.txt"
|
| 57 |
+
ORACLE_ESM_NAME="facebook/esm2_t33_650M_UR50D"
|
| 58 |
+
ORACLE_ESM_CACHE_DIR="" # Optional: set to a cache dir path
|
| 59 |
+
ORACLE_ESM_LOCAL_FILES_ONLY=0 # Set to 1 to avoid network access
|
| 60 |
+
ORACLE_MAX_LIGAND_LENGTH=768
|
| 61 |
+
ORACLE_MAX_PROTEIN_LENGTH=1024
|
| 62 |
+
ORACLE_D_MODEL=256
|
| 63 |
+
ORACLE_N_HEADS=4
|
| 64 |
+
ORACLE_N_SELF_ATTN_LAYERS=1
|
| 65 |
+
ORACLE_N_BMCA_LAYERS=2
|
| 66 |
+
ORACLE_DROPOUT=0.3
|
| 67 |
+
|
| 68 |
+
EXTRA_ORACLE_ARGS=""
|
| 69 |
+
if [ -n "$ORACLE_ESM_CACHE_DIR" ]; then
|
| 70 |
+
EXTRA_ORACLE_ARGS="$EXTRA_ORACLE_ARGS --direction_oracle_esm_cache_dir $ORACLE_ESM_CACHE_DIR"
|
| 71 |
+
fi
|
| 72 |
+
if [ "$ORACLE_ESM_LOCAL_FILES_ONLY" -eq 1 ]; then
|
| 73 |
+
EXTRA_ORACLE_ARGS="$EXTRA_ORACLE_ARGS --direction_oracle_esm_local_files_only"
|
| 74 |
+
fi
|
| 75 |
+
|
| 76 |
+
# W&B (optional)
|
| 77 |
+
WANDB_PROJECT="tr2d2-multi-target"
|
| 78 |
+
WANDB_ENTITY="phos_zj"
|
| 79 |
+
|
| 80 |
+
# ============================================================================
|
| 81 |
+
# Launch Training
|
| 82 |
+
# ============================================================================
|
| 83 |
+
|
| 84 |
+
cd ${BASE_PATH}
|
| 85 |
+
|
| 86 |
+
echo "============================================================================"
|
| 87 |
+
echo "Multi-Target TD3B Training"
|
| 88 |
+
echo "============================================================================"
|
| 89 |
+
echo "Configuration:"
|
| 90 |
+
echo " - Targets per MCTS: ${TARGETS_PER_MCTS}"
|
| 91 |
+
echo " - Training batch size: ${TRAIN_BATCH_SIZE}"
|
| 92 |
+
echo " - Gradient accumulation: ${GRADIENT_ACCUMULATION_STEPS}"
|
| 93 |
+
echo " - Effective batch size: $((TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS))"
|
| 94 |
+
echo " - Epochs: ${NUM_EPOCHS}"
|
| 95 |
+
echo " - MCTS iterations: ${NUM_ITER}"
|
| 96 |
+
echo " - MCTS children: ${NUM_CHILDREN}"
|
| 97 |
+
echo " - Buffer size: ${BUFFER_SIZE}"
|
| 98 |
+
echo " - Replay buffer size: ${REPLAY_BUFFER_SIZE} (${REPLAY_BUFFER_STRATEGY})"
|
| 99 |
+
echo "============================================================================"
|
| 100 |
+
echo ""
|
| 101 |
+
|
| 102 |
+
# Build command
|
| 103 |
+
CMD="python finetune_multi_target.py \
|
| 104 |
+
--base_path ${BASE_PATH} \
|
| 105 |
+
--train_csv ${TRAIN_CSV} \
|
| 106 |
+
--pretrained_checkpoint ${PRETRAINED_CHECKPOINT} \
|
| 107 |
+
--run_name ${RUN_NAME} \
|
| 108 |
+
--device ${DEVICE} \
|
| 109 |
+
\
|
| 110 |
+
--targets_per_mcts ${TARGETS_PER_MCTS} \
|
| 111 |
+
--resample_targets_every ${RESAMPLE_TARGETS_EVERY} \
|
| 112 |
+
\
|
| 113 |
+
--num_epochs ${NUM_EPOCHS} \
|
| 114 |
+
--learning_rate ${LEARNING_RATE} \
|
| 115 |
+
--train_batch_size ${TRAIN_BATCH_SIZE} \
|
| 116 |
+
--gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \
|
| 117 |
+
--resample_every_n_step ${RESAMPLE_EVERY} \
|
| 118 |
+
--save_every_n_epochs ${SAVE_EVERY} \
|
| 119 |
+
--validate_every_n_epochs ${VALIDATE_EVERY} \
|
| 120 |
+
--reset_every_n_step ${RESET_TREE_EVERY} \
|
| 121 |
+
\
|
| 122 |
+
--num_iter ${NUM_ITER} \
|
| 123 |
+
--num_children ${NUM_CHILDREN} \
|
| 124 |
+
--buffer_size ${BUFFER_SIZE} \
|
| 125 |
+
--replay_buffer_size ${REPLAY_BUFFER_SIZE} \
|
| 126 |
+
--replay_buffer_strategy ${REPLAY_BUFFER_STRATEGY} \
|
| 127 |
+
--alpha ${ALPHA} \
|
| 128 |
+
--exploration ${EXPLORATION} \
|
| 129 |
+
\
|
| 130 |
+
--contrastive_weight ${CONTRASTIVE_WEIGHT} \
|
| 131 |
+
--contrastive_margin ${CONTRASTIVE_MARGIN} \
|
| 132 |
+
--kl_beta ${KL_BETA} \
|
| 133 |
+
--min_affinity_threshold ${MIN_AFFINITY_THRESHOLD} \
|
| 134 |
+
--sigmoid_temperature ${SIGMOID_TEMPERATURE} \
|
| 135 |
+
\
|
| 136 |
+
--direction_oracle_ckpt ${ORACLE_CKPT} \
|
| 137 |
+
--direction_oracle_tr2d2_checkpoint ${ORACLE_TR2D2_CHECKPOINT} \
|
| 138 |
+
--direction_oracle_tokenizer_vocab ${ORACLE_TOKENIZER_VOCAB} \
|
| 139 |
+
--direction_oracle_tokenizer_splits ${ORACLE_TOKENIZER_SPLITS} \
|
| 140 |
+
--direction_oracle_esm_name ${ORACLE_ESM_NAME} \
|
| 141 |
+
--direction_oracle_max_ligand_length ${ORACLE_MAX_LIGAND_LENGTH} \
|
| 142 |
+
--direction_oracle_max_protein_length ${ORACLE_MAX_PROTEIN_LENGTH} \
|
| 143 |
+
--direction_oracle_d_model ${ORACLE_D_MODEL} \
|
| 144 |
+
--direction_oracle_n_heads ${ORACLE_N_HEADS} \
|
| 145 |
+
--direction_oracle_n_self_attn_layers ${ORACLE_N_SELF_ATTN_LAYERS} \
|
| 146 |
+
--direction_oracle_n_bmca_layers ${ORACLE_N_BMCA_LAYERS} \
|
| 147 |
+
--direction_oracle_dropout ${ORACLE_DROPOUT} \
|
| 148 |
+
${EXTRA_ORACLE_ARGS} \
|
| 149 |
+
\
|
| 150 |
+
--val_samples_per_target ${VAL_SAMPLES_PER_TARGET} \
|
| 151 |
+
\
|
| 152 |
+
--grad_clip \
|
| 153 |
+
--gradnorm_clip 1.0 \
|
| 154 |
+
--wandb_project ${WANDB_PROJECT}"
|
| 155 |
+
|
| 156 |
+
# Add validation CSV if it exists
|
| 157 |
+
if [ -f "${VAL_CSV}" ]; then
|
| 158 |
+
CMD="${CMD} --val_csv ${VAL_CSV}"
|
| 159 |
+
echo "Validation CSV: ${VAL_CSV}"
|
| 160 |
+
else
|
| 161 |
+
echo "No validation CSV found (${VAL_CSV})"
|
| 162 |
+
echo "Skipping validation during training"
|
| 163 |
+
fi
|
| 164 |
+
|
| 165 |
+
# Add W&B entity if specified
|
| 166 |
+
if [ -n "${WANDB_ENTITY}" ]; then
|
| 167 |
+
CMD="${CMD} --wandb_entity ${WANDB_ENTITY}"
|
| 168 |
+
fi
|
| 169 |
+
|
| 170 |
+
echo ""
|
| 171 |
+
echo "Launching training..."
|
| 172 |
+
echo ""
|
| 173 |
+
|
| 174 |
+
# Execute
|
| 175 |
+
eval $CMD
|
noise_schedule.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
torch._C._jit_set_profiling_mode(False)
|
| 7 |
+
torch._C._jit_set_profiling_executor(False)
|
| 8 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
| 9 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
| 10 |
+
|
| 11 |
+
def get_noise(config, dtype=torch.float32):
|
| 12 |
+
if config.noise.type == 'geometric':
|
| 13 |
+
return GeometricNoise(config.noise.sigma_min, config.noise.sigma_max)
|
| 14 |
+
elif config.noise.type == 'loglinear':
|
| 15 |
+
return LogLinearNoise()
|
| 16 |
+
elif config.noise.type == 'cosine':
|
| 17 |
+
return CosineNoise()
|
| 18 |
+
elif config.noise.type == 'cosinesqr':
|
| 19 |
+
return CosineSqrNoise()
|
| 20 |
+
elif config.noise.type == 'linear':
|
| 21 |
+
return Linear(config.noise.sigma_min, config.noise.sigma_max, dtype)
|
| 22 |
+
else:
|
| 23 |
+
raise ValueError(f'{config.noise.type} is not a valid noise')
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def binary_discretization(z):
|
| 27 |
+
z_hard = torch.sign(z)
|
| 28 |
+
z_soft = z / torch.norm(z, dim=-1, keepdim=True)
|
| 29 |
+
return z_soft + (z_hard - z_soft).detach()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Noise(abc.ABC, nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Baseline forward method to get the total + rate of noise at a timestep
|
| 35 |
+
"""
|
| 36 |
+
def forward(self, t):
|
| 37 |
+
# Assume time goes from 0 to 1
|
| 38 |
+
return self.total_noise(t), self.rate_noise(t)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class CosineNoise(Noise):
|
| 42 |
+
def __init__(self, eps=1e-3):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.eps = eps
|
| 45 |
+
|
| 46 |
+
def rate_noise(self, t):
|
| 47 |
+
cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
|
| 48 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
|
| 49 |
+
scale = torch.pi / 2
|
| 50 |
+
return scale * sin / (cos + self.eps)
|
| 51 |
+
|
| 52 |
+
def total_noise(self, t):
|
| 53 |
+
cos = torch.cos(t * torch.pi / 2)
|
| 54 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class CosineSqrNoise(Noise):
|
| 58 |
+
def __init__(self, eps=1e-3):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.eps = eps
|
| 61 |
+
|
| 62 |
+
def rate_noise(self, t):
|
| 63 |
+
cos = (1 - self.eps) * (
|
| 64 |
+
torch.cos(t * torch.pi / 2) ** 2)
|
| 65 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi)
|
| 66 |
+
scale = torch.pi / 2
|
| 67 |
+
return scale * sin / (cos + self.eps)
|
| 68 |
+
|
| 69 |
+
def total_noise(self, t):
|
| 70 |
+
cos = torch.cos(t * torch.pi / 2) ** 2
|
| 71 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class Linear(Noise):
|
| 75 |
+
def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
|
| 78 |
+
self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
|
| 79 |
+
|
| 80 |
+
def rate_noise(self):
|
| 81 |
+
return self.sigma_max - self.sigma_min
|
| 82 |
+
|
| 83 |
+
def total_noise(self, t):
|
| 84 |
+
return self.sigma_min + t * (self.sigma_max - self.sigma_min)
|
| 85 |
+
|
| 86 |
+
def importance_sampling_transformation(self, t):
|
| 87 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
| 88 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
| 89 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
| 90 |
+
return (sigma_t - self.sigma_min) / (
|
| 91 |
+
self.sigma_max - self.sigma_min)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class GeometricNoise(Noise):
|
| 95 |
+
def __init__(self, sigma_min=1e-3, sigma_max=1):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
|
| 98 |
+
|
| 99 |
+
def rate_noise(self, t):
|
| 100 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
|
| 101 |
+
self.sigmas[1].log() - self.sigmas[0].log())
|
| 102 |
+
|
| 103 |
+
def total_noise(self, t):
|
| 104 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class LogLinearNoise(Noise):
|
| 108 |
+
"""Log Linear noise schedule.
|
| 109 |
+
|
| 110 |
+
Built such that 1 - 1/e^(n(t)) interpolates between 0 and
|
| 111 |
+
~1 when t varies from 0 to 1. Total noise is
|
| 112 |
+
-log(1 - (1 - eps) * t), so the sigma will be
|
| 113 |
+
(1 - eps) * t.
|
| 114 |
+
"""
|
| 115 |
+
def __init__(self, eps=1e-3):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.eps = eps
|
| 118 |
+
self.sigma_max = self.total_noise(torch.tensor(1.0))
|
| 119 |
+
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
|
| 120 |
+
|
| 121 |
+
def rate_noise(self, t):
|
| 122 |
+
return (1 - self.eps) / (1 - (1 - self.eps) * t)
|
| 123 |
+
|
| 124 |
+
def total_noise(self, t):
|
| 125 |
+
return -torch.log1p(-(1 - self.eps) * t)
|
| 126 |
+
|
| 127 |
+
def importance_sampling_transformation(self, t):
|
| 128 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
| 129 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
| 130 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
| 131 |
+
t = - torch.expm1(- sigma_t) / (1 - self.eps)
|
| 132 |
+
return t
|
| 133 |
+
|
| 134 |
+
class LogPolyNoise(Noise):
|
| 135 |
+
"""
|
| 136 |
+
Log Polynomial noise schedule for slower masking of peptide bond tokens
|
| 137 |
+
"""
|
| 138 |
+
def __init__(self, eps=1e-3):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.eps = eps
|
| 141 |
+
self.sigma_max = self.total_noise(torch.tensor(1.0))
|
| 142 |
+
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
|
| 143 |
+
|
| 144 |
+
def rate_noise(self, t):
|
| 145 |
+
# derivative of -log(1-t^w)
|
| 146 |
+
return ((3 * (t**2)) - self.eps) / (1 - (1 - self.eps) * (t**3))
|
| 147 |
+
|
| 148 |
+
def total_noise(self, t):
|
| 149 |
+
# -log(1-t^w)
|
| 150 |
+
return -torch.log1p(-(1 - self.eps) * (t**3))
|
peptide_mcts.py
ADDED
|
@@ -0,0 +1,676 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random as rd
|
| 6 |
+
from utils.app import PeptideAnalyzer
|
| 7 |
+
from utils.timer import StepTimer
|
| 8 |
+
from scoring.scoring_functions import ScoringFunctions
|
| 9 |
+
|
| 10 |
+
import noise_schedule
|
| 11 |
+
|
| 12 |
+
### for peptide multi-objective ###
|
| 13 |
+
def dominates(a, b):
|
| 14 |
+
a = np.asarray(a); b = np.asarray(b)
|
| 15 |
+
return np.all(a >= b) and np.any(a > b)
|
| 16 |
+
|
| 17 |
+
def dominated_by(a, b):
|
| 18 |
+
return dominates(b, a)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def updateParetoFront(paretoFront, node, scoreVector, totalSize=None, eps=1e-12):
|
| 22 |
+
"""
|
| 23 |
+
Maintain a non-dominated set (Pareto front) of (node -> scoreVector).
|
| 24 |
+
|
| 25 |
+
- Accept 'node' iff it is NOT dominated by any node in the set.
|
| 26 |
+
- Remove any nodes that ARE dominated by 'node'.
|
| 27 |
+
- Skip insertion if an equal point already exists (within eps).
|
| 28 |
+
- If totalSize is given and the archive exceeds it, drop the item
|
| 29 |
+
with the smallest sum(scoreVector) as a simple tie-breaker.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
paretoFront (dict): {node: scoreVector}
|
| 33 |
+
node: candidate node (used as dict key)
|
| 34 |
+
scoreVector (array-like): candidate scores (to be maximized)
|
| 35 |
+
totalSize (int|None): optional max size for the archive
|
| 36 |
+
eps (float): tolerance for equality/inequality checks
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
dict: updated paretoFront
|
| 40 |
+
"""
|
| 41 |
+
s = np.asarray(scoreVector, dtype=float)
|
| 42 |
+
|
| 43 |
+
def dominates(a, b):
|
| 44 |
+
# a >= b in all coords and > in at least one (with tolerance)
|
| 45 |
+
return np.all(a >= b - eps) and np.any(a > b + eps)
|
| 46 |
+
|
| 47 |
+
def equal(a, b):
|
| 48 |
+
return np.all(np.abs(a - b) <= eps)
|
| 49 |
+
|
| 50 |
+
# reject if candidate is dominated by any node already in the set
|
| 51 |
+
for v in paretoFront.values():
|
| 52 |
+
v = np.asarray(v, dtype=float)
|
| 53 |
+
if dominates(v, s):
|
| 54 |
+
return paretoFront # no change
|
| 55 |
+
|
| 56 |
+
# remove any nodes dominated by candidate node
|
| 57 |
+
survivors = {}
|
| 58 |
+
#has_equal = False
|
| 59 |
+
for k, v in paretoFront.items():
|
| 60 |
+
v_arr = np.asarray(v, dtype=float)
|
| 61 |
+
if dominates(s, v_arr):
|
| 62 |
+
continue # drop dominated incumbent
|
| 63 |
+
"""if equal(s, v_arr):
|
| 64 |
+
has_equal = True # skip duplicate insertion later"""
|
| 65 |
+
survivors[k] = v_arr
|
| 66 |
+
|
| 67 |
+
# if an equal point exists, keep survivors as-is (no duplicate)
|
| 68 |
+
"""if has_equal:
|
| 69 |
+
return survivors"""
|
| 70 |
+
|
| 71 |
+
# insert node
|
| 72 |
+
survivors[node] = s
|
| 73 |
+
|
| 74 |
+
# delete nodes if larger than total size
|
| 75 |
+
if totalSize is not None and totalSize > 0 and len(survivors) > totalSize:
|
| 76 |
+
# remove the item with the smallest sum(scoreVector)
|
| 77 |
+
keys = list(survivors.keys())
|
| 78 |
+
sums = np.array([np.sum(np.asarray(survivors[k], dtype=float)) for k in keys])
|
| 79 |
+
drop_idx = int(np.argmin(sums))
|
| 80 |
+
del survivors[keys[drop_idx]]
|
| 81 |
+
|
| 82 |
+
return survivors
|
| 83 |
+
|
| 84 |
+
### BEGINNING OF NODE CLASS ###
|
| 85 |
+
|
| 86 |
+
class Node:
|
| 87 |
+
"""
|
| 88 |
+
Node class: partially unmasked sequence
|
| 89 |
+
- parentNode: Node object at previous time step
|
| 90 |
+
- childNodes: set of M Node objects generated from sampling M distinct unmasking schemes
|
| 91 |
+
- totalReward: vector of cumulative rewards for all K objectives
|
| 92 |
+
- visits: number of times the node has been visited by an interation
|
| 93 |
+
- path: array of partially unmasked SMILES strings leading to the node from the completely masked root node
|
| 94 |
+
- timestep: the time step where the sequence was sampled
|
| 95 |
+
"""
|
| 96 |
+
def __init__(self, args, tokens=None, log_rnd=None, log_policy_step=None, log_pretrained_step=None, parentNode=None, childNodes=None, totalReward=None, timestep=None):
|
| 97 |
+
self.args = args
|
| 98 |
+
self.parentNode = parentNode
|
| 99 |
+
# fixed child node list creation
|
| 100 |
+
self.childNodes = [] if childNodes is None else childNodes
|
| 101 |
+
|
| 102 |
+
self.log_rnd = log_rnd # stores the log_rnd up to that step
|
| 103 |
+
|
| 104 |
+
#self.log_p0 = 0 # stores the log probabiltiy of the unmasking step from the previous iteration
|
| 105 |
+
self.log_policy_step = log_policy_step # stores the log probability of the unmasking step under the current policy
|
| 106 |
+
self.log_pretrained_step = log_pretrained_step
|
| 107 |
+
|
| 108 |
+
# initialize total rewards to the reward of the roll out unmasked sequence
|
| 109 |
+
if totalReward is not None:
|
| 110 |
+
self.totalReward = totalReward # potential reward of the node based on generated children
|
| 111 |
+
else:
|
| 112 |
+
self.totalReward = np.zeros(self.args.num_obj)
|
| 113 |
+
|
| 114 |
+
# set initial visits to 1
|
| 115 |
+
self.visits = 1
|
| 116 |
+
|
| 117 |
+
# set timestep (value between 0 and num_steps)
|
| 118 |
+
self.timestep = timestep
|
| 119 |
+
|
| 120 |
+
# dict with 'seqs' as token array and 'attention_mask'
|
| 121 |
+
self.tokens = tokens
|
| 122 |
+
|
| 123 |
+
def selectNode(self):
|
| 124 |
+
"""
|
| 125 |
+
Selects a node to move to among the children nodes based on select score
|
| 126 |
+
"""
|
| 127 |
+
# extract the status of the current node
|
| 128 |
+
nodeStatus = self.getExpandStatus()
|
| 129 |
+
|
| 130 |
+
# if the node is a legal non-leaf node
|
| 131 |
+
if (nodeStatus == 3):
|
| 132 |
+
# initialize array that will store select score vectors of each child node
|
| 133 |
+
|
| 134 |
+
paretoFront = {}
|
| 135 |
+
|
| 136 |
+
for childNode in self.childNodes:
|
| 137 |
+
childStatus = childNode.getExpandStatus()
|
| 138 |
+
# only append child if it is legal leaf node (expandable) or legal non-leaf node
|
| 139 |
+
if childStatus == 2 or childStatus == 3:
|
| 140 |
+
selectScore = childNode.calcSelectScore()
|
| 141 |
+
paretoFront = updateParetoFront(paretoFront, childNode, selectScore)
|
| 142 |
+
|
| 143 |
+
selected = rd.choice(list(paretoFront.keys()))
|
| 144 |
+
|
| 145 |
+
# return selected child node and status
|
| 146 |
+
return selected, selected.getExpandStatus()
|
| 147 |
+
|
| 148 |
+
# if node is not valid non-leaf node
|
| 149 |
+
return self, nodeStatus
|
| 150 |
+
|
| 151 |
+
def addChildNode(self, tokens, log_rnd, log_policy_step, log_pretrained_step, totalReward):
|
| 152 |
+
""""
|
| 153 |
+
Adds a child node:
|
| 154 |
+
log_rnd: log_rnd of the path up to the added child node
|
| 155 |
+
log_policy_step: scalar value of the log-prob of sampling the step under the policy
|
| 156 |
+
log_pretrained_step: scalar value of the log-prob of sampling the step under the pretrained model
|
| 157 |
+
"""
|
| 158 |
+
child = Node(args=self.args,
|
| 159 |
+
tokens=tokens,
|
| 160 |
+
log_rnd = log_rnd,
|
| 161 |
+
log_policy_step=log_policy_step,
|
| 162 |
+
log_pretrained_step=log_pretrained_step,
|
| 163 |
+
parentNode=self,
|
| 164 |
+
childNodes=[],
|
| 165 |
+
totalReward=totalReward,
|
| 166 |
+
timestep=self.timestep+1)
|
| 167 |
+
|
| 168 |
+
self.childNodes.append(child)
|
| 169 |
+
return child
|
| 170 |
+
|
| 171 |
+
def update_logrnd(self, log_policy_step, log_rnd):
|
| 172 |
+
self.log_policy_step = log_policy_step
|
| 173 |
+
self.log_rnd = log_rnd
|
| 174 |
+
|
| 175 |
+
def updateNode(self, rewards):
|
| 176 |
+
"""
|
| 177 |
+
Updates the cumulative rewards vector with the reward vector at a descendent leaf node.
|
| 178 |
+
Increments the number of visits to the node.
|
| 179 |
+
"""
|
| 180 |
+
self.visits += 1
|
| 181 |
+
|
| 182 |
+
self.totalReward += rewards # singleton tensor
|
| 183 |
+
|
| 184 |
+
def calcSelectScore(self):
|
| 185 |
+
"""
|
| 186 |
+
Calculates the select score for the node from the cumulative rewards vector and number of visits.
|
| 187 |
+
- c: determines the degree of exploration
|
| 188 |
+
- minSelectScore: determines the
|
| 189 |
+
"""
|
| 190 |
+
scaling = 0.1 # scaling of the second term in the select score
|
| 191 |
+
|
| 192 |
+
# K-dimensional vector of normalized rewards for each objective
|
| 193 |
+
normRewards = self.totalReward / self.visits
|
| 194 |
+
|
| 195 |
+
# scales the cumulative reward by the sampling probability
|
| 196 |
+
|
| 197 |
+
return normRewards + (scaling * self.log_policy_step.detach().cpu().item() * np.sqrt(self.parentNode.visits) / self.visits)
|
| 198 |
+
|
| 199 |
+
def getExpandStatus(self):
|
| 200 |
+
"""
|
| 201 |
+
Returns an integer indicating whether the node is a:
|
| 202 |
+
1. terminal node (sequence is fully unmasked)
|
| 203 |
+
2. legal leaf node (partially unmasked sequence that can be expanded)
|
| 204 |
+
3. legal non-leaf node (already expanded sequence with M child nodes)
|
| 205 |
+
"""
|
| 206 |
+
if self.timestep == self.args.total_num_steps:
|
| 207 |
+
return 1
|
| 208 |
+
elif (self.timestep < self.args.total_num_steps) and (len(self.childNodes) == 0):
|
| 209 |
+
return 2
|
| 210 |
+
return 3
|
| 211 |
+
|
| 212 |
+
### END OF NODE CLASS ###
|
| 213 |
+
|
| 214 |
+
### BEGINNING OF MCTS CLASS ###
|
| 215 |
+
|
| 216 |
+
class MCTS:
|
| 217 |
+
def __init__(
|
| 218 |
+
self,
|
| 219 |
+
args,
|
| 220 |
+
config,
|
| 221 |
+
policy_model,
|
| 222 |
+
pretrained,
|
| 223 |
+
score_func_names=None,
|
| 224 |
+
prot_seqs=None,
|
| 225 |
+
rootNode=None,
|
| 226 |
+
reward_func=None,
|
| 227 |
+
num_obj=None,
|
| 228 |
+
):
|
| 229 |
+
self.timer = StepTimer(policy_model.device)
|
| 230 |
+
|
| 231 |
+
self.device = policy_model.device
|
| 232 |
+
|
| 233 |
+
self.args = args
|
| 234 |
+
self.config = config
|
| 235 |
+
self.noise = noise_schedule.get_noise(config)
|
| 236 |
+
self.time_conditioning = args.time_conditioning
|
| 237 |
+
|
| 238 |
+
if score_func_names is None:
|
| 239 |
+
score_func_names = []
|
| 240 |
+
if num_obj is None:
|
| 241 |
+
num_obj = getattr(reward_func, "num_obj", None)
|
| 242 |
+
self.num_obj = num_obj if num_obj is not None else len(score_func_names)
|
| 243 |
+
|
| 244 |
+
self.mask_index = policy_model.mask_index
|
| 245 |
+
masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index
|
| 246 |
+
masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)}
|
| 247 |
+
if rootNode is None:
|
| 248 |
+
self.rootNode = Node(self.args, tokens = masked_tokens,
|
| 249 |
+
log_rnd=torch.zeros((), device=self.device),
|
| 250 |
+
log_policy_step=torch.zeros((), device=self.device),
|
| 251 |
+
log_pretrained_step=torch.zeros((), device=self.device),
|
| 252 |
+
totalReward=np.zeros(self.num_obj), timestep=0)
|
| 253 |
+
else:
|
| 254 |
+
self.rootNode = rootNode # stores the root node of the tree
|
| 255 |
+
|
| 256 |
+
# dictionary:
|
| 257 |
+
# "seq": final unmasked sequence
|
| 258 |
+
# "traj": list of (N_steps, L)
|
| 259 |
+
# "reward": reward of the trajectory
|
| 260 |
+
self.buffer = [] # List[Dict[str, Any]]
|
| 261 |
+
|
| 262 |
+
self.buffer_size = args.buffer_size
|
| 263 |
+
|
| 264 |
+
self.num_steps = args.total_num_steps
|
| 265 |
+
#self.num_sequences = args.num_sequences
|
| 266 |
+
|
| 267 |
+
# pretrained model
|
| 268 |
+
self.pretrained = pretrained
|
| 269 |
+
|
| 270 |
+
# the policy model that we want to finetune
|
| 271 |
+
self.policy_model = policy_model
|
| 272 |
+
#self.tokenizer = policy_model.tokenizer
|
| 273 |
+
self.device = policy_model.device
|
| 274 |
+
|
| 275 |
+
self.sequence_length = args.seq_length
|
| 276 |
+
|
| 277 |
+
self.num_iter = args.num_iter
|
| 278 |
+
|
| 279 |
+
self.num_children = args.num_children
|
| 280 |
+
|
| 281 |
+
# score functions
|
| 282 |
+
|
| 283 |
+
if reward_func is None:
|
| 284 |
+
self.rewardFunc = ScoringFunctions(score_func_names, prot_seqs, device=args.device)
|
| 285 |
+
else:
|
| 286 |
+
self.rewardFunc = reward_func
|
| 287 |
+
|
| 288 |
+
self.iter_num = 0
|
| 289 |
+
|
| 290 |
+
self.reward_log = [] # stores scalarized total rewards
|
| 291 |
+
self.logrnd_log = []
|
| 292 |
+
# stores each objective
|
| 293 |
+
self.valid_fraction_log = []
|
| 294 |
+
self.affinity1_log = []
|
| 295 |
+
self.affinity2_log = []
|
| 296 |
+
self.permeability_log = []
|
| 297 |
+
self.sol_log = []
|
| 298 |
+
self.hemo_log = []
|
| 299 |
+
self.nf_log = []
|
| 300 |
+
|
| 301 |
+
self.policy_model.eval()
|
| 302 |
+
self.pretrained.eval()
|
| 303 |
+
|
| 304 |
+
# for peptides
|
| 305 |
+
self.analyzer = PeptideAnalyzer()
|
| 306 |
+
self.tokenizer = policy_model.tokenizer
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def reset(self, resetTree):
|
| 310 |
+
self.iter_num = 0
|
| 311 |
+
self.buffer = []
|
| 312 |
+
self.reward_log = []
|
| 313 |
+
self.logrnd_log = []
|
| 314 |
+
|
| 315 |
+
# reset logs for each objective
|
| 316 |
+
self.valid_fraction_log = []
|
| 317 |
+
self.affinity1_log = []
|
| 318 |
+
self.affinity2_log = []
|
| 319 |
+
self.permeability_log = []
|
| 320 |
+
self.sol_log = []
|
| 321 |
+
self.hemo_log = []
|
| 322 |
+
self.nf_log = []
|
| 323 |
+
|
| 324 |
+
# add option to continue with the same tree
|
| 325 |
+
if resetTree:
|
| 326 |
+
masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index
|
| 327 |
+
masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)}
|
| 328 |
+
self.rootNode = Node(self.args, tokens = masked_tokens,
|
| 329 |
+
log_rnd=torch.zeros((), device=self.device),
|
| 330 |
+
log_policy_step=torch.zeros((), device=self.device),
|
| 331 |
+
log_pretrained_step=torch.zeros((), device=self.device),
|
| 332 |
+
totalReward=np.zeros(self.num_obj), timestep=0)
|
| 333 |
+
|
| 334 |
+
def forward(self, resetTree=False):
|
| 335 |
+
|
| 336 |
+
self.reset(resetTree)
|
| 337 |
+
|
| 338 |
+
while (self.iter_num < self.num_iter):
|
| 339 |
+
self.iter_num += 1
|
| 340 |
+
|
| 341 |
+
# traverse the tree form the root node until a leaf node
|
| 342 |
+
with self.timer.section("select"):
|
| 343 |
+
leafNode, _ = self.select(self.rootNode)
|
| 344 |
+
|
| 345 |
+
# expand leaf node into num_children partially unmasked sequences at the next timestep
|
| 346 |
+
with self.timer.section("expand"):
|
| 347 |
+
self.expand(leafNode)
|
| 348 |
+
|
| 349 |
+
final_x, log_rnd, final_rewards, score_vectors, sequences = self.consolidateBuffer()
|
| 350 |
+
# return final_seqs (B, L), log_rnd (B, ), and final rewards (B, )
|
| 351 |
+
|
| 352 |
+
rows = self.timer.summary()
|
| 353 |
+
print("\n=== Timing summary (by total time) ===")
|
| 354 |
+
for name, cnt, total, mean, p50, p95 in rows:
|
| 355 |
+
print(f"{name:30s} n={cnt:5d} total={total:8.3f}s mean={mean*1e3:7.2f}ms "
|
| 356 |
+
f"p50={p50*1e3:7.2f}ms p95={p95*1e3:7.2f}ms")
|
| 357 |
+
|
| 358 |
+
return final_x, log_rnd, final_rewards, score_vectors, sequences
|
| 359 |
+
|
| 360 |
+
# new updateBuffer
|
| 361 |
+
def _debug_buffer_decision(self, sv, reason, extra=None):
|
| 362 |
+
if extra is None: extra = {}
|
| 363 |
+
print(f"[BUFFER] reason={reason} sv={np.round(sv,4)} "
|
| 364 |
+
f"buf_len={len(self.buffer)} extra={extra}")
|
| 365 |
+
|
| 366 |
+
def updateBuffer(self, x_final, log_rnd, score_vectors, childSequences):
|
| 367 |
+
B = x_final.shape[0]
|
| 368 |
+
traj_log_rnds, scalar_rewards = [], []
|
| 369 |
+
|
| 370 |
+
for i in range(B):
|
| 371 |
+
sv = np.asarray(score_vectors[i], dtype=float)
|
| 372 |
+
|
| 373 |
+
# determine how to scalarize the multi-objective rewards
|
| 374 |
+
if self.args.scalarization == "normalized":
|
| 375 |
+
pass
|
| 376 |
+
elif self.args.scalarization == "weighted":
|
| 377 |
+
pass
|
| 378 |
+
else:
|
| 379 |
+
scalar_reward = float(np.sum(sv))
|
| 380 |
+
|
| 381 |
+
traj_log_rnd = log_rnd[i] + (scalar_reward / self.args.alpha) # scale down by alpha
|
| 382 |
+
|
| 383 |
+
item = {
|
| 384 |
+
"x_final": x_final[i].clone(), # clone?
|
| 385 |
+
"log_rnd": traj_log_rnd.clone(),
|
| 386 |
+
"final_reward": scalar_reward,
|
| 387 |
+
"score_vector": sv.copy(),
|
| 388 |
+
"seq": childSequences[i],
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
# Drop if dominated by any existing
|
| 392 |
+
if any(dominated_by(sv, bi["score_vector"]) for bi in self.buffer):
|
| 393 |
+
# for debugging
|
| 394 |
+
self._debug_buffer_decision(sv, "rejected_dominated")
|
| 395 |
+
continue
|
| 396 |
+
|
| 397 |
+
# Remove any existing that this candidate dominates
|
| 398 |
+
keep = []
|
| 399 |
+
for bi in self.buffer:
|
| 400 |
+
if not dominates(sv, bi["score_vector"]):
|
| 401 |
+
keep.append(bi)
|
| 402 |
+
self.buffer = keep
|
| 403 |
+
|
| 404 |
+
# Insert with capacity rule
|
| 405 |
+
if len(self.buffer) < self.buffer_size:
|
| 406 |
+
self.buffer.append(item)
|
| 407 |
+
else:
|
| 408 |
+
# tie-breaker: replace the worst by a simple heuristic (min sum)
|
| 409 |
+
worst_i = int(np.argmin([np.sum(bi["score_vector"]) for bi in self.buffer]))
|
| 410 |
+
self.buffer[worst_i] = item
|
| 411 |
+
|
| 412 |
+
# for debugging
|
| 413 |
+
self._debug_buffer_decision(sv, "inserted", {"new_len": len(self.buffer)})
|
| 414 |
+
|
| 415 |
+
traj_log_rnds.append(traj_log_rnd)
|
| 416 |
+
scalar_rewards.append(scalar_reward)
|
| 417 |
+
|
| 418 |
+
traj_log_rnds = torch.stack(traj_log_rnds, dim=0) if traj_log_rnds else torch.empty(0)
|
| 419 |
+
scalar_rewards = np.asarray(scalar_rewards, dtype=float)
|
| 420 |
+
return traj_log_rnds, scalar_rewards
|
| 421 |
+
|
| 422 |
+
def consolidateBuffer(self):
|
| 423 |
+
"""
|
| 424 |
+
returns x_final, log_rnd, and final_rewards in tensors
|
| 425 |
+
"""
|
| 426 |
+
x_final = []
|
| 427 |
+
log_rnd = []
|
| 428 |
+
final_rewards = []
|
| 429 |
+
score_vectors = []
|
| 430 |
+
sequences = []
|
| 431 |
+
for item in self.buffer:
|
| 432 |
+
x_final.append(item["x_final"])
|
| 433 |
+
log_rnd.append(item["log_rnd"])
|
| 434 |
+
final_rewards.append(item["final_reward"])
|
| 435 |
+
score_vectors.append(item["score_vector"])
|
| 436 |
+
sequences.append(item["seq"])
|
| 437 |
+
|
| 438 |
+
x_final = torch.stack(x_final, dim=0) # (B, L)
|
| 439 |
+
log_rnd = torch.stack(log_rnd, dim=0).to(dtype=torch.float32) # (B)
|
| 440 |
+
final_rewards = np.stack(final_rewards, axis=0).astype(np.float32)
|
| 441 |
+
score_vectors = np.stack(score_vectors, axis=0).astype(np.float32)
|
| 442 |
+
|
| 443 |
+
return x_final, log_rnd, final_rewards, score_vectors, sequences
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def isPathEnd(self, path, maxDepth):
|
| 447 |
+
"""
|
| 448 |
+
Checks if the node is completely unmasked (ie. end of path)
|
| 449 |
+
or if the path is at the max depth
|
| 450 |
+
"""
|
| 451 |
+
if (path[-1] != self.mask_index).all():
|
| 452 |
+
return True
|
| 453 |
+
elif len(path) >= maxDepth:
|
| 454 |
+
return True
|
| 455 |
+
return False
|
| 456 |
+
|
| 457 |
+
def select(self, currNode, eps=1e-5):
|
| 458 |
+
"""
|
| 459 |
+
Traverse the tree from the root node until reaching a legal leaf node
|
| 460 |
+
"""
|
| 461 |
+
updated_log_rnd = torch.zeros((), device=self.device)
|
| 462 |
+
while True:
|
| 463 |
+
currNode, nodeStatus = currNode.selectNode()
|
| 464 |
+
|
| 465 |
+
if currNode.parentNode is not None:
|
| 466 |
+
# compute new log_policy
|
| 467 |
+
child_tokens = currNode.tokens['seqs'].to(self.device)
|
| 468 |
+
attn_mask = currNode.tokens['attention_mask'].to(self.device)
|
| 469 |
+
parent = currNode.parentNode
|
| 470 |
+
parent_tokens = parent.tokens['seqs'].to(self.device)
|
| 471 |
+
t = torch.ones(1, device = self.device)
|
| 472 |
+
dt = (1 - eps) / self.num_steps
|
| 473 |
+
with torch.no_grad():
|
| 474 |
+
with self.timer.section("select.compute_log_policy"):
|
| 475 |
+
updated_log_policy_step = self.policy_model.compute_log_policy(parent_tokens,
|
| 476 |
+
child_tokens,
|
| 477 |
+
t=t, dt=dt)
|
| 478 |
+
updated_log_rnd += updated_log_policy_step
|
| 479 |
+
|
| 480 |
+
currNode.update_logrnd(updated_log_policy_step, updated_log_rnd) # update log_rnd
|
| 481 |
+
|
| 482 |
+
if nodeStatus != 3:
|
| 483 |
+
return currNode, nodeStatus
|
| 484 |
+
|
| 485 |
+
def expand(self, parentNode, eps=1e-5):
|
| 486 |
+
"""
|
| 487 |
+
Sample unmasking steps from the pre-trained MDLM
|
| 488 |
+
adds num_children partially unmasked sequences to the children of the parentNode
|
| 489 |
+
"""
|
| 490 |
+
|
| 491 |
+
num_children = self.num_children
|
| 492 |
+
# initialize child rewards that will be added to total rewards
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
# compute number of rollout steps
|
| 496 |
+
# if parentNode.timestep = self.num_steps then num_rollout_steps = 1
|
| 497 |
+
num_rollout_steps = self.num_steps - parentNode.timestep
|
| 498 |
+
# array of rollout timesteps from the timestep of parent node to 0
|
| 499 |
+
rollout_t = torch.linspace(1, eps, self.num_steps + 1, device=self.device)
|
| 500 |
+
dt = (1 - eps) / self.num_steps
|
| 501 |
+
|
| 502 |
+
# initialize x and attn_mask
|
| 503 |
+
x = parentNode.tokens['seqs'].to(self.device)
|
| 504 |
+
attn_mask = parentNode.tokens['attention_mask'].to(self.device)
|
| 505 |
+
parent_log_rnd = parentNode.log_rnd # stores the log_rnd up to parent node
|
| 506 |
+
|
| 507 |
+
t = rollout_t[parentNode.timestep] * torch.ones(1, 1, device = self.device)
|
| 508 |
+
|
| 509 |
+
# sample M child sequences and compute their log probabilities
|
| 510 |
+
with torch.no_grad():
|
| 511 |
+
with self.timer.section("expand.batch_mcts_reverse_step"):
|
| 512 |
+
_, x_children, child_log_policy_step, child_log_pretrained_step = \
|
| 513 |
+
self.policy_model.batch_mcts_reverse_step(token_array=x,
|
| 514 |
+
t=t, dt=dt,
|
| 515 |
+
batch_size=num_children,
|
| 516 |
+
pretrained=self.pretrained)
|
| 517 |
+
|
| 518 |
+
# compute weight of the step (num_children, 1)
|
| 519 |
+
|
| 520 |
+
child_log_rnd = (parent_log_rnd + (child_log_pretrained_step - child_log_policy_step)).to(self.device)
|
| 521 |
+
|
| 522 |
+
x_rollout = x_children
|
| 523 |
+
|
| 524 |
+
traj_log_rnd = child_log_rnd # initialize log_rnd for entire rolled out trajectory
|
| 525 |
+
|
| 526 |
+
# rollout under the policy and compute the log ratio at each step
|
| 527 |
+
with self.timer.section("expand.rollout_total"):
|
| 528 |
+
for i in range(1, num_rollout_steps):
|
| 529 |
+
t = rollout_t[parentNode.timestep + i] * torch.ones(num_children, 1, device = self.device)
|
| 530 |
+
|
| 531 |
+
with torch.no_grad():
|
| 532 |
+
_, x_next, log_policy_step, log_pretrained_step = \
|
| 533 |
+
self.policy_model.mcts_reverse_step(x_rollout,
|
| 534 |
+
t=t, dt=dt,
|
| 535 |
+
pretrained=self.pretrained)
|
| 536 |
+
|
| 537 |
+
# add the rollout step
|
| 538 |
+
traj_log_rnd += log_pretrained_step - log_policy_step
|
| 539 |
+
|
| 540 |
+
x_rollout = x_next
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
# if mask token remains, fully unmask
|
| 544 |
+
mask_positions = (x_rollout == self.mask_index) # (B, L) bool
|
| 545 |
+
|
| 546 |
+
# does **any** mask remain in any sequence
|
| 547 |
+
any_mask_global = mask_positions.any().item() # true if mask remains
|
| 548 |
+
if any_mask_global:
|
| 549 |
+
with torch.no_grad():
|
| 550 |
+
with self.timer.section("expand.noise_removal"):
|
| 551 |
+
log_p, x_next, log_policy_step, log_pretrained_step = \
|
| 552 |
+
self.policy_model.mcts_noise_removal(x_rollout,
|
| 553 |
+
t=t, dt=dt,
|
| 554 |
+
pretrained=self.pretrained)
|
| 555 |
+
|
| 556 |
+
traj_log_rnd += log_pretrained_step - log_policy_step
|
| 557 |
+
|
| 558 |
+
x_rollout = x_next
|
| 559 |
+
|
| 560 |
+
# stores the string sequences for reward evaluation
|
| 561 |
+
with self.timer.section("expand.decode"):
|
| 562 |
+
childSequences = self.tokenizer.batch_decode(x_rollout)
|
| 563 |
+
|
| 564 |
+
## FOR PEPTIDES ONLY ##
|
| 565 |
+
valid_x_children = []
|
| 566 |
+
valid_x_final = []
|
| 567 |
+
validSequences = []
|
| 568 |
+
valid_traj_log_rnd = []
|
| 569 |
+
|
| 570 |
+
with self.timer.section("expand.filter_is_peptide"):
|
| 571 |
+
for i in range(num_children):
|
| 572 |
+
# string sequence
|
| 573 |
+
childSeq = childSequences[i]
|
| 574 |
+
|
| 575 |
+
# check if the peptide is valid
|
| 576 |
+
if self.analyzer.is_peptide(childSeq):
|
| 577 |
+
valid_x_children.append(x_children[i])
|
| 578 |
+
valid_x_final.append(x_rollout[i])
|
| 579 |
+
validSequences.append(childSeq)
|
| 580 |
+
valid_traj_log_rnd.append(traj_log_rnd[i])
|
| 581 |
+
else:
|
| 582 |
+
childTokens = {'seqs': x_children[i].to(dtype=torch.long), 'attention_mask': attn_mask}
|
| 583 |
+
parentNode.addChildNode(tokens=childTokens,
|
| 584 |
+
log_rnd=child_log_rnd[i],
|
| 585 |
+
log_policy_step=child_log_policy_step[i],
|
| 586 |
+
log_pretrained_step=child_log_pretrained_step[i],
|
| 587 |
+
totalReward=np.zeros(self.num_obj))
|
| 588 |
+
|
| 589 |
+
del traj_log_rnd
|
| 590 |
+
|
| 591 |
+
log_targets = [
|
| 592 |
+
self.affinity1_log,
|
| 593 |
+
self.sol_log,
|
| 594 |
+
self.hemo_log,
|
| 595 |
+
self.nf_log,
|
| 596 |
+
self.permeability_log,
|
| 597 |
+
]
|
| 598 |
+
|
| 599 |
+
if len(validSequences) != 0:
|
| 600 |
+
# add scores to log
|
| 601 |
+
with self.timer.section("expand.scoring_functions"):
|
| 602 |
+
score_vectors = np.asarray(self.rewardFunc(input_seqs=validSequences))
|
| 603 |
+
|
| 604 |
+
if score_vectors.ndim == 1:
|
| 605 |
+
score_vectors = score_vectors[:, None]
|
| 606 |
+
|
| 607 |
+
average_scores = score_vectors.T
|
| 608 |
+
num_scores = average_scores.shape[0]
|
| 609 |
+
score_len = average_scores.shape[1]
|
| 610 |
+
|
| 611 |
+
for idx, log_list in enumerate(log_targets):
|
| 612 |
+
if idx < num_scores:
|
| 613 |
+
log_list.append(average_scores[idx])
|
| 614 |
+
else:
|
| 615 |
+
log_list.append(np.zeros(score_len, dtype=np.float32))
|
| 616 |
+
else:
|
| 617 |
+
# set the values added to log as 0s if there are no valid sequences
|
| 618 |
+
empty = np.zeros(self.num_children, dtype=np.float32)
|
| 619 |
+
for log_list in log_targets:
|
| 620 |
+
log_list.append(empty)
|
| 621 |
+
|
| 622 |
+
# convert to tensor
|
| 623 |
+
if len(valid_x_final) == 0:
|
| 624 |
+
# log and bail out gracefully for this expansion
|
| 625 |
+
self.valid_fraction_log.append(0.0)
|
| 626 |
+
return
|
| 627 |
+
|
| 628 |
+
valid_x_final = torch.stack(valid_x_final, dim=0)
|
| 629 |
+
valid_traj_log_rnd = torch.stack(valid_traj_log_rnd, dim=0)
|
| 630 |
+
# update buffer and get rewards
|
| 631 |
+
with self.timer.section("expand.update_buffer"):
|
| 632 |
+
traj_log_rnds, scalar_rewards = self.updateBuffer(valid_x_final, valid_traj_log_rnd, score_vectors, childSequences)
|
| 633 |
+
|
| 634 |
+
allChildReward = np.zeros_like(score_vectors[0])
|
| 635 |
+
|
| 636 |
+
for i in range(len(score_vectors)):
|
| 637 |
+
reward = score_vectors[i]
|
| 638 |
+
|
| 639 |
+
# add to all child reward vector for backprop
|
| 640 |
+
allChildReward += reward # (num_objectives,)
|
| 641 |
+
|
| 642 |
+
# create node for sequence and add to the children node of parent
|
| 643 |
+
childTokens = {'seqs': valid_x_children[i].to(dtype=torch.long), 'attention_mask': attn_mask}
|
| 644 |
+
parentNode.addChildNode(tokens=childTokens,
|
| 645 |
+
log_rnd=child_log_rnd[i],
|
| 646 |
+
log_policy_step=child_log_policy_step[i],
|
| 647 |
+
log_pretrained_step=child_log_pretrained_step[i],
|
| 648 |
+
totalReward=reward)
|
| 649 |
+
|
| 650 |
+
### END OF FOR PEPTIDES ONLY ###
|
| 651 |
+
|
| 652 |
+
valid_fraction = len(validSequences) / num_children
|
| 653 |
+
self.valid_fraction_log.append(valid_fraction)
|
| 654 |
+
|
| 655 |
+
# debugging
|
| 656 |
+
print(f"[EXPAND] iter={self.iter_num} parent_t={parentNode.timestep} "
|
| 657 |
+
f"num_children={num_children} valid={len(validSequences)} any_mask={any_mask_global}")
|
| 658 |
+
if score_vectors is not None:
|
| 659 |
+
print(f"[SCORES] min={np.min(score_vectors,0)} max={np.max(score_vectors,0)} "
|
| 660 |
+
f"nan_any={np.isnan(score_vectors).any()}")
|
| 661 |
+
# end debugging
|
| 662 |
+
|
| 663 |
+
self.reward_log.append(scalar_rewards)
|
| 664 |
+
self.logrnd_log.append(traj_log_rnds.detach().cpu().numpy())
|
| 665 |
+
|
| 666 |
+
allChildReward = allChildReward / len(validSequences) # normalize by number of valid children
|
| 667 |
+
# backpropogate all child rewards
|
| 668 |
+
with self.timer.section("expand.backprop"):
|
| 669 |
+
self.backprop(parentNode, allChildReward)
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
def backprop(self, node, allChildReward):
|
| 673 |
+
# backpropogate rewards through the path leading to the leaf node from the root
|
| 674 |
+
while node:
|
| 675 |
+
node.updateNode(allChildReward)
|
| 676 |
+
node = node.parentNode
|
roformer.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import RoFormerConfig, RoFormerForMaskedLM
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
class Roformer(nn.Module):
|
| 7 |
+
def __init__(self, config, tokenizer, device=None):
|
| 8 |
+
super(Roformer, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.tokenizer = tokenizer
|
| 11 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 12 |
+
|
| 13 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
roformer_config = RoFormerConfig(
|
| 17 |
+
vocab_size=self.tokenizer.vocab_size,
|
| 18 |
+
embedding_size=config.roformer.hidden_size,
|
| 19 |
+
hidden_size=config.roformer.hidden_size,
|
| 20 |
+
num_hidden_layers=config.roformer.n_layers,
|
| 21 |
+
num_attention_heads=config.roformer.n_heads,
|
| 22 |
+
intermediate_size=config.roformer.hidden_size * 4,
|
| 23 |
+
max_position_embeddings=config.roformer.max_position_embeddings,
|
| 24 |
+
hidden_dropout_prob=0.1,
|
| 25 |
+
attention_probs_dropout_prob=0.1,
|
| 26 |
+
pad_token_id=0,
|
| 27 |
+
rotary_value=False
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
self.model = RoFormerForMaskedLM(roformer_config).to(self.device)
|
| 31 |
+
|
| 32 |
+
def freeze_model(self):
|
| 33 |
+
for param in self.model.parameters():
|
| 34 |
+
param.requires_grad = False
|
| 35 |
+
|
| 36 |
+
def unfreeze_all_layers(self):
|
| 37 |
+
for param in self.model.parameters():
|
| 38 |
+
param.requires_grad = True
|
| 39 |
+
|
| 40 |
+
def unfreeze_n_layers(self, n):
|
| 41 |
+
num_layers = 8
|
| 42 |
+
|
| 43 |
+
for i, layer in enumerate(self.model.roformer.encoder.layer):
|
| 44 |
+
# finetune final n layers
|
| 45 |
+
if i >= num_layers - n:
|
| 46 |
+
# unfreeze query weights
|
| 47 |
+
for module in layer.attention.self.query.modules():
|
| 48 |
+
for param in module.parameters():
|
| 49 |
+
param.requires_grad = True
|
| 50 |
+
# unfreeze key weights
|
| 51 |
+
for module in layer.attention.self.key.modules():
|
| 52 |
+
for param in module.parameters():
|
| 53 |
+
param.requires_grad = True
|
| 54 |
+
|
| 55 |
+
def forward(self, input_ids, attn_mask):
|
| 56 |
+
|
| 57 |
+
input_ids = input_ids.to(self.device)
|
| 58 |
+
attn_mask = attn_mask.to(self.device)
|
| 59 |
+
|
| 60 |
+
# get logits embeddings
|
| 61 |
+
logits = self.model(input_ids=input_ids, attention_mask=attn_mask)
|
| 62 |
+
# return logits
|
| 63 |
+
#print(logits.logits)
|
| 64 |
+
return logits.logits
|
| 65 |
+
|
| 66 |
+
def save_model(self, save_dir):
|
| 67 |
+
self.model.save_pretrained(save_dir)
|
| 68 |
+
self.tokenizer.save_pretrained(save_dir)
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
def load_model(cls, save_dir, config, tokenizer):
|
| 72 |
+
roformer = cls(config, tokenizer)
|
| 73 |
+
roformer.model = RoFormerForMaskedLM.from_pretrained(save_dir)
|
| 74 |
+
return roformer
|
scoring/functions/binding.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os, torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import esm
|
| 8 |
+
from transformers import AutoModelForMaskedLM
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _sanitize_token_ids(input_ids: torch.Tensor, vocab_size: int, unk_id: int) -> torch.Tensor:
|
| 12 |
+
if vocab_size <= 0 or input_ids.numel() == 0:
|
| 13 |
+
return input_ids
|
| 14 |
+
if torch.any(input_ids >= vocab_size) or torch.any(input_ids < 0):
|
| 15 |
+
# Replace out-of-range IDs with UNK to avoid embedding OOB.
|
| 16 |
+
unk = torch.tensor(unk_id, device=input_ids.device, dtype=input_ids.dtype)
|
| 17 |
+
input_ids = torch.where((input_ids >= vocab_size) | (input_ids < 0), unk, input_ids)
|
| 18 |
+
return input_ids
|
| 19 |
+
|
| 20 |
+
class ImprovedBindingPredictor(nn.Module):
|
| 21 |
+
def __init__(self,
|
| 22 |
+
esm_dim=1280,
|
| 23 |
+
smiles_dim=768,
|
| 24 |
+
hidden_dim=512,
|
| 25 |
+
n_heads=8,
|
| 26 |
+
n_layers=3,
|
| 27 |
+
dropout=0.1):
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
# Define binding thresholds
|
| 31 |
+
self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
|
| 32 |
+
self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
|
| 33 |
+
|
| 34 |
+
# Project to same dimension
|
| 35 |
+
self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
|
| 36 |
+
self.protein_projection = nn.Linear(esm_dim, hidden_dim)
|
| 37 |
+
self.protein_norm = nn.LayerNorm(hidden_dim)
|
| 38 |
+
self.smiles_norm = nn.LayerNorm(hidden_dim)
|
| 39 |
+
|
| 40 |
+
# Cross attention blocks with layer norm
|
| 41 |
+
self.cross_attention_layers = nn.ModuleList([
|
| 42 |
+
nn.ModuleDict({
|
| 43 |
+
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
|
| 44 |
+
'norm1': nn.LayerNorm(hidden_dim),
|
| 45 |
+
'ffn': nn.Sequential(
|
| 46 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
| 47 |
+
nn.ReLU(),
|
| 48 |
+
nn.Dropout(dropout),
|
| 49 |
+
nn.Linear(hidden_dim * 4, hidden_dim)
|
| 50 |
+
),
|
| 51 |
+
'norm2': nn.LayerNorm(hidden_dim)
|
| 52 |
+
}) for _ in range(n_layers)
|
| 53 |
+
])
|
| 54 |
+
|
| 55 |
+
# Prediction heads
|
| 56 |
+
self.shared_head = nn.Sequential(
|
| 57 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 58 |
+
nn.ReLU(),
|
| 59 |
+
nn.Dropout(dropout),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Regression head
|
| 63 |
+
self.regression_head = nn.Linear(hidden_dim, 1)
|
| 64 |
+
|
| 65 |
+
# Classification head (3 classes: tight, medium, loose binding)
|
| 66 |
+
self.classification_head = nn.Linear(hidden_dim, 3)
|
| 67 |
+
|
| 68 |
+
def get_binding_class(self, affinity):
|
| 69 |
+
"""Convert affinity values to class indices
|
| 70 |
+
0: tight binding (>= 7.5)
|
| 71 |
+
1: medium binding (6.0-7.5)
|
| 72 |
+
2: weak binding (< 6.0)
|
| 73 |
+
"""
|
| 74 |
+
if isinstance(affinity, torch.Tensor):
|
| 75 |
+
tight_mask = affinity >= self.tight_threshold
|
| 76 |
+
weak_mask = affinity < self.weak_threshold
|
| 77 |
+
medium_mask = ~(tight_mask | weak_mask)
|
| 78 |
+
|
| 79 |
+
classes = torch.zeros_like(affinity, dtype=torch.long)
|
| 80 |
+
classes[medium_mask] = 1
|
| 81 |
+
classes[weak_mask] = 2
|
| 82 |
+
return classes
|
| 83 |
+
else:
|
| 84 |
+
if affinity >= self.tight_threshold:
|
| 85 |
+
return 0 # tight binding
|
| 86 |
+
elif affinity < self.weak_threshold:
|
| 87 |
+
return 2 # weak binding
|
| 88 |
+
else:
|
| 89 |
+
return 1 # medium binding
|
| 90 |
+
|
| 91 |
+
def forward(self, protein_emb, smiles_emb):
|
| 92 |
+
protein = self.protein_norm(self.protein_projection(protein_emb))
|
| 93 |
+
smiles = self.smiles_norm(self.smiles_projection(smiles_emb))
|
| 94 |
+
|
| 95 |
+
#protein = protein.transpose(0, 1)
|
| 96 |
+
#smiles = smiles.transpose(0, 1)
|
| 97 |
+
|
| 98 |
+
# Cross attention layers
|
| 99 |
+
for layer in self.cross_attention_layers:
|
| 100 |
+
# Protein attending to SMILES
|
| 101 |
+
attended_protein = layer['attention'](
|
| 102 |
+
protein, smiles, smiles
|
| 103 |
+
)[0]
|
| 104 |
+
protein = layer['norm1'](protein + attended_protein)
|
| 105 |
+
protein = layer['norm2'](protein + layer['ffn'](protein))
|
| 106 |
+
|
| 107 |
+
# SMILES attending to protein
|
| 108 |
+
attended_smiles = layer['attention'](
|
| 109 |
+
smiles, protein, protein
|
| 110 |
+
)[0]
|
| 111 |
+
smiles = layer['norm1'](smiles + attended_smiles)
|
| 112 |
+
smiles = layer['norm2'](smiles + layer['ffn'](smiles))
|
| 113 |
+
|
| 114 |
+
# Get sequence-level representations
|
| 115 |
+
protein_pool = torch.mean(protein, dim=0)
|
| 116 |
+
smiles_pool = torch.mean(smiles, dim=0)
|
| 117 |
+
|
| 118 |
+
# Concatenate both representations
|
| 119 |
+
combined = torch.cat([protein_pool, smiles_pool], dim=-1)
|
| 120 |
+
|
| 121 |
+
# Shared features
|
| 122 |
+
shared_features = self.shared_head(combined)
|
| 123 |
+
|
| 124 |
+
regression_output = self.regression_head(shared_features)
|
| 125 |
+
classification_logits = self.classification_head(shared_features)
|
| 126 |
+
|
| 127 |
+
return regression_output, classification_logits
|
| 128 |
+
|
| 129 |
+
class BindingAffinity:
|
| 130 |
+
def __init__(self, prot_seq, tokenizer, base_path, device=None, emb_model=None):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 133 |
+
|
| 134 |
+
# peptide embeddings
|
| 135 |
+
if emb_model is not None:
|
| 136 |
+
self.pep_model = emb_model.to(self.device).eval()
|
| 137 |
+
else:
|
| 138 |
+
self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
|
| 139 |
+
|
| 140 |
+
self.pep_tokenizer = tokenizer
|
| 141 |
+
self.unk_id = getattr(self.pep_tokenizer, "unk_token_id", None)
|
| 142 |
+
if self.unk_id is None:
|
| 143 |
+
self.unk_id = self.pep_tokenizer.vocab.get(self.pep_tokenizer.unk_token, 0)
|
| 144 |
+
self.pep_vocab_size = None
|
| 145 |
+
self.max_pep_len = None
|
| 146 |
+
if hasattr(self.pep_model, "model") and hasattr(self.pep_model.model, "roformer"):
|
| 147 |
+
self.pep_vocab_size = self.pep_model.model.roformer.embeddings.word_embeddings.num_embeddings
|
| 148 |
+
self.max_pep_len = self.pep_model.model.roformer.config.max_position_embeddings
|
| 149 |
+
elif hasattr(self.pep_model, "roformer"):
|
| 150 |
+
self.pep_vocab_size = self.pep_model.roformer.embeddings.word_embeddings.num_embeddings
|
| 151 |
+
self.max_pep_len = self.pep_model.roformer.config.max_position_embeddings
|
| 152 |
+
elif hasattr(self.pep_model, "get_input_embeddings"):
|
| 153 |
+
self.pep_vocab_size = self.pep_model.get_input_embeddings().num_embeddings
|
| 154 |
+
self.max_pep_len = getattr(self.pep_model.config, "max_position_embeddings", None)
|
| 155 |
+
|
| 156 |
+
self.model = ImprovedBindingPredictor().to(self.device)
|
| 157 |
+
checkpoint = torch.load(f'{base_path}/tr2d2-pep/scoring/functions/classifiers/binding-affinity.pt',
|
| 158 |
+
map_location=self.device,
|
| 159 |
+
weights_only=False)
|
| 160 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 161 |
+
|
| 162 |
+
self.model.eval()
|
| 163 |
+
|
| 164 |
+
self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model
|
| 165 |
+
self.esm_model = self.esm_model.to(self.device).eval()
|
| 166 |
+
self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer
|
| 167 |
+
|
| 168 |
+
data = [("target", prot_seq)]
|
| 169 |
+
# get tokenized protein
|
| 170 |
+
_, _, prot_tokens = self.prot_tokenizer(data)
|
| 171 |
+
prot_tokens = prot_tokens.to(self.device)
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2
|
| 174 |
+
prot_emb = results["representations"][33]
|
| 175 |
+
|
| 176 |
+
self.prot_emb = prot_emb[0].to(self.device)
|
| 177 |
+
self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def forward(self, input_seqs):
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
scores = []
|
| 183 |
+
for seq in input_seqs:
|
| 184 |
+
pep_tokens = self.pep_tokenizer(
|
| 185 |
+
seq,
|
| 186 |
+
return_tensors='pt',
|
| 187 |
+
padding=True,
|
| 188 |
+
truncation=self.max_pep_len is not None,
|
| 189 |
+
max_length=self.max_pep_len,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
pep_tokens = {k: v.to(self.device) for k, v in pep_tokens.items()}
|
| 193 |
+
pep_tokens["input_ids"] = _sanitize_token_ids(
|
| 194 |
+
pep_tokens["input_ids"], int(self.pep_vocab_size or 0), int(self.unk_id)
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
# Check if using custom Roformer wrapper or standard model
|
| 199 |
+
if hasattr(self.pep_model, 'model'):
|
| 200 |
+
# Custom roformer.Roformer wrapper - get hidden states from inner model
|
| 201 |
+
emb = self.pep_model.model.roformer(
|
| 202 |
+
input_ids=pep_tokens['input_ids'],
|
| 203 |
+
attention_mask=pep_tokens.get('attention_mask'),
|
| 204 |
+
output_hidden_states=True
|
| 205 |
+
)
|
| 206 |
+
pep_emb = emb.last_hidden_state.squeeze(0)
|
| 207 |
+
pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
|
| 208 |
+
else:
|
| 209 |
+
# Standard AutoModelForMaskedLM
|
| 210 |
+
emb = self.pep_model(
|
| 211 |
+
input_ids=pep_tokens['input_ids'],
|
| 212 |
+
attention_mask=pep_tokens.get('attention_mask'),
|
| 213 |
+
output_hidden_states=True
|
| 214 |
+
)
|
| 215 |
+
pep_emb = emb.last_hidden_state.squeeze(0)
|
| 216 |
+
pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
|
| 217 |
+
|
| 218 |
+
score, logits = self.model.forward(self.prot_emb, pep_emb)
|
| 219 |
+
scores.append(score.item())
|
| 220 |
+
return scores
|
| 221 |
+
|
| 222 |
+
def __call__(self, input_seqs: list):
|
| 223 |
+
return self.forward(input_seqs)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class MultiTargetBindingAffinity:
|
| 227 |
+
"""
|
| 228 |
+
Binding affinity predictor that can handle multiple protein targets dynamically.
|
| 229 |
+
|
| 230 |
+
Unlike BindingAffinity which pre-computes a single target's embedding,
|
| 231 |
+
this class can switch between different protein targets on-the-fly.
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 235 |
+
"""
|
| 236 |
+
Initialize multi-target binding affinity predictor.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
tokenizer: Peptide tokenizer
|
| 240 |
+
base_path: Base path for model files
|
| 241 |
+
device: Device for computation (default: auto-detect)
|
| 242 |
+
emb_model: Optional pre-loaded embedding model
|
| 243 |
+
"""
|
| 244 |
+
super().__init__()
|
| 245 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 246 |
+
|
| 247 |
+
# Peptide embeddings
|
| 248 |
+
if emb_model is not None:
|
| 249 |
+
self.pep_model = emb_model.to(self.device).eval()
|
| 250 |
+
else:
|
| 251 |
+
self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
|
| 252 |
+
|
| 253 |
+
self.pep_tokenizer = tokenizer
|
| 254 |
+
self.unk_id = getattr(self.pep_tokenizer, "unk_token_id", None)
|
| 255 |
+
if self.unk_id is None:
|
| 256 |
+
self.unk_id = self.pep_tokenizer.vocab.get(self.pep_tokenizer.unk_token, 0)
|
| 257 |
+
self.pep_vocab_size = None
|
| 258 |
+
self.max_pep_len = None
|
| 259 |
+
if hasattr(self.pep_model, "model") and hasattr(self.pep_model.model, "roformer"):
|
| 260 |
+
self.pep_vocab_size = self.pep_model.model.roformer.embeddings.word_embeddings.num_embeddings
|
| 261 |
+
self.max_pep_len = self.pep_model.model.roformer.config.max_position_embeddings
|
| 262 |
+
elif hasattr(self.pep_model, "roformer"):
|
| 263 |
+
self.pep_vocab_size = self.pep_model.roformer.embeddings.word_embeddings.num_embeddings
|
| 264 |
+
self.max_pep_len = self.pep_model.roformer.config.max_position_embeddings
|
| 265 |
+
elif hasattr(self.pep_model, "get_input_embeddings"):
|
| 266 |
+
self.pep_vocab_size = self.pep_model.get_input_embeddings().num_embeddings
|
| 267 |
+
self.max_pep_len = getattr(self.pep_model.config, "max_position_embeddings", None)
|
| 268 |
+
|
| 269 |
+
# Binding affinity prediction model
|
| 270 |
+
self.model = ImprovedBindingPredictor().to(self.device)
|
| 271 |
+
checkpoint = torch.load(f'{base_path}/tr2d2-pep/scoring/functions/classifiers/binding-affinity.pt',
|
| 272 |
+
map_location=self.device,
|
| 273 |
+
weights_only=False)
|
| 274 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 275 |
+
self.model.eval()
|
| 276 |
+
|
| 277 |
+
# Protein (ESM) model
|
| 278 |
+
self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
|
| 279 |
+
self.esm_model = self.esm_model.to(self.device).eval()
|
| 280 |
+
self.prot_tokenizer = alphabet.get_batch_converter()
|
| 281 |
+
|
| 282 |
+
# Cache for protein embeddings (target_seq -> embedding)
|
| 283 |
+
self.prot_emb_cache = {}
|
| 284 |
+
|
| 285 |
+
def get_protein_embedding(self, prot_seq: str):
|
| 286 |
+
"""
|
| 287 |
+
Get protein embedding, using cache if available.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
prot_seq: Protein amino acid sequence
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
Protein embedding tensor
|
| 294 |
+
"""
|
| 295 |
+
# Check cache first
|
| 296 |
+
if prot_seq in self.prot_emb_cache:
|
| 297 |
+
return self.prot_emb_cache[prot_seq]
|
| 298 |
+
|
| 299 |
+
# Compute embedding
|
| 300 |
+
data = [("target", prot_seq)]
|
| 301 |
+
_, _, prot_tokens = self.prot_tokenizer(data)
|
| 302 |
+
prot_tokens = prot_tokens.to(self.device)
|
| 303 |
+
|
| 304 |
+
with torch.no_grad():
|
| 305 |
+
results = self.esm_model.forward(prot_tokens, repr_layers=[33])
|
| 306 |
+
prot_emb = results["representations"][33]
|
| 307 |
+
|
| 308 |
+
prot_emb = prot_emb[0].to(self.device)
|
| 309 |
+
prot_emb = torch.mean(prot_emb, dim=0, keepdim=True)
|
| 310 |
+
|
| 311 |
+
# Cache for future use
|
| 312 |
+
self.prot_emb_cache[prot_seq] = prot_emb
|
| 313 |
+
|
| 314 |
+
return prot_emb
|
| 315 |
+
|
| 316 |
+
def forward(self, input_seqs, prot_seq: str):
|
| 317 |
+
"""
|
| 318 |
+
Predict binding affinity for peptide-protein pairs.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
input_seqs: List of peptide sequences
|
| 322 |
+
prot_seq: Protein target sequence
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
List of binding affinity scores
|
| 326 |
+
"""
|
| 327 |
+
# Get protein embedding (cached if previously computed)
|
| 328 |
+
prot_emb = self.get_protein_embedding(prot_seq)
|
| 329 |
+
|
| 330 |
+
with torch.no_grad():
|
| 331 |
+
scores = []
|
| 332 |
+
for seq in input_seqs:
|
| 333 |
+
pep_tokens = self.pep_tokenizer(
|
| 334 |
+
seq,
|
| 335 |
+
return_tensors='pt',
|
| 336 |
+
padding=True,
|
| 337 |
+
truncation=self.max_pep_len is not None,
|
| 338 |
+
max_length=self.max_pep_len,
|
| 339 |
+
)
|
| 340 |
+
pep_tokens = {k: v.to(self.device) for k, v in pep_tokens.items()}
|
| 341 |
+
pep_tokens["input_ids"] = _sanitize_token_ids(
|
| 342 |
+
pep_tokens["input_ids"], int(self.pep_vocab_size or 0), int(self.unk_id)
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
with torch.no_grad():
|
| 346 |
+
# Check if using custom Roformer wrapper or standard model
|
| 347 |
+
if hasattr(self.pep_model, 'model'):
|
| 348 |
+
# Custom roformer.Roformer wrapper - get hidden states from inner model
|
| 349 |
+
emb = self.pep_model.model.roformer(
|
| 350 |
+
input_ids=pep_tokens['input_ids'],
|
| 351 |
+
attention_mask=pep_tokens.get('attention_mask'),
|
| 352 |
+
output_hidden_states=True
|
| 353 |
+
)
|
| 354 |
+
pep_emb = emb.last_hidden_state.squeeze(0)
|
| 355 |
+
pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
|
| 356 |
+
else:
|
| 357 |
+
# Standard AutoModelForMaskedLM
|
| 358 |
+
emb = self.pep_model(
|
| 359 |
+
input_ids=pep_tokens['input_ids'],
|
| 360 |
+
attention_mask=pep_tokens.get('attention_mask'),
|
| 361 |
+
output_hidden_states=True
|
| 362 |
+
)
|
| 363 |
+
pep_emb = emb.last_hidden_state.squeeze(0)
|
| 364 |
+
pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
|
| 365 |
+
|
| 366 |
+
score, logits = self.model.forward(prot_emb, pep_emb)
|
| 367 |
+
scores.append(score.item())
|
| 368 |
+
|
| 369 |
+
return scores
|
| 370 |
+
|
| 371 |
+
def forward_from_probs(
|
| 372 |
+
self,
|
| 373 |
+
token_probs: torch.Tensor,
|
| 374 |
+
attention_mask: torch.Tensor,
|
| 375 |
+
prot_seq: str,
|
| 376 |
+
) -> torch.Tensor:
|
| 377 |
+
"""
|
| 378 |
+
Differentiable binding affinity from token probabilities.
|
| 379 |
+
"""
|
| 380 |
+
if token_probs.dim() == 2:
|
| 381 |
+
token_probs = token_probs.unsqueeze(0)
|
| 382 |
+
token_probs = token_probs.to(self.device)
|
| 383 |
+
attention_mask = attention_mask.to(self.device)
|
| 384 |
+
|
| 385 |
+
roformer = None
|
| 386 |
+
if hasattr(self.pep_model, "model") and hasattr(self.pep_model.model, "roformer"):
|
| 387 |
+
roformer = self.pep_model.model.roformer
|
| 388 |
+
emb_weight = roformer.embeddings.word_embeddings.weight
|
| 389 |
+
elif hasattr(self.pep_model, "roformer"):
|
| 390 |
+
roformer = self.pep_model.roformer
|
| 391 |
+
emb_weight = roformer.embeddings.word_embeddings.weight
|
| 392 |
+
else:
|
| 393 |
+
emb_weight = self.pep_model.get_input_embeddings().weight
|
| 394 |
+
|
| 395 |
+
if token_probs.size(-1) != emb_weight.size(0):
|
| 396 |
+
raise ValueError(
|
| 397 |
+
f"Token vocab mismatch: probs={token_probs.size(-1)} vs model={emb_weight.size(0)}"
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
inputs_embeds = token_probs @ emb_weight
|
| 401 |
+
if roformer is not None:
|
| 402 |
+
outputs = roformer(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
|
| 403 |
+
hidden = outputs.last_hidden_state
|
| 404 |
+
else:
|
| 405 |
+
outputs = self.pep_model(
|
| 406 |
+
inputs_embeds=inputs_embeds,
|
| 407 |
+
attention_mask=attention_mask,
|
| 408 |
+
output_hidden_states=True,
|
| 409 |
+
return_dict=True,
|
| 410 |
+
)
|
| 411 |
+
hidden = outputs.hidden_states[-1]
|
| 412 |
+
|
| 413 |
+
mask = attention_mask.to(hidden.dtype).unsqueeze(-1)
|
| 414 |
+
pep_emb = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0)
|
| 415 |
+
|
| 416 |
+
prot_emb = self.get_protein_embedding(prot_seq).to(self.device)
|
| 417 |
+
prot_emb = prot_emb.expand(pep_emb.size(0), -1).unsqueeze(0)
|
| 418 |
+
pep_emb = pep_emb.unsqueeze(0)
|
| 419 |
+
|
| 420 |
+
score, _ = self.model.forward(prot_emb, pep_emb)
|
| 421 |
+
return score.squeeze(-1)
|
| 422 |
+
|
| 423 |
+
def __call__(self, input_seqs: list, prot_seq: str):
|
| 424 |
+
"""
|
| 425 |
+
Predict binding affinity for peptide-protein pairs.
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
input_seqs: List of peptide sequences
|
| 429 |
+
prot_seq: Protein target sequence
|
| 430 |
+
|
| 431 |
+
Returns:
|
| 432 |
+
List of binding affinity scores
|
| 433 |
+
"""
|
| 434 |
+
return self.forward(input_seqs, prot_seq)
|
| 435 |
+
|
| 436 |
+
def clear_cache(self):
|
| 437 |
+
"""Clear the protein embedding cache to free memory."""
|
| 438 |
+
self.prot_emb_cache = {}
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
class TargetSpecificBindingAffinity:
|
| 442 |
+
"""
|
| 443 |
+
Wrapper that binds a specific protein target to MultiTargetBindingAffinity.
|
| 444 |
+
|
| 445 |
+
This allows using MultiTargetBindingAffinity with the standard BindingAffinity interface
|
| 446 |
+
where only peptide sequences need to be provided.
|
| 447 |
+
"""
|
| 448 |
+
|
| 449 |
+
def __init__(self, multi_target_predictor: MultiTargetBindingAffinity, prot_seq: str):
|
| 450 |
+
"""
|
| 451 |
+
Create a target-specific binding affinity predictor.
|
| 452 |
+
|
| 453 |
+
Args:
|
| 454 |
+
multi_target_predictor: The underlying multi-target predictor
|
| 455 |
+
prot_seq: The protein target sequence to use
|
| 456 |
+
"""
|
| 457 |
+
self.predictor = multi_target_predictor
|
| 458 |
+
self.prot_seq = prot_seq
|
| 459 |
+
|
| 460 |
+
def forward(self, input_seqs):
|
| 461 |
+
"""
|
| 462 |
+
Predict binding affinity for peptides against the bound target.
|
| 463 |
+
|
| 464 |
+
Args:
|
| 465 |
+
input_seqs: List of peptide sequences
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
List of binding affinity scores
|
| 469 |
+
"""
|
| 470 |
+
return self.predictor.forward(input_seqs, self.prot_seq)
|
| 471 |
+
|
| 472 |
+
def __call__(self, input_seqs: list):
|
| 473 |
+
"""
|
| 474 |
+
Predict binding affinity for peptides against the bound target.
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
input_seqs: List of peptide sequences
|
| 478 |
+
|
| 479 |
+
Returns:
|
| 480 |
+
List of binding affinity scores
|
| 481 |
+
"""
|
| 482 |
+
return self.forward(input_seqs)
|
scoring/functions/classifiers/hemolysis-xgboost.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scoring/functions/classifiers/nonfouling-xgboost.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scoring/functions/classifiers/permeability-xgboost.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7e5d8c84bdad75f7091b5b3963133d4b0ebd180ae45654618ca6c090eee0bc06
|
| 3 |
+
size 45249160
|
scoring/functions/classifiers/solubility-xgboost.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scoring/functions/hemolysis.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import xgboost as xgb
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import AutoModelForMaskedLM
|
| 5 |
+
import warnings
|
| 6 |
+
import numpy as np
|
| 7 |
+
from rdkit import rdBase
|
| 8 |
+
|
| 9 |
+
rdBase.DisableLog('rdApp.error')
|
| 10 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 11 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 12 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 13 |
+
|
| 14 |
+
class Hemolysis:
|
| 15 |
+
|
| 16 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 17 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 18 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/tr2d2-pep/scoring/functions/classifiers/hemolysis-xgboost.json')
|
| 19 |
+
self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
|
| 20 |
+
self.tokenizer = tokenizer
|
| 21 |
+
|
| 22 |
+
def generate_embeddings(self, sequences):
|
| 23 |
+
embeddings = []
|
| 24 |
+
for sequence in sequences:
|
| 25 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 26 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
output = self.emb_model(**tokenized)
|
| 29 |
+
# Mean pooling across sequence length
|
| 30 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 31 |
+
embeddings.append(embedding)
|
| 32 |
+
return np.array(embeddings)
|
| 33 |
+
|
| 34 |
+
def get_scores(self, input_seqs: list):
|
| 35 |
+
scores = np.ones(len(input_seqs))
|
| 36 |
+
features = self.generate_embeddings(input_seqs)
|
| 37 |
+
|
| 38 |
+
if len(features) == 0:
|
| 39 |
+
return scores
|
| 40 |
+
|
| 41 |
+
features = np.nan_to_num(features, nan=0.)
|
| 42 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 43 |
+
|
| 44 |
+
features = xgb.DMatrix(features)
|
| 45 |
+
|
| 46 |
+
probs = self.predictor.predict(features)
|
| 47 |
+
# return the probability of it being not hemolytic
|
| 48 |
+
return scores - probs
|
| 49 |
+
|
| 50 |
+
def __call__(self, input_seqs: list):
|
| 51 |
+
scores = self.get_scores(input_seqs)
|
| 52 |
+
return scores
|
| 53 |
+
|
| 54 |
+
def unittest():
|
| 55 |
+
hemo = Hemolysis()
|
| 56 |
+
seq = ["[te]NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 57 |
+
print(hemo.tokenizer.vocab_size)
|
| 58 |
+
scores = hemo(input_seqs=seq)
|
| 59 |
+
print(scores)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
unittest()
|
scoring/functions/nonfouling.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import xgboost as xgb
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import AutoModelForMaskedLM
|
| 7 |
+
import warnings
|
| 8 |
+
import numpy as np
|
| 9 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
rdBase.DisableLog('rdApp.error')
|
| 13 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 14 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 15 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 16 |
+
|
| 17 |
+
class Nonfouling:
|
| 18 |
+
|
| 19 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 20 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 21 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/tr2d2-pep/scoring/functions/classifiers/nonfouling-xgboost.json')
|
| 22 |
+
self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
|
| 23 |
+
self.tokenizer = tokenizer
|
| 24 |
+
|
| 25 |
+
def generate_embeddings(self, sequences):
|
| 26 |
+
embeddings = []
|
| 27 |
+
for sequence in sequences:
|
| 28 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 29 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
output = self.emb_model(**tokenized)
|
| 32 |
+
# Mean pooling across sequence length
|
| 33 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 34 |
+
embeddings.append(embedding)
|
| 35 |
+
return np.array(embeddings)
|
| 36 |
+
|
| 37 |
+
def get_scores(self, input_seqs: list):
|
| 38 |
+
scores = np.zeros(len(input_seqs))
|
| 39 |
+
features = self.generate_embeddings(input_seqs)
|
| 40 |
+
|
| 41 |
+
if len(features) == 0:
|
| 42 |
+
return scores
|
| 43 |
+
|
| 44 |
+
features = np.nan_to_num(features, nan=0.)
|
| 45 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 46 |
+
|
| 47 |
+
features = xgb.DMatrix(features)
|
| 48 |
+
|
| 49 |
+
scores = self.predictor.predict(features)
|
| 50 |
+
# return the probability of it being not hemolytic
|
| 51 |
+
return scores
|
| 52 |
+
|
| 53 |
+
def __call__(self, input_seqs: list):
|
| 54 |
+
scores = self.get_scores(input_seqs)
|
| 55 |
+
return scores
|
| 56 |
+
|
| 57 |
+
def unittest():
|
| 58 |
+
nf = Nonfouling()
|
| 59 |
+
seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 60 |
+
|
| 61 |
+
scores = nf(input_seqs=seq)
|
| 62 |
+
print(scores)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == '__main__':
|
| 66 |
+
unittest()
|
scoring/functions/permeability.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import xgboost as xgb
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import AutoModelForMaskedLM
|
| 7 |
+
import warnings
|
| 8 |
+
import numpy as np
|
| 9 |
+
from rdkit.Chem import Descriptors, rdMolDescriptors
|
| 10 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 11 |
+
from rdkit.Chem import AllChem
|
| 12 |
+
from typing import List
|
| 13 |
+
from transformers import AutoModelForMaskedLM
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
rdBase.DisableLog('rdApp.error')
|
| 17 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 18 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 19 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 20 |
+
|
| 21 |
+
def fingerprints_from_smiles(smiles: List, size=2048):
|
| 22 |
+
""" Create ECFP fingerprints of smiles, with validity check """
|
| 23 |
+
fps = []
|
| 24 |
+
valid_mask = []
|
| 25 |
+
for i, smile in enumerate(smiles):
|
| 26 |
+
mol = Chem.MolFromSmiles(smile)
|
| 27 |
+
valid_mask.append(int(mol is not None))
|
| 28 |
+
fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
|
| 29 |
+
fps.append(fp)
|
| 30 |
+
|
| 31 |
+
fps = np.concatenate(fps, axis=0)
|
| 32 |
+
return fps, valid_mask
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
|
| 36 |
+
""" Create ECFP fingerprint of a molecule """
|
| 37 |
+
if hashed:
|
| 38 |
+
fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
|
| 39 |
+
else:
|
| 40 |
+
fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
|
| 41 |
+
fp_np = np.zeros((1,))
|
| 42 |
+
DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
|
| 43 |
+
return fp_np.reshape(1, -1)
|
| 44 |
+
|
| 45 |
+
def getMolDescriptors(mol, missingVal=0):
|
| 46 |
+
""" calculate the full list of descriptors for a molecule """
|
| 47 |
+
|
| 48 |
+
values, names = [], []
|
| 49 |
+
for nm, fn in Descriptors._descList:
|
| 50 |
+
try:
|
| 51 |
+
val = fn(mol)
|
| 52 |
+
except:
|
| 53 |
+
val = missingVal
|
| 54 |
+
values.append(val)
|
| 55 |
+
names.append(nm)
|
| 56 |
+
|
| 57 |
+
custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
|
| 58 |
+
'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
|
| 59 |
+
'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
|
| 60 |
+
|
| 61 |
+
for nm, fn in custom_descriptors.items():
|
| 62 |
+
try:
|
| 63 |
+
val = fn(mol)
|
| 64 |
+
except:
|
| 65 |
+
val = missingVal
|
| 66 |
+
values.append(val)
|
| 67 |
+
names.append(nm)
|
| 68 |
+
return values, names
|
| 69 |
+
|
| 70 |
+
def get_pep_dps_from_smi(smi):
|
| 71 |
+
try:
|
| 72 |
+
mol = Chem.MolFromSmiles(smi)
|
| 73 |
+
except:
|
| 74 |
+
print(f"convert smi {smi} to molecule failed!")
|
| 75 |
+
mol = None
|
| 76 |
+
|
| 77 |
+
dps, _ = getMolDescriptors(mol)
|
| 78 |
+
return np.array(dps)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_pep_dps(smi_list):
|
| 82 |
+
if len(smi_list) == 0:
|
| 83 |
+
return np.zeros((0, 213))
|
| 84 |
+
return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
|
| 85 |
+
|
| 86 |
+
def check_smi_validity(smiles: list):
|
| 87 |
+
valid_smi, valid_idx = [], []
|
| 88 |
+
for idx, smi in enumerate(smiles):
|
| 89 |
+
try:
|
| 90 |
+
mol = Chem.MolFromSmiles(smi) if smi else None
|
| 91 |
+
if mol:
|
| 92 |
+
valid_smi.append(smi)
|
| 93 |
+
valid_idx.append(idx)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
# logger.debug(f'Error: {e} in smiles {smi}')
|
| 96 |
+
pass
|
| 97 |
+
return valid_smi, valid_idx
|
| 98 |
+
|
| 99 |
+
class Permeability:
|
| 100 |
+
|
| 101 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 102 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 103 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/tr2d2-pep/scoring/functions/classifiers/permeability-xgboost.json')
|
| 104 |
+
if emb_model is not None:
|
| 105 |
+
self.emb_model = emb_model.to(self.device).eval()
|
| 106 |
+
else:
|
| 107 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
|
| 108 |
+
|
| 109 |
+
self.tokenizer = tokenizer
|
| 110 |
+
|
| 111 |
+
def generate_embeddings(self, sequences):
|
| 112 |
+
embeddings = []
|
| 113 |
+
for sequence in sequences:
|
| 114 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 115 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
output = self.emb_model(**tokenized)
|
| 118 |
+
# Mean pooling across sequence length
|
| 119 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 120 |
+
embeddings.append(embedding)
|
| 121 |
+
return np.array(embeddings)
|
| 122 |
+
|
| 123 |
+
def get_features(self, input_seqs: list, dps=False, fps=False):
|
| 124 |
+
#valid_smiles, valid_idxes = check_smi_validity(input_seqs)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if fps:
|
| 128 |
+
fingerprints = fingerprints_from_smiles(input_seqs)[0]
|
| 129 |
+
else:
|
| 130 |
+
fingerprints = torch.empty((len(input_seqs), 0))
|
| 131 |
+
|
| 132 |
+
if dps:
|
| 133 |
+
descriptors = get_pep_dps(input_seqs)
|
| 134 |
+
else:
|
| 135 |
+
descriptors = torch.empty((len(input_seqs), 0))
|
| 136 |
+
|
| 137 |
+
embeddings = self.generate_embeddings(input_seqs)
|
| 138 |
+
# logger.debug(f'X_fps.shape: {X_fps.shape}, X_dps.shape: {X_dps.shape}')
|
| 139 |
+
|
| 140 |
+
features = np.concatenate([fingerprints, descriptors, embeddings], axis=1)
|
| 141 |
+
|
| 142 |
+
return features
|
| 143 |
+
|
| 144 |
+
def get_scores(self, input_seqs: list):
|
| 145 |
+
scores = -10 * np.ones(len(input_seqs))
|
| 146 |
+
features = self.get_features(input_seqs)
|
| 147 |
+
|
| 148 |
+
if len(features) == 0:
|
| 149 |
+
return scores
|
| 150 |
+
|
| 151 |
+
features = np.nan_to_num(features, nan=0.)
|
| 152 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 153 |
+
|
| 154 |
+
features = xgb.DMatrix(features)
|
| 155 |
+
|
| 156 |
+
scores = self.predictor.predict(features)
|
| 157 |
+
return scores
|
| 158 |
+
|
| 159 |
+
def __call__(self, input_seqs: list):
|
| 160 |
+
scores = self.get_scores(input_seqs)
|
| 161 |
+
return scores
|
| 162 |
+
|
| 163 |
+
def unittest():
|
| 164 |
+
permeability = Permeability()
|
| 165 |
+
seq = ['N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1cNc2c1cc(O)cc2)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H]([C@@H](O)C(C)C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O']
|
| 166 |
+
scores = permeability(input_seqs=seq)
|
| 167 |
+
print(scores)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
if __name__ == '__main__':
|
| 171 |
+
unittest()
|
scoring/functions/solubility.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import xgboost as xgb
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import AutoModelForMaskedLM
|
| 5 |
+
import warnings
|
| 6 |
+
import numpy as np
|
| 7 |
+
from rdkit import rdBase
|
| 8 |
+
|
| 9 |
+
rdBase.DisableLog('rdApp.error')
|
| 10 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 11 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 12 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 13 |
+
|
| 14 |
+
class Solubility:
|
| 15 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 16 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 17 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/tr2d2-pep/scoring/functions/classifiers/solubility-xgboost.json')
|
| 18 |
+
if emb_model is not None:
|
| 19 |
+
self.emb_model = emb_model.to(self.device).eval()
|
| 20 |
+
else:
|
| 21 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
|
| 22 |
+
|
| 23 |
+
self.tokenizer = tokenizer
|
| 24 |
+
|
| 25 |
+
def generate_embeddings(self, sequences):
|
| 26 |
+
embeddings = []
|
| 27 |
+
for sequence in sequences:
|
| 28 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 29 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
output = self.emb_model(**tokenized)
|
| 32 |
+
# Mean pooling across sequence length
|
| 33 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 34 |
+
embeddings.append(embedding)
|
| 35 |
+
return np.array(embeddings)
|
| 36 |
+
|
| 37 |
+
def get_scores(self, input_seqs: list):
|
| 38 |
+
scores = np.zeros(len(input_seqs))
|
| 39 |
+
features = self.generate_embeddings(input_seqs)
|
| 40 |
+
|
| 41 |
+
if len(features) == 0:
|
| 42 |
+
return scores
|
| 43 |
+
|
| 44 |
+
features = np.nan_to_num(features, nan=0.)
|
| 45 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 46 |
+
|
| 47 |
+
features = xgb.DMatrix(features)
|
| 48 |
+
|
| 49 |
+
scores = self.predictor.predict(features)
|
| 50 |
+
return scores
|
| 51 |
+
|
| 52 |
+
def __call__(self, input_seqs: list):
|
| 53 |
+
scores = self.get_scores(input_seqs)
|
| 54 |
+
return scores
|
| 55 |
+
|
| 56 |
+
def unittest():
|
| 57 |
+
solubility = Solubility()
|
| 58 |
+
seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 59 |
+
scores = solubility(input_seqs=seq)
|
| 60 |
+
print(scores)
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
unittest()
|
scoring/scoring_functions.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 4 |
+
from transformers import AutoModelForMaskedLM
|
| 5 |
+
from scoring.functions.binding import BindingAffinity
|
| 6 |
+
from scoring.functions.permeability import Permeability
|
| 7 |
+
from scoring.functions.solubility import Solubility
|
| 8 |
+
from scoring.functions.hemolysis import Hemolysis
|
| 9 |
+
from scoring.functions.nonfouling import Nonfouling
|
| 10 |
+
|
| 11 |
+
base_path = 'To Be Added'
|
| 12 |
+
|
| 13 |
+
def resolve_device(requested):
|
| 14 |
+
if requested is None or str(requested).lower() == "auto":
|
| 15 |
+
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
| 16 |
+
return torch.device("cuda:0")
|
| 17 |
+
return torch.device("cpu")
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
device = torch.device(requested)
|
| 21 |
+
except Exception:
|
| 22 |
+
return torch.device("cpu")
|
| 23 |
+
|
| 24 |
+
if device.type != "cuda":
|
| 25 |
+
return device
|
| 26 |
+
|
| 27 |
+
if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
|
| 28 |
+
return torch.device("cpu")
|
| 29 |
+
|
| 30 |
+
index = device.index if device.index is not None else 0
|
| 31 |
+
if index is None or index < 0 or index >= torch.cuda.device_count():
|
| 32 |
+
return torch.device("cuda:0")
|
| 33 |
+
|
| 34 |
+
return torch.device(f"cuda:{index}")
|
| 35 |
+
|
| 36 |
+
class ScoringFunctions:
|
| 37 |
+
def __init__(self, score_func_names=None, prot_seqs=None, device=None):
|
| 38 |
+
"""
|
| 39 |
+
Class for generating score vectors given generated sequence
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
score_func_names: list of scoring function names to be evaluated
|
| 43 |
+
score_weights: weights to scale scores (default: 1)
|
| 44 |
+
target_protein: sequence of target protein binder
|
| 45 |
+
"""
|
| 46 |
+
device = resolve_device(device)
|
| 47 |
+
emb_model = AutoModelForMaskedLM.from_pretrained(
|
| 48 |
+
'aaronfeller/PeptideCLM-23M-all'
|
| 49 |
+
).roformer.to(device).eval()
|
| 50 |
+
tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/tr2d2-pep/tokenizer/new_vocab.txt',
|
| 51 |
+
f'{base_path}/tr2d2-pep/tokenizer/new_splits.txt')
|
| 52 |
+
prot_seqs = prot_seqs if prot_seqs is not None else []
|
| 53 |
+
|
| 54 |
+
if score_func_names is None:
|
| 55 |
+
# just do unmasking based on validity of peptide bonds
|
| 56 |
+
self.score_func_names = []
|
| 57 |
+
else:
|
| 58 |
+
self.score_func_names = score_func_names
|
| 59 |
+
|
| 60 |
+
# self.weights = np.array([1] * len(self.score_func_names) if score_weights is None else score_weights)
|
| 61 |
+
|
| 62 |
+
# binding affinities
|
| 63 |
+
self.target_protein = prot_seqs
|
| 64 |
+
print(len(prot_seqs))
|
| 65 |
+
|
| 66 |
+
if ('binding_affinity1' in score_func_names) and (len(prot_seqs) == 1):
|
| 67 |
+
binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
|
| 68 |
+
binding_affinity2 = None
|
| 69 |
+
elif ('binding_affinity1' in score_func_names) and ('binding_affinity2' in score_func_names) and (len(prot_seqs) == 2):
|
| 70 |
+
binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
|
| 71 |
+
binding_affinity2 = BindingAffinity(prot_seqs[1], tokenizer=tokenizer, base_path=base_path, device=device)
|
| 72 |
+
else:
|
| 73 |
+
print("here")
|
| 74 |
+
binding_affinity1 = None
|
| 75 |
+
binding_affinity2 = None
|
| 76 |
+
|
| 77 |
+
permeability = Permeability(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 78 |
+
sol = Solubility(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 79 |
+
nonfouling = Nonfouling(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 80 |
+
hemo = Hemolysis(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 81 |
+
|
| 82 |
+
self.all_funcs = {'binding_affinity1': binding_affinity1,
|
| 83 |
+
'binding_affinity2': binding_affinity2,
|
| 84 |
+
'permeability': permeability,
|
| 85 |
+
'nonfouling': nonfouling,
|
| 86 |
+
'solubility': sol,
|
| 87 |
+
'hemolysis': hemo
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
def forward(self, input_seqs):
|
| 91 |
+
scores = []
|
| 92 |
+
|
| 93 |
+
for i, score_func in enumerate(self.score_func_names):
|
| 94 |
+
score = self.all_funcs[score_func](input_seqs = input_seqs)
|
| 95 |
+
|
| 96 |
+
scores.append(score)
|
| 97 |
+
|
| 98 |
+
# convert to numpy arrays with shape (num_sequences, num_functions)
|
| 99 |
+
scores = np.float32(scores).T
|
| 100 |
+
|
| 101 |
+
return scores
|
| 102 |
+
|
| 103 |
+
def __call__(self, input_seqs: list):
|
| 104 |
+
return self.forward(input_seqs)
|
setup.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="td3b",
|
| 5 |
+
version="0.1.0",
|
| 6 |
+
description="TD3B: Transition-Directed Discrete Diffusion for Allosteric Binder Generation",
|
| 7 |
+
packages=find_packages(),
|
| 8 |
+
python_requires=">=3.10",
|
| 9 |
+
)
|
td3b/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD3B: Transition-Directed Discrete Diffusion for Binders
|
| 3 |
+
A module extending TR2-D2 with directional allosteric control.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .direction_oracle import DirectionalOracle
|
| 7 |
+
from .td3b_scoring import TD3BRewardFunction, TD3BConfidenceWeighting, create_td3b_reward_function
|
| 8 |
+
from .td3b_losses import ContrastiveLoss, InfoNCELoss, TD3BTotalLoss, extract_embeddings_from_mdlm
|
| 9 |
+
from .td3b_mcts import TD3B_MCTS, create_td3b_mcts
|
| 10 |
+
from .td3b_finetune import td3b_finetune, add_td3b_sampling_to_model
|
| 11 |
+
from .data_utils import TD3BDataset, load_td3b_data
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
'DirectionalOracle',
|
| 15 |
+
'TD3BRewardFunction',
|
| 16 |
+
'TD3BConfidenceWeighting',
|
| 17 |
+
'create_td3b_reward_function',
|
| 18 |
+
'ContrastiveLoss',
|
| 19 |
+
'InfoNCELoss',
|
| 20 |
+
'TD3BTotalLoss',
|
| 21 |
+
'extract_embeddings_from_mdlm',
|
| 22 |
+
'TD3B_MCTS',
|
| 23 |
+
'create_td3b_mcts',
|
| 24 |
+
'td3b_finetune',
|
| 25 |
+
'add_td3b_sampling_to_model',
|
| 26 |
+
'TD3BDataset',
|
| 27 |
+
'load_td3b_data',
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
__version__ = '0.1.0'
|
td3b/data_utils.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD3B Data Utilities
|
| 3 |
+
Handles loading and preprocessing of TD3B_data.csv for both oracle training and finetuning.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
from typing import Dict, List, Optional, Tuple
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from rdkit import Chem
|
| 15 |
+
except ImportError: # pragma: no cover - rdkit may be optional in some setups
|
| 16 |
+
Chem = None
|
| 17 |
+
|
| 18 |
+
sys.path.append('..')
|
| 19 |
+
|
| 20 |
+
AA_SET = set("ACDEFGHIKLMNPQRSTVWY")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def is_amino_acid_sequence(seq: str) -> bool:
|
| 24 |
+
if not isinstance(seq, str) or not seq:
|
| 25 |
+
return False
|
| 26 |
+
seq = seq.strip().upper()
|
| 27 |
+
return all(ch in AA_SET for ch in seq)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def aa_sequence_to_smiles(seq: str) -> Optional[str]:
|
| 31 |
+
if Chem is None or not is_amino_acid_sequence(seq):
|
| 32 |
+
return None
|
| 33 |
+
try:
|
| 34 |
+
mol = Chem.MolFromSequence(seq)
|
| 35 |
+
except Exception:
|
| 36 |
+
return None
|
| 37 |
+
if mol is None:
|
| 38 |
+
return None
|
| 39 |
+
return Chem.MolToSmiles(mol, isomericSmiles=True)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def peptide_seq_to_smiles(seq: str) -> str:
|
| 43 |
+
smiles = aa_sequence_to_smiles(seq)
|
| 44 |
+
return smiles if smiles is not None else seq
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def smiles_token_length(smiles: str, tokenizer) -> int:
|
| 48 |
+
if tokenizer is None:
|
| 49 |
+
return len(smiles)
|
| 50 |
+
tokens = tokenizer(smiles, return_tensors="pt")["input_ids"][0]
|
| 51 |
+
return int(tokens.numel())
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class TD3BDataset(Dataset):
|
| 55 |
+
"""
|
| 56 |
+
Dataset for TD3B that loads peptide-protein pairs with directional labels.
|
| 57 |
+
|
| 58 |
+
Supports both:
|
| 59 |
+
1. Oracle training: uses all pairs for training f_φ
|
| 60 |
+
2. Finetuning: provides target proteins for conditioning during RL
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
data_path: str,
|
| 66 |
+
mode: str = 'oracle', # 'oracle' or 'finetune'
|
| 67 |
+
peptide_tokenizer=None,
|
| 68 |
+
protein_tokenizer=None,
|
| 69 |
+
max_peptide_length: int = 200,
|
| 70 |
+
max_protein_length: int = 1000,
|
| 71 |
+
target_protein_id: Optional[str] = None, # For finetuning mode
|
| 72 |
+
convert_peptide_to_smiles: bool = True,
|
| 73 |
+
):
|
| 74 |
+
"""
|
| 75 |
+
Args:
|
| 76 |
+
data_path: Path to TD3B_data.csv
|
| 77 |
+
mode: 'oracle' for training f_φ, 'finetune' for RL conditioning
|
| 78 |
+
peptide_tokenizer: Tokenizer for peptide sequences
|
| 79 |
+
protein_tokenizer: Tokenizer for protein sequences (ESM-2)
|
| 80 |
+
max_peptide_length: Maximum peptide sequence length
|
| 81 |
+
max_protein_length: Maximum protein sequence length
|
| 82 |
+
target_protein_id: UniProt ID for target protein (finetuning mode)
|
| 83 |
+
"""
|
| 84 |
+
self.mode = mode
|
| 85 |
+
self.data_path = data_path
|
| 86 |
+
self.peptide_tokenizer = peptide_tokenizer
|
| 87 |
+
self.protein_tokenizer = protein_tokenizer
|
| 88 |
+
self.max_peptide_length = max_peptide_length
|
| 89 |
+
self.max_protein_length = max_protein_length
|
| 90 |
+
self.convert_peptide_to_smiles = convert_peptide_to_smiles
|
| 91 |
+
|
| 92 |
+
# Load data
|
| 93 |
+
self.data = pd.read_csv(data_path)
|
| 94 |
+
print(f"Loaded {len(self.data)} peptide-protein pairs from {data_path}")
|
| 95 |
+
|
| 96 |
+
# Filter by target protein if in finetune mode
|
| 97 |
+
if mode == 'finetune' and target_protein_id is not None:
|
| 98 |
+
self.data = self.data[self.data['Target_UniProt_ID'] == target_protein_id]
|
| 99 |
+
print(f"Filtered to {len(self.data)} pairs for target {target_protein_id}")
|
| 100 |
+
|
| 101 |
+
# Process labels
|
| 102 |
+
self.label_map = {
|
| 103 |
+
'agonist': 1.0,
|
| 104 |
+
'antagonist': -1.0,
|
| 105 |
+
'neutral': 0.0,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
# Convert action descriptions to numerical labels
|
| 109 |
+
self.data['numeric_label'] = self.data['label'].map(self.label_map)
|
| 110 |
+
|
| 111 |
+
# Assign confidence based on action description
|
| 112 |
+
self.data['confidence'] = self.data['Action'].apply(self._action_to_confidence)
|
| 113 |
+
|
| 114 |
+
def _action_to_confidence(self, action: str) -> float:
|
| 115 |
+
"""
|
| 116 |
+
Convert action description to confidence score.
|
| 117 |
+
|
| 118 |
+
Full agonist/antagonist: 1.0
|
| 119 |
+
Partial/Weak: 0.7
|
| 120 |
+
Others: 0.5
|
| 121 |
+
"""
|
| 122 |
+
action_lower = action.lower()
|
| 123 |
+
|
| 124 |
+
if 'full' in action_lower:
|
| 125 |
+
return 1.0
|
| 126 |
+
elif 'partial' in action_lower or 'weak' in action_lower:
|
| 127 |
+
return 0.7
|
| 128 |
+
elif 'slows' in action_lower or 'modulator' in action_lower:
|
| 129 |
+
return 0.5
|
| 130 |
+
else:
|
| 131 |
+
return 0.8 # Default for unspecified agonist/antagonist
|
| 132 |
+
|
| 133 |
+
def __len__(self):
|
| 134 |
+
return len(self.data)
|
| 135 |
+
|
| 136 |
+
def __getitem__(self, idx):
|
| 137 |
+
row = self.data.iloc[idx]
|
| 138 |
+
|
| 139 |
+
# Get sequences
|
| 140 |
+
peptide_seq = row['Ligand_Sequence']
|
| 141 |
+
protein_seq = row['Target_Sequence']
|
| 142 |
+
peptide_smiles = self._peptide_to_smiles(peptide_seq)
|
| 143 |
+
peptide_smiles_length = smiles_token_length(peptide_smiles, self.peptide_tokenizer)
|
| 144 |
+
|
| 145 |
+
# Tokenize (placeholder - actual tokenization depends on mode)
|
| 146 |
+
if self.peptide_tokenizer is not None:
|
| 147 |
+
peptide_tokens = self._tokenize_peptide(peptide_smiles)
|
| 148 |
+
else:
|
| 149 |
+
peptide_tokens = torch.zeros(self.max_peptide_length, dtype=torch.long)
|
| 150 |
+
|
| 151 |
+
if self.protein_tokenizer is not None:
|
| 152 |
+
protein_tokens = self._tokenize_protein(protein_seq)
|
| 153 |
+
else:
|
| 154 |
+
protein_tokens = self._tokenize_protein_placeholder(protein_seq)
|
| 155 |
+
|
| 156 |
+
# Get label and confidence
|
| 157 |
+
label = torch.tensor(row['numeric_label'], dtype=torch.float32)
|
| 158 |
+
confidence = torch.tensor(row['confidence'], dtype=torch.float32)
|
| 159 |
+
|
| 160 |
+
return {
|
| 161 |
+
'peptide_seq': peptide_seq,
|
| 162 |
+
'peptide_smiles': peptide_smiles,
|
| 163 |
+
'peptide_smiles_length': peptide_smiles_length,
|
| 164 |
+
'protein_seq': protein_seq,
|
| 165 |
+
'peptide_tokens': peptide_tokens,
|
| 166 |
+
'protein_tokens': protein_tokens,
|
| 167 |
+
'label': label,
|
| 168 |
+
'confidence': confidence,
|
| 169 |
+
'target_id': row['Target_UniProt_ID'],
|
| 170 |
+
'ligand_id': row['Ligand_UniProt_ID'],
|
| 171 |
+
'action': row['Action']
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
def _peptide_to_smiles(self, peptide_seq: str) -> str:
|
| 175 |
+
if not self.convert_peptide_to_smiles:
|
| 176 |
+
return peptide_seq
|
| 177 |
+
return peptide_seq_to_smiles(peptide_seq)
|
| 178 |
+
|
| 179 |
+
def _tokenize_peptide(self, peptide_seq: str) -> torch.Tensor:
|
| 180 |
+
"""Tokenize peptide sequence using provided tokenizer."""
|
| 181 |
+
tokens = self.peptide_tokenizer(
|
| 182 |
+
peptide_seq,
|
| 183 |
+
return_tensors='pt',
|
| 184 |
+
padding='max_length',
|
| 185 |
+
max_length=self.max_peptide_length,
|
| 186 |
+
truncation=True
|
| 187 |
+
)['input_ids'].squeeze(0)
|
| 188 |
+
return tokens
|
| 189 |
+
|
| 190 |
+
def _tokenize_protein_placeholder(self, protein_seq: str) -> torch.Tensor:
|
| 191 |
+
"""
|
| 192 |
+
Placeholder protein tokenizer (character-level).
|
| 193 |
+
|
| 194 |
+
NOTE: Replace with ESM-2 tokenizer in production:
|
| 195 |
+
from esm import pretrained
|
| 196 |
+
_, alphabet = pretrained.esm2_t33_650M_UR50D()
|
| 197 |
+
batch_converter = alphabet.get_batch_converter()
|
| 198 |
+
_, _, tokens = batch_converter([("protein", protein_seq)])
|
| 199 |
+
"""
|
| 200 |
+
# Amino acid to index mapping
|
| 201 |
+
aa_to_idx = {aa: i+1 for i, aa in enumerate('ACDEFGHIKLMNPQRSTVWY')}
|
| 202 |
+
aa_to_idx['<PAD>'] = 0
|
| 203 |
+
aa_to_idx['<UNK>'] = 21
|
| 204 |
+
|
| 205 |
+
# Convert to indices
|
| 206 |
+
indices = [aa_to_idx.get(aa, aa_to_idx['<UNK>']) for aa in protein_seq]
|
| 207 |
+
|
| 208 |
+
# Pad or truncate
|
| 209 |
+
if len(indices) > self.max_protein_length:
|
| 210 |
+
indices = indices[:self.max_protein_length]
|
| 211 |
+
else:
|
| 212 |
+
indices += [0] * (self.max_protein_length - len(indices))
|
| 213 |
+
|
| 214 |
+
return torch.tensor(indices, dtype=torch.long)
|
| 215 |
+
|
| 216 |
+
def _tokenize_protein(self, protein_seq: str) -> torch.Tensor:
|
| 217 |
+
"""Tokenize protein using ESM-2 tokenizer if available."""
|
| 218 |
+
if self.protein_tokenizer is None:
|
| 219 |
+
return self._tokenize_protein_placeholder(protein_seq)
|
| 220 |
+
|
| 221 |
+
# Use ESM-2 tokenizer
|
| 222 |
+
# TODO: Implement when ESM-2 is integrated
|
| 223 |
+
return self._tokenize_protein_placeholder(protein_seq)
|
| 224 |
+
|
| 225 |
+
def get_target_proteins(self) -> Dict[str, str]:
|
| 226 |
+
"""
|
| 227 |
+
Get dictionary of unique target proteins.
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
dict: {UniProt_ID: Sequence}
|
| 231 |
+
"""
|
| 232 |
+
unique_targets = self.data.drop_duplicates(subset=['Target_UniProt_ID'])
|
| 233 |
+
return dict(zip(unique_targets['Target_UniProt_ID'], unique_targets['Target_Sequence']))
|
| 234 |
+
|
| 235 |
+
def get_ligands_for_target(self, target_id: str) -> List[Dict]:
|
| 236 |
+
"""
|
| 237 |
+
Get all ligands (peptides) for a specific target protein.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
target_id: Target protein UniProt ID
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
List of dicts with ligand info
|
| 244 |
+
"""
|
| 245 |
+
target_data = self.data[self.data['Target_UniProt_ID'] == target_id]
|
| 246 |
+
|
| 247 |
+
ligands = []
|
| 248 |
+
for _, row in target_data.iterrows():
|
| 249 |
+
ligands.append({
|
| 250 |
+
'sequence': row['Ligand_Sequence'],
|
| 251 |
+
'uniprot_id': row['Ligand_UniProt_ID'],
|
| 252 |
+
'label': row['numeric_label'],
|
| 253 |
+
'confidence': row['confidence'],
|
| 254 |
+
'action': row['Action']
|
| 255 |
+
})
|
| 256 |
+
|
| 257 |
+
return ligands
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def load_td3b_data(
|
| 261 |
+
data_path: str,
|
| 262 |
+
mode: str = 'oracle',
|
| 263 |
+
target_protein_id: Optional[str] = None
|
| 264 |
+
) -> Tuple[pd.DataFrame, Dict]:
|
| 265 |
+
"""
|
| 266 |
+
Load and summarize TD3B data.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
data_path: Path to TD3B_data.csv
|
| 270 |
+
mode: 'oracle' or 'finetune'
|
| 271 |
+
target_protein_id: Filter by target protein (finetuning mode)
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
data: Filtered DataFrame
|
| 275 |
+
stats: Dictionary of statistics
|
| 276 |
+
"""
|
| 277 |
+
data = pd.read_csv(data_path)
|
| 278 |
+
|
| 279 |
+
# Filter if needed
|
| 280 |
+
if mode == 'finetune' and target_protein_id is not None:
|
| 281 |
+
data = data[data['Target_UniProt_ID'] == target_protein_id]
|
| 282 |
+
|
| 283 |
+
# Compute statistics
|
| 284 |
+
stats = {
|
| 285 |
+
'total_pairs': len(data),
|
| 286 |
+
'unique_targets': data['Target_UniProt_ID'].nunique(),
|
| 287 |
+
'unique_ligands': data['Ligand_UniProt_ID'].nunique(),
|
| 288 |
+
'agonist_count': (data['label'] == 'agonist').sum(),
|
| 289 |
+
'antagonist_count': (data['label'] == 'antagonist').sum(),
|
| 290 |
+
'action_distribution': data['Action'].value_counts().to_dict()
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
return data, stats
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def create_target_dataset_for_finetuning(
|
| 297 |
+
data_path: str,
|
| 298 |
+
target_protein_id: str,
|
| 299 |
+
desired_direction: str = 'agonist'
|
| 300 |
+
) -> Dict:
|
| 301 |
+
"""
|
| 302 |
+
Create a dataset for TD3B finetuning focused on a specific target.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
data_path: Path to TD3B_data.csv
|
| 306 |
+
target_protein_id: Target protein UniProt ID
|
| 307 |
+
desired_direction: 'agonist' or 'antagonist'
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
dict with target protein info and example ligands
|
| 311 |
+
"""
|
| 312 |
+
data = pd.read_csv(data_path)
|
| 313 |
+
|
| 314 |
+
# Get target protein info
|
| 315 |
+
target_data = data[data['Target_UniProt_ID'] == target_protein_id]
|
| 316 |
+
|
| 317 |
+
if len(target_data) == 0:
|
| 318 |
+
raise ValueError(f"No data found for target {target_protein_id}")
|
| 319 |
+
|
| 320 |
+
# Get protein sequence (should be same for all rows)
|
| 321 |
+
protein_seq = target_data.iloc[0]['Target_Sequence']
|
| 322 |
+
|
| 323 |
+
# Get ligands with desired direction
|
| 324 |
+
direction_map = {'agonist': 'agonist', 'antagonist': 'antagonist'}
|
| 325 |
+
direction_ligands = target_data[target_data['label'] == direction_map[desired_direction]]
|
| 326 |
+
|
| 327 |
+
# Also get opposite direction for contrastive learning
|
| 328 |
+
opposite_direction = 'antagonist' if desired_direction == 'agonist' else 'agonist'
|
| 329 |
+
opposite_ligands = target_data[target_data['label'] == opposite_direction]
|
| 330 |
+
|
| 331 |
+
return {
|
| 332 |
+
'target_protein_id': target_protein_id,
|
| 333 |
+
'target_protein_seq': protein_seq,
|
| 334 |
+
'desired_direction': desired_direction,
|
| 335 |
+
'n_desired_examples': len(direction_ligands),
|
| 336 |
+
'n_opposite_examples': len(opposite_ligands),
|
| 337 |
+
'desired_ligands': direction_ligands[['Ligand_Sequence', 'Action', 'Ligand_UniProt_ID']].to_dict('records'),
|
| 338 |
+
'opposite_ligands': opposite_ligands[['Ligand_Sequence', 'Action', 'Ligand_UniProt_ID']].to_dict('records')
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
if __name__ == "__main__":
|
| 343 |
+
# Example usage
|
| 344 |
+
data_path = "../TD3B_data.csv"
|
| 345 |
+
|
| 346 |
+
print("=" * 80)
|
| 347 |
+
print("TD3B Data Loading Example")
|
| 348 |
+
print("=" * 80)
|
| 349 |
+
|
| 350 |
+
# Load and summarize data
|
| 351 |
+
data, stats = load_td3b_data(data_path, mode='oracle')
|
| 352 |
+
|
| 353 |
+
print("\nDataset Statistics:")
|
| 354 |
+
for key, value in stats.items():
|
| 355 |
+
print(f" {key}: {value}")
|
| 356 |
+
|
| 357 |
+
# Create dataset for oracle training
|
| 358 |
+
print("\n" + "=" * 80)
|
| 359 |
+
print("Oracle Training Dataset")
|
| 360 |
+
print("=" * 80)
|
| 361 |
+
|
| 362 |
+
dataset = TD3BDataset(data_path, mode='oracle')
|
| 363 |
+
print(f"Dataset size: {len(dataset)}")
|
| 364 |
+
|
| 365 |
+
# Sample first item
|
| 366 |
+
sample = dataset[0]
|
| 367 |
+
print(f"\nSample item:")
|
| 368 |
+
print(f" Target: {sample['target_id']}")
|
| 369 |
+
print(f" Ligand: {sample['ligand_id']}")
|
| 370 |
+
print(f" Label: {sample['label'].item()}")
|
| 371 |
+
print(f" Confidence: {sample['confidence'].item()}")
|
| 372 |
+
print(f" Action: {sample['action']}")
|
| 373 |
+
|
| 374 |
+
# Create finetuning dataset for a specific target
|
| 375 |
+
print("\n" + "=" * 80)
|
| 376 |
+
print("Finetuning Dataset Example")
|
| 377 |
+
print("=" * 80)
|
| 378 |
+
|
| 379 |
+
# Get first target
|
| 380 |
+
targets = dataset.get_target_proteins()
|
| 381 |
+
first_target_id = list(targets.keys())[0]
|
| 382 |
+
|
| 383 |
+
finetune_info = create_target_dataset_for_finetuning(
|
| 384 |
+
data_path,
|
| 385 |
+
first_target_id,
|
| 386 |
+
desired_direction='agonist'
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
print(f"\nTarget: {finetune_info['target_protein_id']}")
|
| 390 |
+
print(f"Desired direction: {finetune_info['desired_direction']}")
|
| 391 |
+
print(f"Number of agonist examples: {finetune_info['n_desired_examples']}")
|
| 392 |
+
print(f"Number of antagonist examples: {finetune_info['n_opposite_examples']}")
|
td3b/direction_oracle.py
ADDED
|
@@ -0,0 +1,709 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
GPCR Agonist Classifier - TR2-D2 Inference Script
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
from types import SimpleNamespace
|
| 11 |
+
from typing import Dict, List, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from transformers import EsmModel, EsmTokenizer
|
| 17 |
+
|
| 18 |
+
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 19 |
+
if PROJECT_ROOT not in sys.path:
|
| 20 |
+
sys.path.insert(0, PROJECT_ROOT)
|
| 21 |
+
|
| 22 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 23 |
+
from roformer import Roformer
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def resolve_device(requested: Optional[str]) -> torch.device:
|
| 29 |
+
if requested is None or str(requested).lower() == "auto":
|
| 30 |
+
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
| 31 |
+
return torch.device("cuda:0")
|
| 32 |
+
return torch.device("cpu")
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
device = torch.device(requested)
|
| 36 |
+
except Exception as exc:
|
| 37 |
+
logger.warning("Invalid device '%s': %s. Falling back to CPU.", requested, exc)
|
| 38 |
+
return torch.device("cpu")
|
| 39 |
+
if device.type != "cuda":
|
| 40 |
+
return device
|
| 41 |
+
|
| 42 |
+
if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
|
| 43 |
+
logger.warning("CUDA requested but not available; falling back to CPU")
|
| 44 |
+
return torch.device("cpu")
|
| 45 |
+
|
| 46 |
+
index = device.index if device.index is not None else 0
|
| 47 |
+
count = torch.cuda.device_count()
|
| 48 |
+
if index is None or index < 0 or index >= count:
|
| 49 |
+
logger.warning(
|
| 50 |
+
"CUDA device %s requested but only %d visible; using cuda:0",
|
| 51 |
+
index,
|
| 52 |
+
count
|
| 53 |
+
)
|
| 54 |
+
return torch.device("cuda:0")
|
| 55 |
+
|
| 56 |
+
return torch.device(f"cuda:{index}")
|
| 57 |
+
|
| 58 |
+
# -------------------------
|
| 59 |
+
# Peptide to SMILES
|
| 60 |
+
# -------------------------
|
| 61 |
+
def peptide_to_smiles(seq: str) -> str:
|
| 62 |
+
from rdkit import Chem
|
| 63 |
+
seq = seq.strip().upper()
|
| 64 |
+
mol = Chem.MolFromSequence(seq)
|
| 65 |
+
if mol is None:
|
| 66 |
+
raise ValueError(f"RDKit failed to convert peptide '{seq}' to SMILES")
|
| 67 |
+
return Chem.MolToSmiles(mol)
|
| 68 |
+
|
| 69 |
+
# -------------------------
|
| 70 |
+
# Self-Attention Block
|
| 71 |
+
# -------------------------
|
| 72 |
+
class SelfAttentionBlock(nn.Module):
|
| 73 |
+
def __init__(self, d_model, n_heads, dropout=0.1):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout, batch_first=True)
|
| 76 |
+
self.norm = nn.LayerNorm(d_model)
|
| 77 |
+
self.dropout = nn.Dropout(dropout)
|
| 78 |
+
|
| 79 |
+
def forward(self, x, key_padding_mask=None):
|
| 80 |
+
attn_out, _ = self.self_attn(x, x, x, key_padding_mask=key_padding_mask)
|
| 81 |
+
x = self.norm(x + self.dropout(attn_out))
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
# -------------------------
|
| 85 |
+
# Cross-Attention Module
|
| 86 |
+
# -------------------------
|
| 87 |
+
class BiMultiHeadCrossAttention(nn.Module):
|
| 88 |
+
def __init__(self, d_model, n_heads, dropout=0.1):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.prot_to_lig = nn.MultiheadAttention(d_model, n_heads, dropout, batch_first=True)
|
| 91 |
+
self.lig_to_prot = nn.MultiheadAttention(d_model, n_heads, dropout, batch_first=True)
|
| 92 |
+
self.prot_ln = nn.LayerNorm(d_model)
|
| 93 |
+
self.lig_ln = nn.LayerNorm(d_model)
|
| 94 |
+
self.dropout = nn.Dropout(dropout)
|
| 95 |
+
|
| 96 |
+
def forward(self, prot_h, lig_h, prot_kpm=None, lig_kpm=None):
|
| 97 |
+
prot_ctx, _ = self.prot_to_lig(prot_h, lig_h, lig_h, key_padding_mask=lig_kpm)
|
| 98 |
+
prot_h_out = self.prot_ln(prot_h + self.dropout(prot_ctx))
|
| 99 |
+
|
| 100 |
+
lig_ctx, _ = self.lig_to_prot(lig_h, prot_h, prot_h, key_padding_mask=prot_kpm)
|
| 101 |
+
lig_h_out = self.lig_ln(lig_h + self.dropout(lig_ctx))
|
| 102 |
+
|
| 103 |
+
return prot_h_out, lig_h_out
|
| 104 |
+
|
| 105 |
+
# -------------------------
|
| 106 |
+
# TR2-D2 Encoder Wrapper
|
| 107 |
+
# -------------------------
|
| 108 |
+
class TR2D2RoFormerEncoder(nn.Module):
|
| 109 |
+
def __init__(self, config, tokenizer, checkpoint_path=None, device="cpu"):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.device = device
|
| 112 |
+
self.encoder = Roformer(config, tokenizer, device=device)
|
| 113 |
+
|
| 114 |
+
if checkpoint_path:
|
| 115 |
+
print(f" Loading TR2-D2 checkpoint...")
|
| 116 |
+
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 117 |
+
state_dict = ckpt.get("state_dict", ckpt)
|
| 118 |
+
roformer_state = {
|
| 119 |
+
k.replace("model.", "").replace("backbone.", ""): v
|
| 120 |
+
for k, v in state_dict.items()
|
| 121 |
+
if "roformer" in k or "encoder" in k or "backbone" in k
|
| 122 |
+
}
|
| 123 |
+
self.encoder.model.load_state_dict(roformer_state, strict=False)
|
| 124 |
+
print(" TR2-D2 checkpoint loaded")
|
| 125 |
+
|
| 126 |
+
for p in self.encoder.parameters():
|
| 127 |
+
p.requires_grad = False
|
| 128 |
+
self.encoder.eval()
|
| 129 |
+
|
| 130 |
+
def forward(self, input_ids, attention_mask, inputs_embeds=None):
|
| 131 |
+
if attention_mask is None:
|
| 132 |
+
raise ValueError("attention_mask is required for ligand encoding.")
|
| 133 |
+
attention_mask = attention_mask.to(self.device)
|
| 134 |
+
if inputs_embeds is not None:
|
| 135 |
+
inputs_embeds = inputs_embeds.to(self.device)
|
| 136 |
+
out = self.encoder.model.roformer(
|
| 137 |
+
inputs_embeds=inputs_embeds,
|
| 138 |
+
attention_mask=attention_mask
|
| 139 |
+
)
|
| 140 |
+
else:
|
| 141 |
+
input_ids = input_ids.to(self.device)
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
out = self.encoder.model.roformer(
|
| 144 |
+
input_ids=input_ids,
|
| 145 |
+
attention_mask=attention_mask
|
| 146 |
+
)
|
| 147 |
+
return out.last_hidden_state
|
| 148 |
+
|
| 149 |
+
# -------------------------
|
| 150 |
+
# Full GPCR Model
|
| 151 |
+
# -------------------------
|
| 152 |
+
class ESM_TR2D2_GPCRClassifier(nn.Module):
|
| 153 |
+
"""
|
| 154 |
+
GPCR Agonist Classifier with TR2-D2
|
| 155 |
+
|
| 156 |
+
Architecture:
|
| 157 |
+
1. ESM2 (protein) + TR2-D2 RoFormer (ligand)
|
| 158 |
+
2. Projections to common dimension
|
| 159 |
+
3. Self-Attention (1 layer each)
|
| 160 |
+
4. BiDirectional Cross-Attention (2 stacked layers)
|
| 161 |
+
5. Masked Average Pooling
|
| 162 |
+
6. MLP Classifier
|
| 163 |
+
"""
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
esm_name,
|
| 167 |
+
tr2d2_config,
|
| 168 |
+
lig_tokenizer,
|
| 169 |
+
tr2d2_checkpoint=None,
|
| 170 |
+
d_model=256,
|
| 171 |
+
n_heads=4,
|
| 172 |
+
n_self_attn_layers=1,
|
| 173 |
+
n_bmca_layers=2,
|
| 174 |
+
dropout=0.3,
|
| 175 |
+
device="cuda",
|
| 176 |
+
esm_cache_dir=None,
|
| 177 |
+
esm_local_files_only=False
|
| 178 |
+
):
|
| 179 |
+
super().__init__()
|
| 180 |
+
self.device = device
|
| 181 |
+
|
| 182 |
+
# Frozen encoders
|
| 183 |
+
print("Loading ESM2 protein encoder...")
|
| 184 |
+
self.esm = EsmModel.from_pretrained(
|
| 185 |
+
esm_name,
|
| 186 |
+
cache_dir=esm_cache_dir,
|
| 187 |
+
local_files_only=esm_local_files_only
|
| 188 |
+
)
|
| 189 |
+
for p in self.esm.parameters():
|
| 190 |
+
p.requires_grad = False
|
| 191 |
+
self.esm.eval()
|
| 192 |
+
|
| 193 |
+
print("Loading TR2-D2 ligand encoder...")
|
| 194 |
+
self.ligand_encoder = TR2D2RoFormerEncoder(
|
| 195 |
+
tr2d2_config, lig_tokenizer, tr2d2_checkpoint, device
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
esm_dim = self.esm.config.hidden_size
|
| 199 |
+
lig_dim = tr2d2_config.roformer.hidden_size
|
| 200 |
+
|
| 201 |
+
self.prot_proj = nn.Linear(esm_dim, d_model)
|
| 202 |
+
self.lig_proj = nn.Linear(lig_dim, d_model)
|
| 203 |
+
|
| 204 |
+
# Self-attention
|
| 205 |
+
self.prot_self_attn_layers = nn.ModuleList([
|
| 206 |
+
SelfAttentionBlock(d_model, n_heads, dropout)
|
| 207 |
+
for _ in range(n_self_attn_layers)
|
| 208 |
+
])
|
| 209 |
+
self.lig_self_attn_layers = nn.ModuleList([
|
| 210 |
+
SelfAttentionBlock(d_model, n_heads, dropout)
|
| 211 |
+
for _ in range(n_self_attn_layers)
|
| 212 |
+
])
|
| 213 |
+
|
| 214 |
+
# Cross-attention
|
| 215 |
+
self.bmca_layers = nn.ModuleList([
|
| 216 |
+
BiMultiHeadCrossAttention(d_model, n_heads, dropout)
|
| 217 |
+
for _ in range(n_bmca_layers)
|
| 218 |
+
])
|
| 219 |
+
|
| 220 |
+
# Classifier
|
| 221 |
+
self.classifier = nn.Sequential(
|
| 222 |
+
nn.Linear(2 * d_model, d_model),
|
| 223 |
+
nn.ReLU(),
|
| 224 |
+
nn.Dropout(dropout),
|
| 225 |
+
nn.Linear(d_model, 2)
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def forward(self, prot_tokens, lig_tokens, lig_inputs_embeds=None):
|
| 229 |
+
prot_kpm = prot_tokens["attention_mask"].eq(0)
|
| 230 |
+
lig_kpm = lig_tokens["attention_mask"].eq(0)
|
| 231 |
+
|
| 232 |
+
with torch.no_grad():
|
| 233 |
+
prot_out = self.esm(**prot_tokens).last_hidden_state
|
| 234 |
+
|
| 235 |
+
lig_out = self.ligand_encoder(
|
| 236 |
+
lig_tokens["input_ids"],
|
| 237 |
+
lig_tokens["attention_mask"],
|
| 238 |
+
inputs_embeds=lig_inputs_embeds
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
prot_h = self.prot_proj(prot_out)
|
| 242 |
+
lig_h = self.lig_proj(lig_out)
|
| 243 |
+
|
| 244 |
+
# Self-attention
|
| 245 |
+
for self_attn in self.prot_self_attn_layers:
|
| 246 |
+
prot_h = self_attn(prot_h, key_padding_mask=prot_kpm)
|
| 247 |
+
for self_attn in self.lig_self_attn_layers:
|
| 248 |
+
lig_h = self_attn(lig_h, key_padding_mask=lig_kpm)
|
| 249 |
+
|
| 250 |
+
# Cross-attention (2 stacked)
|
| 251 |
+
for bmca in self.bmca_layers:
|
| 252 |
+
prot_h, lig_h = bmca(prot_h, lig_h, prot_kpm=prot_kpm, lig_kpm=lig_kpm)
|
| 253 |
+
|
| 254 |
+
# Masked average pooling
|
| 255 |
+
prot_mask = prot_tokens["attention_mask"].unsqueeze(-1)
|
| 256 |
+
lig_mask = lig_tokens["attention_mask"].unsqueeze(-1)
|
| 257 |
+
|
| 258 |
+
prot_repr = (prot_h * prot_mask).sum(dim=1) / prot_mask.sum(dim=1).clamp(min=1)
|
| 259 |
+
lig_repr = (lig_h * lig_mask).sum(dim=1) / lig_mask.sum(dim=1).clamp(min=1)
|
| 260 |
+
|
| 261 |
+
return self.classifier(torch.cat([prot_repr, lig_repr], dim=-1))
|
| 262 |
+
|
| 263 |
+
# -------------------------
|
| 264 |
+
# Tokenization
|
| 265 |
+
# -------------------------
|
| 266 |
+
def create_tr2d2_config(vocab_size):
|
| 267 |
+
return SimpleNamespace(
|
| 268 |
+
roformer=SimpleNamespace(
|
| 269 |
+
vocab_size=vocab_size,
|
| 270 |
+
hidden_size=768,
|
| 271 |
+
n_layers=8,
|
| 272 |
+
n_heads=8,
|
| 273 |
+
max_position_embeddings=1035
|
| 274 |
+
)
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def _load_state_dict_flexible(model: nn.Module, state_dict: Dict, strict: bool = True) -> None:
|
| 279 |
+
try:
|
| 280 |
+
model.load_state_dict(state_dict, strict=strict)
|
| 281 |
+
return
|
| 282 |
+
except RuntimeError as exc:
|
| 283 |
+
model_keys = set(model.state_dict().keys())
|
| 284 |
+
filtered = {k: v for k, v in state_dict.items() if k in model_keys}
|
| 285 |
+
logger.warning("Strict load failed: %s", exc)
|
| 286 |
+
logger.warning(
|
| 287 |
+
"Retrying with filtered keys (%d/%d) and strict=False",
|
| 288 |
+
len(filtered),
|
| 289 |
+
len(state_dict)
|
| 290 |
+
)
|
| 291 |
+
incompatible = model.load_state_dict(filtered, strict=False)
|
| 292 |
+
if incompatible.missing_keys:
|
| 293 |
+
logger.warning("Missing keys (first 10): %s", incompatible.missing_keys[:10])
|
| 294 |
+
if incompatible.unexpected_keys:
|
| 295 |
+
logger.warning("Unexpected keys (first 10): %s", incompatible.unexpected_keys[:10])
|
| 296 |
+
|
| 297 |
+
def tokenize_protein(seq, tokenizer, device):
|
| 298 |
+
out = tokenizer(
|
| 299 |
+
seq,
|
| 300 |
+
return_tensors="pt",
|
| 301 |
+
padding=True,
|
| 302 |
+
truncation=True,
|
| 303 |
+
max_length=1024,
|
| 304 |
+
add_special_tokens=True
|
| 305 |
+
)
|
| 306 |
+
return {k: v.to(device) for k, v in out.items()}
|
| 307 |
+
|
| 308 |
+
def tokenize_ligand(smiles, tokenizer, max_len, device):
|
| 309 |
+
enc = tokenizer(
|
| 310 |
+
smiles,
|
| 311 |
+
return_tensors="pt",
|
| 312 |
+
truncation=True,
|
| 313 |
+
max_length=max_len,
|
| 314 |
+
add_special_tokens=True
|
| 315 |
+
)
|
| 316 |
+
ids = enc["input_ids"].squeeze(0)
|
| 317 |
+
att = enc["attention_mask"].squeeze(0)
|
| 318 |
+
|
| 319 |
+
pad = max_len - ids.numel()
|
| 320 |
+
if pad > 0:
|
| 321 |
+
ids = torch.cat([ids, torch.full((pad,), tokenizer.pad_token_id)])
|
| 322 |
+
att = torch.cat([att, torch.zeros(pad)])
|
| 323 |
+
|
| 324 |
+
return {
|
| 325 |
+
"input_ids": ids.unsqueeze(0).to(device),
|
| 326 |
+
"attention_mask": att.unsqueeze(0).to(device)
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
# -------------------------
|
| 330 |
+
# Training-Compatible Oracle Wrapper
|
| 331 |
+
# -------------------------
|
| 332 |
+
class DirectionalOracle(nn.Module):
|
| 333 |
+
"""
|
| 334 |
+
Batch-capable oracle wrapper with TD3B-compatible predict_with_confidence().
|
| 335 |
+
|
| 336 |
+
This class is intended for training integration where peptide/protein tokens
|
| 337 |
+
are provided directly (batched) and the oracle runs in inference-only mode.
|
| 338 |
+
"""
|
| 339 |
+
def __init__(
|
| 340 |
+
self,
|
| 341 |
+
model_ckpt: str,
|
| 342 |
+
tr2d2_checkpoint: str,
|
| 343 |
+
tokenizer_vocab: str,
|
| 344 |
+
tokenizer_splits: str,
|
| 345 |
+
esm_name: str = "facebook/esm2_t33_650M_UR50D",
|
| 346 |
+
d_model: int = 256,
|
| 347 |
+
n_heads: int = 4,
|
| 348 |
+
n_self_attn_layers: int = 1,
|
| 349 |
+
n_bmca_layers: int = 2,
|
| 350 |
+
dropout: float = 0.3,
|
| 351 |
+
max_ligand_length: int = 768,
|
| 352 |
+
max_protein_length: int = 1024,
|
| 353 |
+
device: Optional[str] = None,
|
| 354 |
+
esm_cache_dir: Optional[str] = None,
|
| 355 |
+
esm_local_files_only: bool = False
|
| 356 |
+
):
|
| 357 |
+
super().__init__()
|
| 358 |
+
|
| 359 |
+
if isinstance(device, torch.device):
|
| 360 |
+
device = str(device)
|
| 361 |
+
self.device = resolve_device(device)
|
| 362 |
+
|
| 363 |
+
self.max_ligand_length = max_ligand_length
|
| 364 |
+
self.max_protein_length = max_protein_length
|
| 365 |
+
self._warned_ligand_truncation = False
|
| 366 |
+
self._warned_protein_truncation = False
|
| 367 |
+
|
| 368 |
+
self.lig_tokenizer = SMILES_SPE_Tokenizer(tokenizer_vocab, tokenizer_splits)
|
| 369 |
+
self.prot_tokenizer = EsmTokenizer.from_pretrained(
|
| 370 |
+
esm_name,
|
| 371 |
+
cache_dir=esm_cache_dir,
|
| 372 |
+
local_files_only=esm_local_files_only
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
tr2d2_cfg = create_tr2d2_config(self.lig_tokenizer.vocab_size)
|
| 376 |
+
self.model = ESM_TR2D2_GPCRClassifier(
|
| 377 |
+
esm_name=esm_name,
|
| 378 |
+
tr2d2_config=tr2d2_cfg,
|
| 379 |
+
lig_tokenizer=self.lig_tokenizer,
|
| 380 |
+
tr2d2_checkpoint=tr2d2_checkpoint,
|
| 381 |
+
d_model=d_model,
|
| 382 |
+
n_heads=n_heads,
|
| 383 |
+
n_self_attn_layers=n_self_attn_layers,
|
| 384 |
+
n_bmca_layers=n_bmca_layers,
|
| 385 |
+
dropout=dropout,
|
| 386 |
+
device=self.device,
|
| 387 |
+
esm_cache_dir=esm_cache_dir,
|
| 388 |
+
esm_local_files_only=esm_local_files_only
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
state_dict = torch.load(model_ckpt, map_location=self.device, weights_only=False)
|
| 392 |
+
if isinstance(state_dict, dict) and "model_state_dict" in state_dict:
|
| 393 |
+
state_dict = state_dict["model_state_dict"]
|
| 394 |
+
_load_state_dict_flexible(self.model, state_dict, strict=True)
|
| 395 |
+
self.model.to(self.device).eval()
|
| 396 |
+
|
| 397 |
+
for param in self.model.parameters():
|
| 398 |
+
param.requires_grad = False
|
| 399 |
+
|
| 400 |
+
self._lig_pad_token_id = self.lig_tokenizer.pad_token_id
|
| 401 |
+
if self._lig_pad_token_id is None:
|
| 402 |
+
self._lig_pad_token_id = 0
|
| 403 |
+
self._prot_pad_token_id = self.prot_tokenizer.pad_token_id
|
| 404 |
+
if self._prot_pad_token_id is None:
|
| 405 |
+
self._prot_pad_token_id = 0
|
| 406 |
+
|
| 407 |
+
def encode_protein(self, protein_seq: str) -> torch.Tensor:
|
| 408 |
+
tokens = self.prot_tokenizer(
|
| 409 |
+
protein_seq,
|
| 410 |
+
return_tensors="pt",
|
| 411 |
+
padding=True,
|
| 412 |
+
truncation=True,
|
| 413 |
+
max_length=self.max_protein_length,
|
| 414 |
+
add_special_tokens=True
|
| 415 |
+
)
|
| 416 |
+
return tokens["input_ids"].to(self.device)
|
| 417 |
+
|
| 418 |
+
def _normalize_token_dict(
|
| 419 |
+
self,
|
| 420 |
+
tokens: torch.Tensor,
|
| 421 |
+
pad_token_id: int,
|
| 422 |
+
max_length: int,
|
| 423 |
+
warned_attr: str
|
| 424 |
+
) -> Dict[str, torch.Tensor]:
|
| 425 |
+
if isinstance(tokens, dict):
|
| 426 |
+
input_ids = tokens.get("input_ids")
|
| 427 |
+
if input_ids is None:
|
| 428 |
+
raise ValueError("Token dict must include input_ids.")
|
| 429 |
+
attention_mask = tokens.get("attention_mask")
|
| 430 |
+
input_ids = input_ids.to(self.device)
|
| 431 |
+
if attention_mask is None:
|
| 432 |
+
attention_mask = (input_ids != pad_token_id).long()
|
| 433 |
+
else:
|
| 434 |
+
attention_mask = attention_mask.to(self.device)
|
| 435 |
+
else:
|
| 436 |
+
input_ids = tokens
|
| 437 |
+
if input_ids.dim() == 1:
|
| 438 |
+
input_ids = input_ids.unsqueeze(0)
|
| 439 |
+
input_ids = input_ids.to(self.device)
|
| 440 |
+
attention_mask = (input_ids != pad_token_id).long()
|
| 441 |
+
|
| 442 |
+
if max_length is not None and input_ids.size(1) > max_length:
|
| 443 |
+
if not getattr(self, warned_attr):
|
| 444 |
+
logger.warning(
|
| 445 |
+
"Truncating input from length %d to max_length=%d",
|
| 446 |
+
input_ids.size(1),
|
| 447 |
+
max_length
|
| 448 |
+
)
|
| 449 |
+
setattr(self, warned_attr, True)
|
| 450 |
+
input_ids = input_ids[:, :max_length]
|
| 451 |
+
attention_mask = attention_mask[:, :max_length]
|
| 452 |
+
|
| 453 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 454 |
+
|
| 455 |
+
def _normalize_prob_inputs(
|
| 456 |
+
self,
|
| 457 |
+
probs: torch.Tensor,
|
| 458 |
+
attention_mask: Optional[torch.Tensor],
|
| 459 |
+
max_length: int,
|
| 460 |
+
warned_attr: str,
|
| 461 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 462 |
+
if probs.dim() == 2:
|
| 463 |
+
probs = probs.unsqueeze(0)
|
| 464 |
+
probs = probs.to(self.device)
|
| 465 |
+
if attention_mask is None:
|
| 466 |
+
attention_mask = torch.ones(
|
| 467 |
+
probs.size(0), probs.size(1), device=self.device, dtype=torch.long
|
| 468 |
+
)
|
| 469 |
+
else:
|
| 470 |
+
if attention_mask.dim() == 1:
|
| 471 |
+
attention_mask = attention_mask.unsqueeze(0)
|
| 472 |
+
attention_mask = attention_mask.to(self.device).long()
|
| 473 |
+
|
| 474 |
+
if max_length is not None and probs.size(1) > max_length:
|
| 475 |
+
if not getattr(self, warned_attr):
|
| 476 |
+
logger.warning(
|
| 477 |
+
"Truncating input from length %d to max_length=%d",
|
| 478 |
+
probs.size(1),
|
| 479 |
+
max_length
|
| 480 |
+
)
|
| 481 |
+
setattr(self, warned_attr, True)
|
| 482 |
+
probs = probs[:, :max_length]
|
| 483 |
+
attention_mask = attention_mask[:, :max_length]
|
| 484 |
+
|
| 485 |
+
return probs, attention_mask
|
| 486 |
+
|
| 487 |
+
@torch.no_grad()
|
| 488 |
+
def predict_with_confidence(
|
| 489 |
+
self,
|
| 490 |
+
peptide_tokens: torch.Tensor,
|
| 491 |
+
protein_tokens: torch.Tensor
|
| 492 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 493 |
+
lig_tokens = self._normalize_token_dict(
|
| 494 |
+
peptide_tokens,
|
| 495 |
+
self._lig_pad_token_id,
|
| 496 |
+
self.max_ligand_length,
|
| 497 |
+
"_warned_ligand_truncation"
|
| 498 |
+
)
|
| 499 |
+
prot_tokens = self._normalize_token_dict(
|
| 500 |
+
protein_tokens,
|
| 501 |
+
self._prot_pad_token_id,
|
| 502 |
+
self.max_protein_length,
|
| 503 |
+
"_warned_protein_truncation"
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
lig_batch = lig_tokens["input_ids"].size(0)
|
| 507 |
+
prot_batch = prot_tokens["input_ids"].size(0)
|
| 508 |
+
if prot_batch == 1 and lig_batch > 1:
|
| 509 |
+
prot_tokens = {k: v.expand(lig_batch, -1) for k, v in prot_tokens.items()}
|
| 510 |
+
elif prot_batch != lig_batch:
|
| 511 |
+
raise ValueError(
|
| 512 |
+
f"Batch size mismatch: peptide_tokens={lig_batch}, protein_tokens={prot_batch}"
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
logits = self.model(prot_tokens, lig_tokens)
|
| 516 |
+
probs = F.softmax(logits, dim=-1)
|
| 517 |
+
p_agonist = probs[:, 1]
|
| 518 |
+
confidence = torch.max(probs, dim=-1).values
|
| 519 |
+
return p_agonist, confidence
|
| 520 |
+
|
| 521 |
+
def predict_from_probs(
|
| 522 |
+
self,
|
| 523 |
+
ligand_probs: torch.Tensor,
|
| 524 |
+
protein_tokens: torch.Tensor,
|
| 525 |
+
ligand_attention_mask: Optional[torch.Tensor] = None,
|
| 526 |
+
) -> torch.Tensor:
|
| 527 |
+
lig_probs, lig_attention = self._normalize_prob_inputs(
|
| 528 |
+
ligand_probs,
|
| 529 |
+
ligand_attention_mask,
|
| 530 |
+
self.max_ligand_length,
|
| 531 |
+
"_warned_ligand_truncation",
|
| 532 |
+
)
|
| 533 |
+
prot_tokens = self._normalize_token_dict(
|
| 534 |
+
protein_tokens,
|
| 535 |
+
self._prot_pad_token_id,
|
| 536 |
+
self.max_protein_length,
|
| 537 |
+
"_warned_protein_truncation"
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
lig_batch = lig_probs.size(0)
|
| 541 |
+
prot_batch = prot_tokens["input_ids"].size(0)
|
| 542 |
+
if prot_batch == 1 and lig_batch > 1:
|
| 543 |
+
prot_tokens = {k: v.expand(lig_batch, -1) for k, v in prot_tokens.items()}
|
| 544 |
+
elif prot_batch != lig_batch:
|
| 545 |
+
raise ValueError(
|
| 546 |
+
f"Batch size mismatch: ligand_probs={lig_batch}, protein_tokens={prot_batch}"
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
emb_weight = self.model.ligand_encoder.encoder.model.roformer.embeddings.word_embeddings.weight
|
| 550 |
+
if lig_probs.size(-1) != emb_weight.size(0):
|
| 551 |
+
raise ValueError(
|
| 552 |
+
f"Ligand vocab mismatch: probs={lig_probs.size(-1)} vs oracle={emb_weight.size(0)}"
|
| 553 |
+
)
|
| 554 |
+
lig_inputs_embeds = lig_probs @ emb_weight
|
| 555 |
+
lig_input_ids = torch.zeros(
|
| 556 |
+
lig_probs.size(0), lig_probs.size(1), device=lig_probs.device, dtype=torch.long
|
| 557 |
+
)
|
| 558 |
+
lig_tokens = {"input_ids": lig_input_ids, "attention_mask": lig_attention}
|
| 559 |
+
logits = self.model(prot_tokens, lig_tokens, lig_inputs_embeds=lig_inputs_embeds)
|
| 560 |
+
probs = F.softmax(logits, dim=-1)
|
| 561 |
+
return probs[:, 1]
|
| 562 |
+
|
| 563 |
+
# -------------------------
|
| 564 |
+
# Prediction
|
| 565 |
+
# -------------------------
|
| 566 |
+
@torch.no_grad()
|
| 567 |
+
def predict(model, prot_tok, lig_tok, protein_seq, peptide_seq, device, threshold=0.5):
|
| 568 |
+
"""
|
| 569 |
+
Predict agonist activity
|
| 570 |
+
|
| 571 |
+
Returns:
|
| 572 |
+
dict with keys: smiles, non_agonist_prob, agonist_prob, prediction, confidence
|
| 573 |
+
"""
|
| 574 |
+
# Convert peptide to SMILES
|
| 575 |
+
smiles = peptide_to_smiles(peptide_seq)
|
| 576 |
+
|
| 577 |
+
# Tokenize
|
| 578 |
+
prot_tokens = tokenize_protein(protein_seq, prot_tok, device)
|
| 579 |
+
lig_tokens = tokenize_ligand(smiles, lig_tok, 768, device) # FIXED: 768 not 256!
|
| 580 |
+
|
| 581 |
+
# Predict
|
| 582 |
+
logits = model(prot_tokens, lig_tokens)
|
| 583 |
+
probs = F.softmax(logits, dim=-1).squeeze(0)
|
| 584 |
+
|
| 585 |
+
p_non_agonist = probs[0].item()
|
| 586 |
+
p_agonist = probs[1].item()
|
| 587 |
+
prediction = "agonist" if p_agonist >= threshold else "non-agonist"
|
| 588 |
+
|
| 589 |
+
return {
|
| 590 |
+
"smiles": smiles,
|
| 591 |
+
"non_agonist_prob": p_non_agonist,
|
| 592 |
+
"agonist_prob": p_agonist,
|
| 593 |
+
"prediction": prediction,
|
| 594 |
+
"confidence": max(p_non_agonist, p_agonist)
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
# -------------------------
|
| 598 |
+
# MAIN
|
| 599 |
+
# -------------------------
|
| 600 |
+
def main():
|
| 601 |
+
parser = argparse.ArgumentParser(
|
| 602 |
+
description="GPCR Agonist Classifier - TR2-D2 Inference"
|
| 603 |
+
)
|
| 604 |
+
parser.add_argument("--model_ckpt", required=True,
|
| 605 |
+
help="Path to trained model checkpoint")
|
| 606 |
+
parser.add_argument("--tr2d2_checkpoint", required=True,
|
| 607 |
+
help="Path to TR2-D2 pretrained checkpoint")
|
| 608 |
+
parser.add_argument("--tokenizer_vocab", required=True,
|
| 609 |
+
help="Path to tokenizer vocabulary")
|
| 610 |
+
parser.add_argument("--tokenizer_splits", required=True,
|
| 611 |
+
help="Path to tokenizer splits")
|
| 612 |
+
parser.add_argument("--protein_seq", required=True,
|
| 613 |
+
help="GPCR protein sequence")
|
| 614 |
+
parser.add_argument("--ligand_peptide", required=True,
|
| 615 |
+
help="Ligand peptide sequence")
|
| 616 |
+
parser.add_argument("--threshold", type=float, default=0.5,
|
| 617 |
+
help="Classification threshold (default: 0.5)")
|
| 618 |
+
parser.add_argument("--d_model", type=int, default=256,
|
| 619 |
+
help="Hidden dimension (must match training)")
|
| 620 |
+
parser.add_argument("--n_heads", type=int, default=4,
|
| 621 |
+
help="Number of attention heads (must match training)")
|
| 622 |
+
parser.add_argument("--n_self_attn_layers", type=int, default=1,
|
| 623 |
+
help="Number of self-attention layers (must match training)")
|
| 624 |
+
parser.add_argument("--n_bmca_layers", type=int, default=2,
|
| 625 |
+
help="Number of cross-attention layers (must match training)")
|
| 626 |
+
parser.add_argument("--dropout", type=float, default=0.3,
|
| 627 |
+
help="Dropout rate (must match training)")
|
| 628 |
+
parser.add_argument("--device", default=None,
|
| 629 |
+
help="Device (cuda/cpu, default: auto)")
|
| 630 |
+
parser.add_argument("--esm_name", default="facebook/esm2_t33_650M_UR50D",
|
| 631 |
+
help="ESM model name or local path")
|
| 632 |
+
parser.add_argument("--esm_cache_dir", default=None,
|
| 633 |
+
help="Optional cache directory for ESM model")
|
| 634 |
+
parser.add_argument("--esm_local_files_only", action="store_true",
|
| 635 |
+
help="Load ESM from local cache only (no network)")
|
| 636 |
+
|
| 637 |
+
args = parser.parse_args()
|
| 638 |
+
|
| 639 |
+
# Device
|
| 640 |
+
device = resolve_device(args.device)
|
| 641 |
+
|
| 642 |
+
print(f"Device: {device}")
|
| 643 |
+
print("")
|
| 644 |
+
|
| 645 |
+
# Load tokenizers
|
| 646 |
+
print("Loading tokenizers...")
|
| 647 |
+
prot_tok = EsmTokenizer.from_pretrained(
|
| 648 |
+
args.esm_name,
|
| 649 |
+
cache_dir=args.esm_cache_dir,
|
| 650 |
+
local_files_only=args.esm_local_files_only
|
| 651 |
+
)
|
| 652 |
+
lig_tok = SMILES_SPE_Tokenizer(args.tokenizer_vocab, args.tokenizer_splits)
|
| 653 |
+
print(f" Vocab size: {lig_tok.vocab_size}")
|
| 654 |
+
print("")
|
| 655 |
+
|
| 656 |
+
# Create config
|
| 657 |
+
tr2d2_cfg = create_tr2d2_config(lig_tok.vocab_size)
|
| 658 |
+
|
| 659 |
+
# Load model
|
| 660 |
+
print("Loading model...")
|
| 661 |
+
model = ESM_TR2D2_GPCRClassifier(
|
| 662 |
+
esm_name=args.esm_name,
|
| 663 |
+
tr2d2_config=tr2d2_cfg,
|
| 664 |
+
lig_tokenizer=lig_tok,
|
| 665 |
+
tr2d2_checkpoint=args.tr2d2_checkpoint,
|
| 666 |
+
d_model=args.d_model,
|
| 667 |
+
n_heads=args.n_heads,
|
| 668 |
+
n_self_attn_layers=args.n_self_attn_layers,
|
| 669 |
+
n_bmca_layers=args.n_bmca_layers,
|
| 670 |
+
dropout=args.dropout,
|
| 671 |
+
device=device,
|
| 672 |
+
esm_cache_dir=args.esm_cache_dir,
|
| 673 |
+
esm_local_files_only=args.esm_local_files_only
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
# Load trained weights
|
| 677 |
+
print(" Loading trained weights...")
|
| 678 |
+
state_dict = torch.load(args.model_ckpt, map_location=device)
|
| 679 |
+
_load_state_dict_flexible(model, state_dict, strict=True)
|
| 680 |
+
model.to(device).eval()
|
| 681 |
+
print(" Model ready.")
|
| 682 |
+
print("")
|
| 683 |
+
|
| 684 |
+
# Predict
|
| 685 |
+
print("Running inference...")
|
| 686 |
+
result = predict(
|
| 687 |
+
model, prot_tok, lig_tok,
|
| 688 |
+
args.protein_seq, args.ligand_peptide,
|
| 689 |
+
device, args.threshold
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
# Display results
|
| 693 |
+
print("")
|
| 694 |
+
print("=" * 70)
|
| 695 |
+
print("RESULTS")
|
| 696 |
+
print("=" * 70)
|
| 697 |
+
print(f"Protein: {args.protein_seq[:50]}{'...' if len(args.protein_seq) > 50 else ''}")
|
| 698 |
+
print(f"Ligand: {args.ligand_peptide}")
|
| 699 |
+
print(f"SMILES: {result['smiles']}")
|
| 700 |
+
print("")
|
| 701 |
+
print(f"Non-agonist probability: {result['non_agonist_prob']:.4f}")
|
| 702 |
+
print(f"Agonist probability: {result['agonist_prob']:.4f}")
|
| 703 |
+
print("")
|
| 704 |
+
print(f"Prediction (threshold={args.threshold}): {result['prediction'].upper()}")
|
| 705 |
+
print(f"Confidence: {result['confidence']:.4f}")
|
| 706 |
+
print("=" * 70)
|
| 707 |
+
|
| 708 |
+
if __name__ == "__main__":
|
| 709 |
+
main()
|
td3b/td3b_finetune.py
ADDED
|
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD3B Finetuning Loop
|
| 3 |
+
Extends TR2-D2 training with contrastive loss and directional rewards.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import wandb
|
| 9 |
+
import os
|
| 10 |
+
from finetune_utils import loss_wdce
|
| 11 |
+
from .td3b_losses import TD3BTotalLoss, extract_embeddings_from_mdlm
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from plotting import plot_data_with_distribution_seaborn, plot_data
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def td3b_finetune(
|
| 18 |
+
args,
|
| 19 |
+
cfg,
|
| 20 |
+
policy_model,
|
| 21 |
+
reward_model,
|
| 22 |
+
mcts=None,
|
| 23 |
+
pretrained=None,
|
| 24 |
+
filename=None,
|
| 25 |
+
prot_name=None,
|
| 26 |
+
eps=1e-5,
|
| 27 |
+
# TD3B-specific arguments
|
| 28 |
+
contrastive_weight=0.1,
|
| 29 |
+
contrastive_margin=1.0,
|
| 30 |
+
contrastive_type='margin',
|
| 31 |
+
embedding_pool_method='mean',
|
| 32 |
+
kl_beta=0.1
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
TD3B finetuning with combined WDCE + contrastive loss + KL regularization.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
args: Configuration arguments
|
| 39 |
+
cfg: Hydra config
|
| 40 |
+
policy_model: Policy model (MDLM)
|
| 41 |
+
reward_model: Reward scoring functions (TD3BRewardFunction)
|
| 42 |
+
mcts: TD3B_MCTS instance
|
| 43 |
+
pretrained: Pretrained model (for no-MCTS mode)
|
| 44 |
+
filename: Output filename
|
| 45 |
+
prot_name: Target protein name
|
| 46 |
+
eps: Small epsilon
|
| 47 |
+
contrastive_weight: λ for contrastive loss
|
| 48 |
+
contrastive_margin: Margin for margin-based contrastive loss
|
| 49 |
+
contrastive_type: 'margin' or 'infonce'
|
| 50 |
+
embedding_pool_method: 'mean', 'max', or 'cls'
|
| 51 |
+
kl_beta: β coefficient for KL divergence regularization
|
| 52 |
+
Returns:
|
| 53 |
+
batch_losses: List of training losses
|
| 54 |
+
"""
|
| 55 |
+
base_path = args.base_path
|
| 56 |
+
dt = (1 - eps) / args.total_num_steps
|
| 57 |
+
|
| 58 |
+
if args.no_mcts:
|
| 59 |
+
assert pretrained is not None, "pretrained model is required for no mcts"
|
| 60 |
+
else:
|
| 61 |
+
assert mcts is not None, "mcts is required for mcts"
|
| 62 |
+
|
| 63 |
+
# Create reference model (frozen copy of policy model at start of training)
|
| 64 |
+
# Cannot use copy.deepcopy() due to unpicklable objects (file handles, etc.)
|
| 65 |
+
# Instead, create a new model instance and load CLONED state dict
|
| 66 |
+
print("[TD3B] Creating reference model for KL regularization...")
|
| 67 |
+
|
| 68 |
+
# Import Diffusion class
|
| 69 |
+
from diffusion import Diffusion
|
| 70 |
+
|
| 71 |
+
# Create new instance with same config
|
| 72 |
+
reference_model = Diffusion(
|
| 73 |
+
config=policy_model.config,
|
| 74 |
+
tokenizer=policy_model.tokenizer,
|
| 75 |
+
mode="eval",
|
| 76 |
+
device=policy_model.device if hasattr(policy_model, 'device') else args.device
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Get the device from policy model
|
| 80 |
+
device = policy_model.device if hasattr(policy_model, 'device') else args.device
|
| 81 |
+
if device is None:
|
| 82 |
+
device = next(policy_model.parameters()).device
|
| 83 |
+
|
| 84 |
+
# IMPORTANT: Clone the state dict to create independent tensors
|
| 85 |
+
# This ensures no memory sharing between policy and reference model
|
| 86 |
+
state_dict_copy = {
|
| 87 |
+
key: value.clone().detach()
|
| 88 |
+
for key, value in policy_model.state_dict().items()
|
| 89 |
+
}
|
| 90 |
+
reference_model.load_state_dict(state_dict_copy)
|
| 91 |
+
|
| 92 |
+
# Move reference model to same device as policy model
|
| 93 |
+
reference_model = reference_model.to(device)
|
| 94 |
+
|
| 95 |
+
# Freeze and set to eval mode
|
| 96 |
+
reference_model.eval()
|
| 97 |
+
for param in reference_model.parameters():
|
| 98 |
+
param.requires_grad = False
|
| 99 |
+
|
| 100 |
+
print(f"[TD3B] Reference model frozen with {sum(p.numel() for p in reference_model.parameters())} parameters")
|
| 101 |
+
print(f"[TD3B] Reference model on device: {device}")
|
| 102 |
+
|
| 103 |
+
# Verify no parameter sharing
|
| 104 |
+
policy_params = {id(p) for p in policy_model.parameters()}
|
| 105 |
+
ref_params = {id(p) for p in reference_model.parameters()}
|
| 106 |
+
assert len(policy_params.intersection(ref_params)) == 0, \
|
| 107 |
+
"ERROR: Reference model shares parameters with policy model!"
|
| 108 |
+
print("[TD3B] ✓ Verified: No parameter sharing between policy and reference model")
|
| 109 |
+
|
| 110 |
+
# Initialize TD3B total loss
|
| 111 |
+
td3b_loss_fn = TD3BTotalLoss(
|
| 112 |
+
contrastive_weight=contrastive_weight,
|
| 113 |
+
contrastive_margin=contrastive_margin,
|
| 114 |
+
contrastive_type=contrastive_type,
|
| 115 |
+
kl_beta=kl_beta,
|
| 116 |
+
reference_model=reference_model
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Set model to train mode
|
| 120 |
+
policy_model.train()
|
| 121 |
+
torch.set_grad_enabled(True)
|
| 122 |
+
optim = torch.optim.AdamW(policy_model.parameters(), lr=args.learning_rate)
|
| 123 |
+
|
| 124 |
+
# Record metrics
|
| 125 |
+
batch_losses = []
|
| 126 |
+
batch_wdce_losses = []
|
| 127 |
+
batch_contrastive_losses = []
|
| 128 |
+
batch_kl_losses = []
|
| 129 |
+
|
| 130 |
+
# Initialize saved trajectories
|
| 131 |
+
x_saved, log_rnd_saved, final_rewards_saved = None, None, None
|
| 132 |
+
directional_labels_saved, confidences_saved = None, None
|
| 133 |
+
|
| 134 |
+
# Logs
|
| 135 |
+
valid_fraction_log = []
|
| 136 |
+
affinity_log = []
|
| 137 |
+
gated_reward_log = []
|
| 138 |
+
confidence_log = []
|
| 139 |
+
direction_prediction_log = [] # Oracle predictions f_φ ∈ [0, 1]
|
| 140 |
+
consistency_reward_log = [] # d* × (f_φ - 0.5)
|
| 141 |
+
|
| 142 |
+
### Fine-Tuning Loop ###
|
| 143 |
+
pbar = tqdm(range(args.num_epochs))
|
| 144 |
+
|
| 145 |
+
for epoch in pbar:
|
| 146 |
+
rewards = []
|
| 147 |
+
losses = []
|
| 148 |
+
|
| 149 |
+
policy_model.train()
|
| 150 |
+
|
| 151 |
+
with torch.no_grad():
|
| 152 |
+
if x_saved is None or epoch % args.resample_every_n_step == 0:
|
| 153 |
+
# Generate trajectories
|
| 154 |
+
if args.no_mcts:
|
| 155 |
+
# Direct sampling (not typical for TD3B, but keep for compatibility)
|
| 156 |
+
x_final, log_rnd, final_rewards = policy_model.sample_finetuned_with_rnd(
|
| 157 |
+
args, reward_model, pretrained
|
| 158 |
+
)
|
| 159 |
+
directional_labels = torch.zeros(x_final.size(0), dtype=torch.float32)
|
| 160 |
+
confidences = torch.ones(x_final.size(0), dtype=torch.float32)
|
| 161 |
+
else:
|
| 162 |
+
# TD3B MCTS forward pass
|
| 163 |
+
# For dual-direction mode, sample BOTH directions in the same batch
|
| 164 |
+
if hasattr(args, 'target_direction') and args.target_direction == 'both':
|
| 165 |
+
print(f"[Dual-direction] Epoch {epoch}: Sampling BOTH agonist and antagonist binders")
|
| 166 |
+
|
| 167 |
+
# Sample agonist binders (d* = +1)
|
| 168 |
+
reward_model.target_direction = 1.0
|
| 169 |
+
if epoch % args.reset_every_n_step == 0:
|
| 170 |
+
results_agonist = mcts.forward(resetTree=True)
|
| 171 |
+
else:
|
| 172 |
+
results_agonist = mcts.forward(resetTree=False)
|
| 173 |
+
|
| 174 |
+
# Sample antagonist binders (d* = -1)
|
| 175 |
+
reward_model.target_direction = -1.0
|
| 176 |
+
# Don't reset tree for antagonist to save computation
|
| 177 |
+
results_antagonist = mcts.forward(resetTree=False)
|
| 178 |
+
|
| 179 |
+
# Unpack both results
|
| 180 |
+
if len(results_agonist) == 7 and len(results_antagonist) == 7:
|
| 181 |
+
x_agonist, log_rnd_agonist, rewards_agonist, _, _, labels_agonist, conf_agonist = results_agonist
|
| 182 |
+
x_antagonist, log_rnd_antagonist, rewards_antagonist, _, _, labels_antagonist, conf_antagonist = results_antagonist
|
| 183 |
+
|
| 184 |
+
# Force labels to be correct (in case oracle is wrong)
|
| 185 |
+
labels_agonist = torch.ones(x_agonist.size(0), dtype=torch.float32) * 1.0 # +1 for agonist
|
| 186 |
+
labels_antagonist = torch.ones(x_antagonist.size(0), dtype=torch.float32) * -1.0 # -1 for antagonist
|
| 187 |
+
|
| 188 |
+
# Combine both directions into single batch
|
| 189 |
+
x_final = torch.cat([x_agonist, x_antagonist], dim=0)
|
| 190 |
+
log_rnd = torch.cat([log_rnd_agonist, log_rnd_antagonist], dim=0)
|
| 191 |
+
final_rewards = np.concatenate([rewards_agonist, rewards_antagonist], axis=0)
|
| 192 |
+
directional_labels = torch.cat([labels_agonist, labels_antagonist], dim=0)
|
| 193 |
+
confidences = torch.cat([
|
| 194 |
+
conf_agonist if isinstance(conf_agonist, torch.Tensor) else torch.tensor(conf_agonist),
|
| 195 |
+
conf_antagonist if isinstance(conf_antagonist, torch.Tensor) else torch.tensor(conf_antagonist)
|
| 196 |
+
], dim=0)
|
| 197 |
+
|
| 198 |
+
print(f" → Combined batch: {x_agonist.size(0)} agonists + {x_antagonist.size(0)} antagonists = {x_final.size(0)} total")
|
| 199 |
+
print(f" → Directional labels: {torch.unique(directional_labels).tolist()} (DIVERSITY CONFIRMED!)")
|
| 200 |
+
else:
|
| 201 |
+
raise ValueError("Dual-direction mode requires 7-value return from MCTS")
|
| 202 |
+
else:
|
| 203 |
+
# Single-direction mode
|
| 204 |
+
if epoch % args.reset_every_n_step == 0:
|
| 205 |
+
results = mcts.forward(resetTree=True)
|
| 206 |
+
else:
|
| 207 |
+
results = mcts.forward(resetTree=False)
|
| 208 |
+
|
| 209 |
+
# Unpack results (TD3B version includes directional labels and confidences)
|
| 210 |
+
if len(results) == 7:
|
| 211 |
+
x_final, log_rnd, final_rewards, score_vectors, sequences, directional_labels, confidences = results
|
| 212 |
+
# Convert numpy arrays to tensors immediately for consistency
|
| 213 |
+
if not isinstance(directional_labels, torch.Tensor):
|
| 214 |
+
directional_labels = torch.tensor(directional_labels, dtype=torch.float32)
|
| 215 |
+
if not isinstance(confidences, torch.Tensor):
|
| 216 |
+
confidences = torch.tensor(confidences, dtype=torch.float32)
|
| 217 |
+
else:
|
| 218 |
+
# Fallback for compatibility with base MCTS
|
| 219 |
+
x_final, log_rnd, final_rewards, score_vectors, sequences = results
|
| 220 |
+
directional_labels = torch.zeros(x_final.size(0), dtype=torch.float32)
|
| 221 |
+
confidences = torch.ones(x_final.size(0), dtype=torch.float32)
|
| 222 |
+
|
| 223 |
+
# Save for next iteration
|
| 224 |
+
x_saved = x_final
|
| 225 |
+
log_rnd_saved = log_rnd
|
| 226 |
+
final_rewards_saved = final_rewards
|
| 227 |
+
directional_labels_saved = directional_labels
|
| 228 |
+
confidences_saved = confidences
|
| 229 |
+
else:
|
| 230 |
+
# Reuse cached trajectories
|
| 231 |
+
x_final = x_saved
|
| 232 |
+
log_rnd = log_rnd_saved
|
| 233 |
+
final_rewards = final_rewards_saved
|
| 234 |
+
directional_labels = directional_labels_saved
|
| 235 |
+
confidences = confidences_saved
|
| 236 |
+
|
| 237 |
+
# Compute WDCE loss
|
| 238 |
+
wdce_loss = loss_wdce(
|
| 239 |
+
policy_model,
|
| 240 |
+
log_rnd,
|
| 241 |
+
x_final,
|
| 242 |
+
num_replicates=args.wdce_num_replicates,
|
| 243 |
+
centering=args.centering
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Compute KL divergence loss
|
| 247 |
+
# Use a random masking and forward pass for KL computation
|
| 248 |
+
mask_index = policy_model.mask_index
|
| 249 |
+
device = x_final.device
|
| 250 |
+
|
| 251 |
+
# Sample random noise level
|
| 252 |
+
lamda = torch.rand(x_final.shape[0], device=device) # (B,)
|
| 253 |
+
sigma_kl = -torch.log1p(-(1 - eps) * lamda)
|
| 254 |
+
|
| 255 |
+
# Apply random masking
|
| 256 |
+
masked_index = torch.rand(*x_final.shape, device=device) < lamda[..., None] # (B, L)
|
| 257 |
+
perturbed_batch = torch.where(masked_index, mask_index, x_final)
|
| 258 |
+
attn_mask_kl = torch.ones_like(perturbed_batch).to(device)
|
| 259 |
+
|
| 260 |
+
# Compute KL loss
|
| 261 |
+
kl_loss = td3b_loss_fn.compute_kl_loss(
|
| 262 |
+
policy_model,
|
| 263 |
+
perturbed_batch,
|
| 264 |
+
attn_mask_kl,
|
| 265 |
+
sigma_kl
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Extract embeddings for contrastive loss
|
| 269 |
+
# Only compute if we have directional labels
|
| 270 |
+
if directional_labels is not None and len(torch.unique(directional_labels)) > 1:
|
| 271 |
+
# Get device from backbone
|
| 272 |
+
device = policy_model.backbone.device if hasattr(policy_model.backbone, 'device') else x_final.device
|
| 273 |
+
|
| 274 |
+
embeddings = extract_embeddings_from_mdlm(
|
| 275 |
+
policy_model,
|
| 276 |
+
x_final.to(device),
|
| 277 |
+
pool_method=embedding_pool_method
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Move directional labels to same device
|
| 281 |
+
directional_labels = directional_labels.to(embeddings.device)
|
| 282 |
+
|
| 283 |
+
# Enable debug mode for first 3 epochs or if loss was zero last epoch
|
| 284 |
+
debug_mode = (epoch < 3) or (epoch > 0 and batch_contrastive_losses and batch_contrastive_losses[-1] < 1e-6)
|
| 285 |
+
|
| 286 |
+
# Compute total TD3B loss
|
| 287 |
+
total_loss, loss_dict = td3b_loss_fn.compute_loss(
|
| 288 |
+
wdce_loss,
|
| 289 |
+
embeddings,
|
| 290 |
+
directional_labels,
|
| 291 |
+
kl_loss=kl_loss, # Pass KL loss
|
| 292 |
+
debug=debug_mode # Enable debugging when needed
|
| 293 |
+
)
|
| 294 |
+
else:
|
| 295 |
+
# If no directional diversity, skip contrastive loss
|
| 296 |
+
print(f"[WARNING] Epoch {epoch}: No directional diversity! Skipping contrastive loss.")
|
| 297 |
+
print(f" Labels: {directional_labels.cpu().tolist() if directional_labels is not None else 'None'}")
|
| 298 |
+
total_loss = wdce_loss + td3b_loss_fn.kl_beta * kl_loss
|
| 299 |
+
loss_dict = {
|
| 300 |
+
'total_loss': total_loss.item(),
|
| 301 |
+
'wdce_loss': wdce_loss.item(),
|
| 302 |
+
'contrastive_loss': 0.0,
|
| 303 |
+
'kl_loss': kl_loss.item()
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
# Gradient descent
|
| 307 |
+
total_loss.backward()
|
| 308 |
+
|
| 309 |
+
# Gradient clipping
|
| 310 |
+
if args.grad_clip:
|
| 311 |
+
torch.nn.utils.clip_grad_norm_(policy_model.parameters(), args.gradnorm_clip)
|
| 312 |
+
|
| 313 |
+
optim.step()
|
| 314 |
+
optim.zero_grad()
|
| 315 |
+
|
| 316 |
+
pbar.set_postfix(
|
| 317 |
+
total_loss=loss_dict['total_loss'],
|
| 318 |
+
wdce=loss_dict['wdce_loss'],
|
| 319 |
+
ctr=loss_dict['contrastive_loss']
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Evaluation sampling
|
| 323 |
+
x_eval, eval_metrics = policy_model.sample_finetuned_td3b(
|
| 324 |
+
args,
|
| 325 |
+
reward_model,
|
| 326 |
+
batch_size=50,
|
| 327 |
+
dataframe=False
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Extract metrics (TD3B-specific)
|
| 331 |
+
affinity = eval_metrics.get('affinity', [0])
|
| 332 |
+
gated_reward = eval_metrics.get('gated_reward', [0])
|
| 333 |
+
confidence = eval_metrics.get('confidence', [1])
|
| 334 |
+
valid_fraction = eval_metrics.get('valid_fraction', 0)
|
| 335 |
+
|
| 336 |
+
# Extract direction predictions (f_φ ∈ [0, 1])
|
| 337 |
+
direction_predictions = eval_metrics.get('direction_predictions', [0.5])
|
| 338 |
+
|
| 339 |
+
# Compute consistency reward: d* × (f_φ - 0.5)
|
| 340 |
+
# Get target direction d* from reward_model
|
| 341 |
+
d_star = reward_model.target_direction # +1 or -1
|
| 342 |
+
consistency_rewards = [d_star * (f_phi - 0.5) for f_phi in direction_predictions]
|
| 343 |
+
|
| 344 |
+
# Append to logs
|
| 345 |
+
affinity_log.append(affinity)
|
| 346 |
+
gated_reward_log.append(gated_reward)
|
| 347 |
+
confidence_log.append(confidence)
|
| 348 |
+
valid_fraction_log.append(valid_fraction)
|
| 349 |
+
direction_prediction_log.append(direction_predictions)
|
| 350 |
+
consistency_reward_log.append(consistency_rewards)
|
| 351 |
+
|
| 352 |
+
batch_losses.append(loss_dict['total_loss'])
|
| 353 |
+
batch_wdce_losses.append(loss_dict['wdce_loss'])
|
| 354 |
+
batch_contrastive_losses.append(loss_dict['contrastive_loss'])
|
| 355 |
+
batch_kl_losses.append(loss_dict.get('kl_loss', 0.0))
|
| 356 |
+
|
| 357 |
+
# Compute search statistics
|
| 358 |
+
if args.no_mcts:
|
| 359 |
+
mean_reward_search = final_rewards.mean().item()
|
| 360 |
+
min_reward_search = final_rewards.min().item()
|
| 361 |
+
max_reward_search = final_rewards.max().item()
|
| 362 |
+
median_reward_search = final_rewards.median().item()
|
| 363 |
+
else:
|
| 364 |
+
mean_reward_search = np.mean(final_rewards)
|
| 365 |
+
min_reward_search = np.min(final_rewards)
|
| 366 |
+
max_reward_search = np.max(final_rewards)
|
| 367 |
+
median_reward_search = np.median(final_rewards)
|
| 368 |
+
|
| 369 |
+
# Compute direction oracle and consistency reward statistics
|
| 370 |
+
mean_direction = np.mean(direction_predictions) if len(direction_predictions) > 0 else 0.5
|
| 371 |
+
std_direction = np.std(direction_predictions) if len(direction_predictions) > 0 else 0.0
|
| 372 |
+
mean_consistency = np.mean(consistency_rewards) if len(consistency_rewards) > 0 else 0.0
|
| 373 |
+
std_consistency = np.std(consistency_rewards) if len(consistency_rewards) > 0 else 0.0
|
| 374 |
+
|
| 375 |
+
print(
|
| 376 |
+
f"epoch {epoch} | "
|
| 377 |
+
f"affinity {np.mean(affinity):.4f} | "
|
| 378 |
+
f"gated_reward {np.mean(gated_reward):.4f} | "
|
| 379 |
+
f"confidence {np.mean(confidence):.4f} | "
|
| 380 |
+
f"valid_frac {valid_fraction:.4f} | "
|
| 381 |
+
f"direction_oracle {mean_direction:.4f}±{std_direction:.4f} | "
|
| 382 |
+
f"consistency_reward {mean_consistency:.4f}±{std_consistency:.4f} | "
|
| 383 |
+
f"total_loss {loss_dict['total_loss']:.4f} | "
|
| 384 |
+
f"wdce_loss {loss_dict['wdce_loss']:.4f} | "
|
| 385 |
+
f"contrastive_loss {loss_dict['contrastive_loss']:.4f} | "
|
| 386 |
+
f"kl_loss {loss_dict.get('kl_loss', 0.0):.4f}"
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# W&B logging
|
| 390 |
+
wandb.log({
|
| 391 |
+
"epoch": epoch,
|
| 392 |
+
"affinity": np.mean(affinity),
|
| 393 |
+
"gated_reward": np.mean(gated_reward),
|
| 394 |
+
"confidence": np.mean(confidence),
|
| 395 |
+
"valid_fraction": valid_fraction,
|
| 396 |
+
"direction_oracle/mean": mean_direction,
|
| 397 |
+
"direction_oracle/std": std_direction,
|
| 398 |
+
"consistency_reward/mean": mean_consistency,
|
| 399 |
+
"consistency_reward/std": std_consistency,
|
| 400 |
+
"total_loss": loss_dict['total_loss'],
|
| 401 |
+
"wdce_loss": loss_dict['wdce_loss'],
|
| 402 |
+
"contrastive_loss": loss_dict['contrastive_loss'],
|
| 403 |
+
"kl_loss": loss_dict.get('kl_loss', 0.0),
|
| 404 |
+
"mean_reward_search": mean_reward_search,
|
| 405 |
+
"min_reward_search": min_reward_search,
|
| 406 |
+
"max_reward_search": max_reward_search,
|
| 407 |
+
"median_reward_search": median_reward_search
|
| 408 |
+
})
|
| 409 |
+
|
| 410 |
+
# Save checkpoint
|
| 411 |
+
if (epoch + 1) % args.save_every_n_epochs == 0:
|
| 412 |
+
model_path = os.path.join(args.save_path, f'model_{epoch}.ckpt')
|
| 413 |
+
torch.save(policy_model.state_dict(), model_path)
|
| 414 |
+
print(f"model saved at epoch {epoch}")
|
| 415 |
+
|
| 416 |
+
### End of Fine-Tuning Loop ###
|
| 417 |
+
|
| 418 |
+
wandb.finish()
|
| 419 |
+
|
| 420 |
+
# Save logs and plots
|
| 421 |
+
plot_path = f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}'
|
| 422 |
+
os.makedirs(plot_path, exist_ok=True)
|
| 423 |
+
output_log_path = f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/log_{filename}.csv'
|
| 424 |
+
save_td3b_logs_to_file(
|
| 425 |
+
valid_fraction_log,
|
| 426 |
+
affinity_log,
|
| 427 |
+
gated_reward_log,
|
| 428 |
+
confidence_log,
|
| 429 |
+
direction_prediction_log,
|
| 430 |
+
consistency_reward_log,
|
| 431 |
+
output_log_path
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
plot_data(valid_fraction_log,
|
| 435 |
+
save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/valid_{filename}.png')
|
| 436 |
+
|
| 437 |
+
plot_data_with_distribution_seaborn(
|
| 438 |
+
log1=affinity_log,
|
| 439 |
+
save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/affinity_{filename}.png',
|
| 440 |
+
label1=f"Average Affinity to {prot_name}",
|
| 441 |
+
title=f"Average Affinity to {prot_name} Over Iterations"
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
plot_data_with_distribution_seaborn(
|
| 445 |
+
log1=gated_reward_log,
|
| 446 |
+
save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/gated_reward_{filename}.png',
|
| 447 |
+
label1="Average Gated Reward",
|
| 448 |
+
title="Average Gated Reward Over Iterations"
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
plot_data_with_distribution_seaborn(
|
| 452 |
+
log1=confidence_log,
|
| 453 |
+
save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/confidence_{filename}.png',
|
| 454 |
+
label1="Average Confidence",
|
| 455 |
+
title="Average Confidence Over Iterations"
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# Final evaluation
|
| 459 |
+
x_eval, eval_metrics, df = policy_model.sample_finetuned_td3b(
|
| 460 |
+
args,
|
| 461 |
+
reward_model,
|
| 462 |
+
batch_size=200,
|
| 463 |
+
dataframe=True
|
| 464 |
+
)
|
| 465 |
+
df.to_csv(f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/{prot_name}_generation_results.csv', index=False)
|
| 466 |
+
|
| 467 |
+
return batch_losses
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def save_td3b_logs_to_file(valid_fraction_log, affinity_log, gated_reward_log, confidence_log,
|
| 471 |
+
direction_prediction_log, consistency_reward_log, output_path):
|
| 472 |
+
"""
|
| 473 |
+
Saves TD3B-specific logs to a CSV file.
|
| 474 |
+
|
| 475 |
+
Parameters:
|
| 476 |
+
valid_fraction_log (list): Log of valid fractions over iterations.
|
| 477 |
+
affinity_log (list): Log of binding affinity over iterations.
|
| 478 |
+
gated_reward_log (list): Log of gated rewards over iterations.
|
| 479 |
+
confidence_log (list): Log of confidence scores over iterations.
|
| 480 |
+
direction_prediction_log (list): Log of direction oracle predictions over iterations.
|
| 481 |
+
consistency_reward_log (list): Log of consistency rewards over iterations.
|
| 482 |
+
output_path (str): Path to save the log CSV file.
|
| 483 |
+
"""
|
| 484 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 485 |
+
|
| 486 |
+
# Combine logs into a DataFrame
|
| 487 |
+
log_data = {
|
| 488 |
+
"Iteration": list(range(1, len(valid_fraction_log) + 1)),
|
| 489 |
+
"Valid Fraction": valid_fraction_log,
|
| 490 |
+
"Binding Affinity": affinity_log,
|
| 491 |
+
"Gated Reward": gated_reward_log,
|
| 492 |
+
"Confidence": confidence_log,
|
| 493 |
+
"Direction Oracle": direction_prediction_log,
|
| 494 |
+
"Consistency Reward": consistency_reward_log
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
df = pd.DataFrame(log_data)
|
| 498 |
+
|
| 499 |
+
# Save to CSV
|
| 500 |
+
df.to_csv(output_path, index=False)
|
| 501 |
+
print(f"Logs saved to {output_path}")
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
# Add sampling method to diffusion model (monkey patch or extend)
|
| 505 |
+
def add_td3b_sampling_to_model(model):
|
| 506 |
+
"""
|
| 507 |
+
Adds TD3B-specific sampling method to the model.
|
| 508 |
+
This is a helper function to extend the existing model.
|
| 509 |
+
"""
|
| 510 |
+
def sample_finetuned_td3b(self, args, reward_model, batch_size=50, dataframe=False):
|
| 511 |
+
"""
|
| 512 |
+
TD3B-specific sampling that returns directional metrics.
|
| 513 |
+
"""
|
| 514 |
+
self.backbone.eval()
|
| 515 |
+
self.noise.eval()
|
| 516 |
+
|
| 517 |
+
if batch_size is None:
|
| 518 |
+
batch_size = args.batch_size
|
| 519 |
+
|
| 520 |
+
eps = getattr(args, "sampling_eps", 1e-5)
|
| 521 |
+
num_steps = args.total_num_steps
|
| 522 |
+
x_rollout = self.sample_prior(
|
| 523 |
+
batch_size,
|
| 524 |
+
args.seq_length).to(self.device, dtype=torch.long)
|
| 525 |
+
|
| 526 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
|
| 527 |
+
dt = torch.tensor((1 - eps) / num_steps, device=self.device)
|
| 528 |
+
|
| 529 |
+
for i in range(num_steps):
|
| 530 |
+
t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
|
| 531 |
+
log_p, x_next = self.single_reverse_step(x_rollout, t=t, dt=dt)
|
| 532 |
+
x_rollout = x_next.to(self.device)
|
| 533 |
+
|
| 534 |
+
mask_positions = (x_rollout == self.mask_index)
|
| 535 |
+
if mask_positions.any().item():
|
| 536 |
+
log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
|
| 537 |
+
x_rollout = x_next.to(self.device)
|
| 538 |
+
|
| 539 |
+
# Convert x to sequences to get valid ones
|
| 540 |
+
from utils.app import PeptideAnalyzer
|
| 541 |
+
analyzer = PeptideAnalyzer()
|
| 542 |
+
sequences = self.tokenizer.batch_decode(x_rollout)
|
| 543 |
+
valid_mask = torch.tensor([analyzer.is_peptide(seq) for seq in sequences], device=self.device)
|
| 544 |
+
valid_sequences = [seq for seq, keep in zip(sequences, valid_mask.tolist()) if keep]
|
| 545 |
+
valid_x_final = x_rollout[valid_mask] if valid_mask.any().item() else torch.empty(0, device=self.device)
|
| 546 |
+
valid_fraction = len(valid_sequences) / batch_size
|
| 547 |
+
|
| 548 |
+
if len(valid_sequences) > 0:
|
| 549 |
+
result = reward_model(valid_sequences)
|
| 550 |
+
if isinstance(result, tuple):
|
| 551 |
+
total_rewards, info = result
|
| 552 |
+
affinity = np.asarray(info.get('affinities', total_rewards))
|
| 553 |
+
confidence = np.asarray(info.get('confidences', np.ones_like(affinity)))
|
| 554 |
+
direction_predictions = np.asarray(info.get('directions', np.zeros_like(affinity)))
|
| 555 |
+
else:
|
| 556 |
+
total_rewards = np.asarray(result)
|
| 557 |
+
if total_rewards.ndim > 1:
|
| 558 |
+
affinity = total_rewards[:, 0]
|
| 559 |
+
else:
|
| 560 |
+
affinity = total_rewards
|
| 561 |
+
confidence = np.ones_like(affinity)
|
| 562 |
+
direction_predictions = np.zeros_like(affinity)
|
| 563 |
+
|
| 564 |
+
rewards_t = torch.as_tensor(total_rewards, dtype=torch.float32, device=self.device)
|
| 565 |
+
alpha = max(float(getattr(args, "alpha", 0.1)), 1e-6)
|
| 566 |
+
weights = torch.softmax(rewards_t / alpha, dim=0)
|
| 567 |
+
idx = torch.multinomial(weights, num_samples=batch_size, replacement=True)
|
| 568 |
+
|
| 569 |
+
idx_np = idx.detach().cpu().numpy()
|
| 570 |
+
x_resampled = valid_x_final[idx]
|
| 571 |
+
sequences = [valid_sequences[i] for i in idx_np]
|
| 572 |
+
total_rewards = total_rewards[idx_np]
|
| 573 |
+
affinity = affinity[idx_np]
|
| 574 |
+
confidence = confidence[idx_np]
|
| 575 |
+
direction_predictions = direction_predictions[idx_np]
|
| 576 |
+
else:
|
| 577 |
+
x_resampled = x_rollout
|
| 578 |
+
total_rewards = np.array([])
|
| 579 |
+
affinity = np.array([])
|
| 580 |
+
confidence = np.array([])
|
| 581 |
+
direction_predictions = np.array([])
|
| 582 |
+
|
| 583 |
+
eval_metrics = {
|
| 584 |
+
'affinity': affinity,
|
| 585 |
+
'gated_reward': total_rewards,
|
| 586 |
+
'confidence': confidence,
|
| 587 |
+
'direction_predictions': direction_predictions,
|
| 588 |
+
'valid_fraction': valid_fraction
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
if dataframe:
|
| 592 |
+
df = pd.DataFrame({
|
| 593 |
+
'sequence': sequences if len(total_rewards) else [],
|
| 594 |
+
'affinity': affinity,
|
| 595 |
+
'gated_reward': total_rewards,
|
| 596 |
+
'confidence': confidence
|
| 597 |
+
})
|
| 598 |
+
return x_resampled, eval_metrics, df
|
| 599 |
+
else:
|
| 600 |
+
return x_resampled, eval_metrics
|
| 601 |
+
|
| 602 |
+
# Attach method to model
|
| 603 |
+
model.sample_finetuned_td3b = sample_finetuned_td3b.__get__(model, type(model))
|
| 604 |
+
return model
|
td3b/td3b_losses.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD3B Loss Functions
|
| 3 |
+
Implements contrastive loss for separating agonist/antagonist embeddings.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from typing import Optional, Tuple
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ContrastiveLoss(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Margin-based contrastive loss for separating agonist and antagonist embeddings.
|
| 15 |
+
|
| 16 |
+
For a pair of sequences (y_i, y_j):
|
| 17 |
+
- If both are agonists OR both are antagonists (similar): minimize distance
|
| 18 |
+
- If one is agonist and one is antagonist (dissimilar): maximize distance
|
| 19 |
+
|
| 20 |
+
Loss formula:
|
| 21 |
+
L_ctr = (1 - y_ij) * 0.5 * d²
|
| 22 |
+
+ y_ij * 0.5 * max(0, margin - d)²
|
| 23 |
+
|
| 24 |
+
where:
|
| 25 |
+
- d = ||emb_i - emb_j||_2 (Euclidean distance)
|
| 26 |
+
- y_ij = 0 if similar, 1 if dissimilar
|
| 27 |
+
- margin = minimum distance between dissimilar pairs
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, margin: float = 1.0, distance_metric: str = 'euclidean', adaptive_margin: bool = False):
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
margin: Minimum distance between dissimilar pairs (base margin)
|
| 34 |
+
distance_metric: 'euclidean' or 'cosine'
|
| 35 |
+
adaptive_margin: If True, adjust margin based on actual dissimilar distances
|
| 36 |
+
"""
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.base_margin = margin
|
| 39 |
+
self.distance_metric = distance_metric
|
| 40 |
+
self.adaptive_margin = adaptive_margin
|
| 41 |
+
|
| 42 |
+
def compute_distance(self, emb1: torch.Tensor, emb2: torch.Tensor) -> torch.Tensor:
|
| 43 |
+
"""
|
| 44 |
+
Compute pairwise distance between embeddings.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
emb1: (batch_size, embedding_dim)
|
| 48 |
+
emb2: (batch_size, embedding_dim)
|
| 49 |
+
Returns:
|
| 50 |
+
distances: (batch_size,)
|
| 51 |
+
"""
|
| 52 |
+
if self.distance_metric == 'euclidean':
|
| 53 |
+
# L2 distance
|
| 54 |
+
distances = torch.norm(emb1 - emb2, p=2, dim=-1) # (B,)
|
| 55 |
+
elif self.distance_metric == 'cosine':
|
| 56 |
+
# Cosine distance = 1 - cosine_similarity
|
| 57 |
+
cos_sim = F.cosine_similarity(emb1, emb2, dim=-1) # (B,)
|
| 58 |
+
distances = 1.0 - cos_sim
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError(f"Unknown distance metric: {self.distance_metric}")
|
| 61 |
+
|
| 62 |
+
return distances
|
| 63 |
+
|
| 64 |
+
def forward(
|
| 65 |
+
self,
|
| 66 |
+
embeddings: torch.Tensor,
|
| 67 |
+
labels: torch.Tensor,
|
| 68 |
+
confidences: Optional[torch.Tensor] = None,
|
| 69 |
+
debug: bool = False
|
| 70 |
+
) -> torch.Tensor:
|
| 71 |
+
"""
|
| 72 |
+
Compute contrastive loss for a batch.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
embeddings: (batch_size, embedding_dim) sequence embeddings
|
| 76 |
+
labels: (batch_size,) directional labels in {-1, +1}
|
| 77 |
+
+1 = agonist, -1 = antagonist
|
| 78 |
+
confidences: (batch_size,) oracle confidence scores; pairs with product <= 0 are masked out
|
| 79 |
+
debug: If True, print detailed debugging information
|
| 80 |
+
Returns:
|
| 81 |
+
loss: scalar contrastive loss
|
| 82 |
+
"""
|
| 83 |
+
batch_size = embeddings.size(0)
|
| 84 |
+
if batch_size < 2:
|
| 85 |
+
if debug:
|
| 86 |
+
print(f"[ContrastiveLoss DEBUG] Batch size {batch_size} < 2, returning 0 loss")
|
| 87 |
+
return torch.tensor(0.0, device=embeddings.device)
|
| 88 |
+
|
| 89 |
+
if confidences is not None:
|
| 90 |
+
if not torch.is_tensor(confidences):
|
| 91 |
+
confidences = torch.as_tensor(confidences, device=embeddings.device)
|
| 92 |
+
else:
|
| 93 |
+
confidences = confidences.to(embeddings.device)
|
| 94 |
+
confidences = confidences.view(-1)
|
| 95 |
+
if confidences.numel() != batch_size:
|
| 96 |
+
raise ValueError(
|
| 97 |
+
f"Confidences size {confidences.numel()} does not match batch size {batch_size}"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Compute pairwise distances (all pairs)
|
| 101 |
+
if self.distance_metric == 'euclidean':
|
| 102 |
+
distances = torch.cdist(embeddings, embeddings, p=2) # (B, B)
|
| 103 |
+
elif self.distance_metric == 'cosine':
|
| 104 |
+
emb_norm = F.normalize(embeddings, p=2, dim=-1)
|
| 105 |
+
distances = 1.0 - torch.matmul(emb_norm, emb_norm.T) # (B, B)
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError(f"Unknown distance metric: {self.distance_metric}")
|
| 108 |
+
|
| 109 |
+
# Compute pairwise similarity labels
|
| 110 |
+
# y_ij = 0 if same class (both agonist or both antagonist)
|
| 111 |
+
# y_ij = 1 if different class
|
| 112 |
+
labels = labels.view(-1)
|
| 113 |
+
labels_expanded = labels.unsqueeze(1) # (B, 1)
|
| 114 |
+
label_product = labels_expanded * labels_expanded.T # (B, B)
|
| 115 |
+
# label_product > 0 means same class (both +1 or both -1)
|
| 116 |
+
# label_product < 0 means different class
|
| 117 |
+
dissimilar_mask = (label_product < 0) # (B, B) bool
|
| 118 |
+
|
| 119 |
+
# Exclude diagonal
|
| 120 |
+
eye_mask = torch.eye(batch_size, device=embeddings.device, dtype=torch.bool)
|
| 121 |
+
pos_mask = (~dissimilar_mask) & ~eye_mask
|
| 122 |
+
neg_mask = dissimilar_mask & ~eye_mask
|
| 123 |
+
|
| 124 |
+
# Apply confidence mask: remove pairs with confidence product <= 0
|
| 125 |
+
conf_mask = None
|
| 126 |
+
if confidences is not None:
|
| 127 |
+
conf_product = confidences.unsqueeze(0) * confidences.unsqueeze(1)
|
| 128 |
+
conf_mask = conf_product > 0
|
| 129 |
+
pos_mask = pos_mask & conf_mask
|
| 130 |
+
neg_mask = neg_mask & conf_mask
|
| 131 |
+
|
| 132 |
+
# Adaptive margin: set margin based on actual dissimilar distances
|
| 133 |
+
if self.adaptive_margin and neg_mask.any():
|
| 134 |
+
# Get all dissimilar distances
|
| 135 |
+
dissimilar_distances = distances[neg_mask]
|
| 136 |
+
# Set margin to 150% of mean dissimilar distance
|
| 137 |
+
# This ensures there's always room for optimization
|
| 138 |
+
adaptive_margin = 1.5 * dissimilar_distances.mean().item()
|
| 139 |
+
# Use max of base_margin and adaptive_margin
|
| 140 |
+
margin = max(self.base_margin, adaptive_margin)
|
| 141 |
+
else:
|
| 142 |
+
margin = self.base_margin
|
| 143 |
+
|
| 144 |
+
pos_count = pos_mask.sum()
|
| 145 |
+
neg_count = neg_mask.sum()
|
| 146 |
+
total_pairs = pos_count + neg_count
|
| 147 |
+
if total_pairs.item() == 0:
|
| 148 |
+
if debug:
|
| 149 |
+
print("[ContrastiveLoss DEBUG] No valid pairs after filtering, returning 0 loss")
|
| 150 |
+
return torch.tensor(0.0, device=embeddings.device)
|
| 151 |
+
|
| 152 |
+
# Contrastive loss
|
| 153 |
+
# For similar pairs: minimize squared distance
|
| 154 |
+
# For dissimilar pairs: squared hinge loss with margin
|
| 155 |
+
pos_loss = distances[pos_mask].pow(2).sum() / (pos_count + 1e-8)
|
| 156 |
+
neg_loss = torch.clamp(margin - distances[neg_mask], min=0.0).pow(2).sum() / (neg_count + 1e-8)
|
| 157 |
+
loss = pos_loss + neg_loss
|
| 158 |
+
|
| 159 |
+
if debug:
|
| 160 |
+
print(f"\n[ContrastiveLoss DEBUG]")
|
| 161 |
+
print(f" Batch size: {batch_size}")
|
| 162 |
+
print(f" Labels: {labels.cpu().tolist()}")
|
| 163 |
+
print(f" Unique labels: {torch.unique(labels).cpu().tolist()}")
|
| 164 |
+
print(f" Embedding shape: {embeddings.shape}")
|
| 165 |
+
print(f" Embedding norm (mean): {embeddings.norm(dim=-1).mean().item():.4f}")
|
| 166 |
+
print(f" Embedding norm (std): {embeddings.norm(dim=-1).std().item():.4f}")
|
| 167 |
+
valid_mask = pos_mask | neg_mask
|
| 168 |
+
if valid_mask.any():
|
| 169 |
+
valid_dists = distances[valid_mask]
|
| 170 |
+
print(f" Distance stats (valid pairs): mean={valid_dists.mean().item():.4f} "
|
| 171 |
+
f"min={valid_dists.min().item():.4f} max={valid_dists.max().item():.4f}")
|
| 172 |
+
if self.adaptive_margin and neg_mask.any():
|
| 173 |
+
print(f" Margin: {margin:.4f} (adaptive, base={self.base_margin})")
|
| 174 |
+
else:
|
| 175 |
+
print(f" Margin: {margin:.4f} (fixed)")
|
| 176 |
+
print(f" Num similar pairs: {pos_count.item():.0f}")
|
| 177 |
+
print(f" Num dissimilar pairs: {neg_count.item():.0f}")
|
| 178 |
+
if conf_mask is not None:
|
| 179 |
+
print(f" Confidence-passing pairs: {conf_mask.sum().item():.0f}")
|
| 180 |
+
print(f" Similar loss (mean): {pos_loss.item():.4f}")
|
| 181 |
+
print(f" Dissimilar loss (mean): {neg_loss.item():.4f}")
|
| 182 |
+
print(f" Total loss: {loss.item():.4f}")
|
| 183 |
+
|
| 184 |
+
# Show which dissimilar pairs have margin violations
|
| 185 |
+
margin_violations = (distances < margin) & neg_mask
|
| 186 |
+
if margin_violations.sum() > 0:
|
| 187 |
+
print(f" Margin violations: {margin_violations.sum().item():.0f} dissimilar pairs have distance < margin")
|
| 188 |
+
else:
|
| 189 |
+
print(f" Margin violations: 0 (all dissimilar pairs are already separated)")
|
| 190 |
+
|
| 191 |
+
return loss
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class InfoNCELoss(nn.Module):
|
| 195 |
+
"""
|
| 196 |
+
Alternative: InfoNCE contrastive loss (used in SimCLR, CLIP).
|
| 197 |
+
Treats agonists as positive class, antagonists as negative class.
|
| 198 |
+
|
| 199 |
+
For each agonist, pull it close to other agonists and push away from antagonists.
|
| 200 |
+
For each antagonist, pull it close to other antagonists and push away from agonists.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
def __init__(self, temperature: float = 0.1):
|
| 204 |
+
"""
|
| 205 |
+
Args:
|
| 206 |
+
temperature: Temperature parameter for softmax
|
| 207 |
+
"""
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.temperature = temperature
|
| 210 |
+
|
| 211 |
+
def forward(
|
| 212 |
+
self,
|
| 213 |
+
embeddings: torch.Tensor,
|
| 214 |
+
labels: torch.Tensor,
|
| 215 |
+
confidences: Optional[torch.Tensor] = None,
|
| 216 |
+
debug: bool = False
|
| 217 |
+
) -> torch.Tensor:
|
| 218 |
+
"""
|
| 219 |
+
Compute InfoNCE loss.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
embeddings: (batch_size, embedding_dim)
|
| 223 |
+
labels: (batch_size,) in {-1, +1}
|
| 224 |
+
confidences: (batch_size,) oracle confidence scores; pairs with product <= 0 are masked out
|
| 225 |
+
debug: Unused (kept for API compatibility)
|
| 226 |
+
Returns:
|
| 227 |
+
loss: scalar
|
| 228 |
+
"""
|
| 229 |
+
batch_size = embeddings.size(0)
|
| 230 |
+
if confidences is not None:
|
| 231 |
+
if not torch.is_tensor(confidences):
|
| 232 |
+
confidences = torch.as_tensor(confidences, device=embeddings.device)
|
| 233 |
+
else:
|
| 234 |
+
confidences = confidences.to(embeddings.device)
|
| 235 |
+
confidences = confidences.view(-1)
|
| 236 |
+
if confidences.numel() != batch_size:
|
| 237 |
+
raise ValueError(
|
| 238 |
+
f"Confidences size {confidences.numel()} does not match batch size {batch_size}"
|
| 239 |
+
)
|
| 240 |
+
if batch_size < 2:
|
| 241 |
+
return torch.tensor(0.0, device=embeddings.device)
|
| 242 |
+
|
| 243 |
+
# Normalize embeddings
|
| 244 |
+
embeddings = F.normalize(embeddings, p=2, dim=-1) # (B, D)
|
| 245 |
+
|
| 246 |
+
# Compute similarity matrix
|
| 247 |
+
similarity = torch.matmul(embeddings, embeddings.T) / self.temperature # (B, B)
|
| 248 |
+
|
| 249 |
+
# Create positive/negative masks
|
| 250 |
+
labels_expanded = labels.unsqueeze(1) # (B, 1)
|
| 251 |
+
label_product = labels_expanded * labels_expanded.T # (B, B)
|
| 252 |
+
positive_mask = (label_product > 0) # Same class
|
| 253 |
+
negative_mask = (label_product < 0) # Different class
|
| 254 |
+
|
| 255 |
+
# Remove self-similarity
|
| 256 |
+
positive_mask.fill_diagonal_(0)
|
| 257 |
+
|
| 258 |
+
if confidences is not None:
|
| 259 |
+
conf_product = confidences.unsqueeze(0) * confidences.unsqueeze(1)
|
| 260 |
+
conf_mask = conf_product > 0
|
| 261 |
+
positive_mask = positive_mask & conf_mask
|
| 262 |
+
negative_mask = negative_mask & conf_mask
|
| 263 |
+
|
| 264 |
+
# For each sample, compute InfoNCE loss
|
| 265 |
+
# log( exp(sim_pos) / (exp(sim_pos) + sum(exp(sim_neg))) )
|
| 266 |
+
losses = []
|
| 267 |
+
for i in range(batch_size):
|
| 268 |
+
# Positive samples
|
| 269 |
+
pos_sims = similarity[i][positive_mask[i]] # (num_pos,)
|
| 270 |
+
# Negative samples
|
| 271 |
+
neg_sims = similarity[i][negative_mask[i]] # (num_neg,)
|
| 272 |
+
|
| 273 |
+
# Check if there are positive samples
|
| 274 |
+
if pos_sims.numel() == 0:
|
| 275 |
+
continue
|
| 276 |
+
|
| 277 |
+
# LogSumExp for numerical stability
|
| 278 |
+
pos_exp = torch.exp(pos_sims) # (num_pos,)
|
| 279 |
+
neg_exp = torch.exp(neg_sims) # (num_neg,)
|
| 280 |
+
|
| 281 |
+
if neg_exp.numel() == 0:
|
| 282 |
+
continue
|
| 283 |
+
|
| 284 |
+
# Average over positive samples
|
| 285 |
+
denominator = pos_exp.sum() + neg_exp.sum()
|
| 286 |
+
loss_i = -torch.log(pos_exp.sum() / (denominator + 1e-8))
|
| 287 |
+
losses.append(loss_i)
|
| 288 |
+
|
| 289 |
+
if len(losses) == 0:
|
| 290 |
+
return torch.tensor(0.0, device=embeddings.device)
|
| 291 |
+
|
| 292 |
+
return torch.stack(losses).mean()
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class TD3BTotalLoss:
|
| 296 |
+
"""
|
| 297 |
+
Combined TD3B loss: L_total = L_WDCE + λ * L_ctr + β * L_KL
|
| 298 |
+
|
| 299 |
+
Components:
|
| 300 |
+
- L_WDCE: Weighted Denoising Cross-Entropy (from TR2-D2)
|
| 301 |
+
- L_ctr: Contrastive loss for agonist/antagonist separation
|
| 302 |
+
- L_KL: KL divergence regularization between policy and reference model
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
def __init__(
|
| 306 |
+
self,
|
| 307 |
+
contrastive_weight: float = 0.1,
|
| 308 |
+
contrastive_margin: float = 1.0,
|
| 309 |
+
contrastive_type: str = 'margin', # 'margin' or 'infonce'
|
| 310 |
+
kl_beta: float = 0.1, # β coefficient for KL divergence
|
| 311 |
+
reference_model: Optional[nn.Module] = None,
|
| 312 |
+
adaptive_margin: bool = True # Enable adaptive margin by default
|
| 313 |
+
):
|
| 314 |
+
"""
|
| 315 |
+
Args:
|
| 316 |
+
contrastive_weight: λ coefficient for contrastive loss
|
| 317 |
+
contrastive_margin: Margin for margin-based contrastive loss (base margin if adaptive)
|
| 318 |
+
contrastive_type: Type of contrastive loss ('margin' or 'infonce')
|
| 319 |
+
kl_beta: β coefficient for KL divergence regularization
|
| 320 |
+
reference_model: Frozen reference model for KL divergence (deepcopy of pretrained)
|
| 321 |
+
adaptive_margin: If True, automatically adjust margin based on dissimilar distances
|
| 322 |
+
"""
|
| 323 |
+
self.contrastive_weight = contrastive_weight
|
| 324 |
+
self.kl_beta = kl_beta
|
| 325 |
+
self.reference_model = reference_model
|
| 326 |
+
|
| 327 |
+
# Freeze reference model if provided
|
| 328 |
+
if self.reference_model is not None:
|
| 329 |
+
self.reference_model.eval()
|
| 330 |
+
for param in self.reference_model.parameters():
|
| 331 |
+
param.requires_grad = False
|
| 332 |
+
|
| 333 |
+
# Verify all parameters are frozen
|
| 334 |
+
assert all(not p.requires_grad for p in self.reference_model.parameters()), \
|
| 335 |
+
"ERROR: Reference model has parameters with requires_grad=True!"
|
| 336 |
+
|
| 337 |
+
if contrastive_type == 'margin':
|
| 338 |
+
self.contrastive_loss = ContrastiveLoss(
|
| 339 |
+
margin=contrastive_margin,
|
| 340 |
+
distance_metric='euclidean',
|
| 341 |
+
adaptive_margin=adaptive_margin
|
| 342 |
+
)
|
| 343 |
+
elif contrastive_type == 'infonce':
|
| 344 |
+
self.contrastive_loss = InfoNCELoss(temperature=0.1)
|
| 345 |
+
else:
|
| 346 |
+
raise ValueError(f"Unknown contrastive type: {contrastive_type}")
|
| 347 |
+
|
| 348 |
+
def compute_kl_categorical(
|
| 349 |
+
self,
|
| 350 |
+
log_p: torch.Tensor,
|
| 351 |
+
log_ref_p: torch.Tensor
|
| 352 |
+
) -> torch.Tensor:
|
| 353 |
+
"""
|
| 354 |
+
Compute KL divergence between categorical distributions.
|
| 355 |
+
|
| 356 |
+
KL(P || Q) = Σ P(x) * log(P(x) / Q(x))
|
| 357 |
+
= Σ P(x) * (log P(x) - log Q(x))
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
log_p: (B, L, Vocab) log-probabilities from policy model
|
| 361 |
+
log_ref_p: (B, L, Vocab) log-probabilities from reference model
|
| 362 |
+
Returns:
|
| 363 |
+
kl: (B, L) KL divergence per position
|
| 364 |
+
"""
|
| 365 |
+
# Convert log-probs to probabilities
|
| 366 |
+
p = torch.exp(log_p) # (B, L, Vocab)
|
| 367 |
+
|
| 368 |
+
# KL divergence element-wise
|
| 369 |
+
kl_elementwise = p * (log_p - log_ref_p) # (B, L, Vocab)
|
| 370 |
+
|
| 371 |
+
# Handle numerical issues: 0 * log(0) should be 0
|
| 372 |
+
# Replace NaNs or Infs that occur at -inf locations with 0
|
| 373 |
+
kl_elementwise = torch.where(
|
| 374 |
+
torch.isfinite(kl_elementwise),
|
| 375 |
+
kl_elementwise,
|
| 376 |
+
torch.zeros_like(kl_elementwise)
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
# Sum over vocabulary dimension
|
| 380 |
+
kl = kl_elementwise.sum(dim=-1) # (B, L)
|
| 381 |
+
|
| 382 |
+
return kl
|
| 383 |
+
|
| 384 |
+
def compute_kl_loss(
|
| 385 |
+
self,
|
| 386 |
+
policy_model: nn.Module,
|
| 387 |
+
sequences: torch.Tensor,
|
| 388 |
+
attn_mask: torch.Tensor,
|
| 389 |
+
sigma: torch.Tensor
|
| 390 |
+
) -> torch.Tensor:
|
| 391 |
+
"""
|
| 392 |
+
Compute KL divergence loss between policy model and reference model.
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
policy_model: Current policy model (being trained)
|
| 396 |
+
sequences: (B, L) input sequences
|
| 397 |
+
attn_mask: (B, L) attention mask
|
| 398 |
+
sigma: (B,) noise schedule
|
| 399 |
+
Returns:
|
| 400 |
+
kl_loss: Scalar KL divergence loss
|
| 401 |
+
"""
|
| 402 |
+
if self.reference_model is None:
|
| 403 |
+
return torch.tensor(0.0, device=sequences.device)
|
| 404 |
+
|
| 405 |
+
# Ensure reference model is in eval mode
|
| 406 |
+
assert not self.reference_model.training, \
|
| 407 |
+
"ERROR: Reference model is in training mode! It should always be in eval mode."
|
| 408 |
+
|
| 409 |
+
# Forward through policy model (already computed in WDCE, but need logits)
|
| 410 |
+
policy_logits = policy_model(sequences, attn_mask=attn_mask, sigma=sigma) # (B, L, Vocab)
|
| 411 |
+
|
| 412 |
+
# Forward through reference model (frozen, no gradients)
|
| 413 |
+
with torch.no_grad():
|
| 414 |
+
ref_logits = self.reference_model(sequences, attn_mask=attn_mask, sigma=sigma) # (B, L, Vocab)
|
| 415 |
+
|
| 416 |
+
# Convert to log-probabilities
|
| 417 |
+
log_p = F.log_softmax(policy_logits, dim=-1) # (B, L, Vocab)
|
| 418 |
+
log_ref_p = F.log_softmax(ref_logits, dim=-1) # (B, L, Vocab)
|
| 419 |
+
|
| 420 |
+
# Compute KL divergence
|
| 421 |
+
kl_per_position = self.compute_kl_categorical(log_p, log_ref_p) # (B, L)
|
| 422 |
+
|
| 423 |
+
# Mask out padding positions
|
| 424 |
+
kl_masked = kl_per_position * attn_mask.float() # (B, L)
|
| 425 |
+
|
| 426 |
+
# Average over all non-padding positions
|
| 427 |
+
num_valid = attn_mask.float().sum()
|
| 428 |
+
kl_loss = kl_masked.sum() / (num_valid + 1e-8)
|
| 429 |
+
|
| 430 |
+
return kl_loss
|
| 431 |
+
|
| 432 |
+
def compute_loss(
|
| 433 |
+
self,
|
| 434 |
+
wdce_loss: torch.Tensor,
|
| 435 |
+
embeddings: torch.Tensor,
|
| 436 |
+
directional_labels: torch.Tensor,
|
| 437 |
+
confidences: Optional[torch.Tensor] = None,
|
| 438 |
+
kl_loss: Optional[torch.Tensor] = None,
|
| 439 |
+
debug: bool = False
|
| 440 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 441 |
+
"""
|
| 442 |
+
Compute total TD3B loss.
|
| 443 |
+
|
| 444 |
+
Args:
|
| 445 |
+
wdce_loss: Precomputed WDCE loss (scalar)
|
| 446 |
+
embeddings: (batch_size, embedding_dim) sequence embeddings from MDLM
|
| 447 |
+
directional_labels: (batch_size,) labels in {-1, +1}
|
| 448 |
+
confidences: (batch_size,) oracle confidence scores; pairs with product <= 0 are masked out
|
| 449 |
+
kl_loss: Precomputed KL divergence loss (optional)
|
| 450 |
+
debug: If True, enable debugging output in contrastive loss
|
| 451 |
+
Returns:
|
| 452 |
+
total_loss: Combined loss
|
| 453 |
+
loss_dict: Dictionary with individual loss components
|
| 454 |
+
"""
|
| 455 |
+
# Contrastive loss (pass debug flag)
|
| 456 |
+
contrastive_loss = self.contrastive_loss(
|
| 457 |
+
embeddings,
|
| 458 |
+
directional_labels,
|
| 459 |
+
confidences=confidences,
|
| 460 |
+
debug=debug
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
# KL divergence loss
|
| 464 |
+
if kl_loss is None:
|
| 465 |
+
kl_loss = torch.tensor(0.0, device=embeddings.device)
|
| 466 |
+
|
| 467 |
+
# Total loss: L_total = L_WDCE + λ * L_ctr + β * L_KL
|
| 468 |
+
total_loss = wdce_loss + self.contrastive_weight * contrastive_loss + self.kl_beta * kl_loss
|
| 469 |
+
|
| 470 |
+
loss_dict = {
|
| 471 |
+
'total_loss': total_loss.item(),
|
| 472 |
+
'wdce_loss': wdce_loss.item(),
|
| 473 |
+
'contrastive_loss': contrastive_loss.item(),
|
| 474 |
+
'kl_loss': kl_loss.item() if isinstance(kl_loss, torch.Tensor) else kl_loss
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
return total_loss, loss_dict
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def extract_embeddings_from_mdlm(
|
| 481 |
+
model,
|
| 482 |
+
sequences: torch.Tensor,
|
| 483 |
+
pool_method: str = 'mean'
|
| 484 |
+
) -> torch.Tensor:
|
| 485 |
+
"""
|
| 486 |
+
Extract sequence-level embeddings from MDLM backbone.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
model: MDLM model with backbone (Roformer)
|
| 490 |
+
sequences: (batch_size, seq_len) token sequences
|
| 491 |
+
pool_method: 'mean', 'max', or 'cls'
|
| 492 |
+
Returns:
|
| 493 |
+
embeddings: (batch_size, hidden_dim)
|
| 494 |
+
"""
|
| 495 |
+
# Create attention mask (1 for real tokens, 0 for padding)
|
| 496 |
+
attn_mask = (sequences != 0).long() # (B, L)
|
| 497 |
+
|
| 498 |
+
# Forward through Roformer backbone to get hidden states
|
| 499 |
+
# IMPORTANT: DO NOT use torch.no_grad() here - we need gradients for backprop!
|
| 500 |
+
# Access the underlying RoFormerForMaskedLM model and request hidden states
|
| 501 |
+
outputs = model.backbone.model(
|
| 502 |
+
input_ids=sequences,
|
| 503 |
+
attention_mask=attn_mask,
|
| 504 |
+
output_hidden_states=True,
|
| 505 |
+
return_dict=True
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# Extract last hidden state from outputs
|
| 509 |
+
# outputs.hidden_states is a tuple of (embedding_output, layer1, layer2, ..., layerN)
|
| 510 |
+
# We want the last layer's hidden states
|
| 511 |
+
hidden_states = outputs.hidden_states[-1] # (B, L, D)
|
| 512 |
+
|
| 513 |
+
# Pool to get sequence-level embeddings
|
| 514 |
+
if pool_method == 'mean':
|
| 515 |
+
# Mean pooling (ignore padding)
|
| 516 |
+
mask = attn_mask.float().unsqueeze(-1) # (B, L, 1)
|
| 517 |
+
pooled = (hidden_states * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-8) # (B, D)
|
| 518 |
+
elif pool_method == 'max':
|
| 519 |
+
# Max pooling
|
| 520 |
+
pooled = hidden_states.max(dim=1)[0] # (B, D)
|
| 521 |
+
elif pool_method == 'cls':
|
| 522 |
+
# Use first token (CLS-style)
|
| 523 |
+
pooled = hidden_states[:, 0, :] # (B, D)
|
| 524 |
+
else:
|
| 525 |
+
raise ValueError(f"Unknown pool method: {pool_method}")
|
| 526 |
+
|
| 527 |
+
return pooled
|
td3b/td3b_mcts.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD3B-specific MCTS modifications.
|
| 3 |
+
Extends the base MCTS to support directional rewards and confidence weighting.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from peptide_mcts import MCTS as BaseMCTS
|
| 9 |
+
from .td3b_scoring import TD3BRewardFunction, TD3BConfidenceWeighting
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TD3B_MCTS(BaseMCTS):
|
| 13 |
+
"""
|
| 14 |
+
TD3B version of MCTS that:
|
| 15 |
+
1. Uses gated directional rewards instead of multi-objective scalarization
|
| 16 |
+
2. Stores directional labels and confidence scores in the buffer
|
| 17 |
+
3. Applies confidence-weighted importance sampling
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
args,
|
| 23 |
+
diffusion_model,
|
| 24 |
+
td3b_reward_function: TD3BRewardFunction,
|
| 25 |
+
confidence_weighting: TD3BConfidenceWeighting,
|
| 26 |
+
mask_index: int,
|
| 27 |
+
buffer_size: int = 100,
|
| 28 |
+
noise=None,
|
| 29 |
+
tokenizer=None
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
args: Configuration arguments
|
| 34 |
+
diffusion_model: MDLM model for sampling
|
| 35 |
+
td3b_reward_function: TD3BRewardFunction instance
|
| 36 |
+
confidence_weighting: TD3BConfidenceWeighting instance
|
| 37 |
+
mask_index: Token ID for masked positions
|
| 38 |
+
buffer_size: Maximum buffer size
|
| 39 |
+
noise: Noise schedule
|
| 40 |
+
tokenizer: Peptide tokenizer
|
| 41 |
+
"""
|
| 42 |
+
# Initialize base MCTS (will set self.rewardFunc later)
|
| 43 |
+
# Note: base MCTS expects 'policy_model' not 'diffusion_model'
|
| 44 |
+
# Create a minimal config object for base MCTS
|
| 45 |
+
class MinimalConfig:
|
| 46 |
+
def __init__(self):
|
| 47 |
+
self.noise = type('obj', (object,), {
|
| 48 |
+
'type': 'loglinear',
|
| 49 |
+
'sigma_min': 1e-4,
|
| 50 |
+
'sigma_max': 20
|
| 51 |
+
})()
|
| 52 |
+
config = MinimalConfig()
|
| 53 |
+
|
| 54 |
+
super().__init__(
|
| 55 |
+
args=args,
|
| 56 |
+
config=config,
|
| 57 |
+
policy_model=diffusion_model,
|
| 58 |
+
pretrained=diffusion_model, # Use same model
|
| 59 |
+
score_func_names=['affinity', 'gated_reward', 'placeholder1', 'placeholder2', 'placeholder3'] # 5 objectives
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Set TD3B-specific attributes
|
| 63 |
+
self.td3b_reward_func = td3b_reward_function
|
| 64 |
+
self.confidence_weighting = confidence_weighting
|
| 65 |
+
self.mask_index = mask_index
|
| 66 |
+
self.buffer_size = buffer_size
|
| 67 |
+
self.noise = noise
|
| 68 |
+
self.tokenizer = tokenizer if tokenizer is not None else diffusion_model.tokenizer
|
| 69 |
+
|
| 70 |
+
# Override num_obj to ensure it's 5 (matching our padded rewards)
|
| 71 |
+
self.num_obj = 5
|
| 72 |
+
|
| 73 |
+
# Override rewardFunc for compatibility
|
| 74 |
+
self.rewardFunc = self._td3b_reward_wrapper
|
| 75 |
+
|
| 76 |
+
def _td3b_reward_wrapper(self, input_seqs):
|
| 77 |
+
"""
|
| 78 |
+
Wrapper to make TD3BRewardFunction compatible with existing MCTS interface.
|
| 79 |
+
Returns (N, 5) array to match base MCTS expectations.
|
| 80 |
+
The 5 columns are: [affinity, gated_reward, 0, 0, 0] (padding last 3)
|
| 81 |
+
"""
|
| 82 |
+
import numpy as np
|
| 83 |
+
total_rewards, info = self.td3b_reward_func(input_seqs)
|
| 84 |
+
# info contains: 'affinities', 'confidences', 'score_vectors'
|
| 85 |
+
|
| 86 |
+
# Store confidences for later use (attach to self for access in updateBuffer)
|
| 87 |
+
self._last_confidences = info['confidences']
|
| 88 |
+
|
| 89 |
+
# Pad score_vectors from (N, 2) to (N, 5) to match base MCTS
|
| 90 |
+
# Original columns: [affinity, gated_reward]
|
| 91 |
+
# Padded to: [affinity, gated_reward, 0, 0, 0]
|
| 92 |
+
score_vectors = info['score_vectors'] # (N, 2)
|
| 93 |
+
padded = np.zeros((score_vectors.shape[0], 5))
|
| 94 |
+
padded[:, :2] = score_vectors # Copy affinity and gated_reward
|
| 95 |
+
|
| 96 |
+
return padded
|
| 97 |
+
|
| 98 |
+
def updateBuffer(self, x_final, log_rnd, score_vectors, childSequences):
|
| 99 |
+
"""
|
| 100 |
+
TD3B version: stores directional labels and confidence scores.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
x_final: (B, L) final sequence tokens
|
| 104 |
+
log_rnd: (B,) log importance weights (trajectory-level)
|
| 105 |
+
score_vectors: (B, K) score arrays
|
| 106 |
+
childSequences: List of B SMILES strings
|
| 107 |
+
Returns:
|
| 108 |
+
traj_log_rnds: (B,) updated log importance weights
|
| 109 |
+
scalar_rewards: (B,) scalar rewards
|
| 110 |
+
"""
|
| 111 |
+
B = x_final.shape[0]
|
| 112 |
+
traj_log_rnds, scalar_rewards = [], []
|
| 113 |
+
|
| 114 |
+
# Get confidences from last reward computation
|
| 115 |
+
confidences = getattr(self, '_last_confidences', np.ones(B))
|
| 116 |
+
|
| 117 |
+
for i in range(B):
|
| 118 |
+
sv = np.asarray(score_vectors[i], dtype=float) # [affinity, gated_reward]
|
| 119 |
+
confidence = confidences[i]
|
| 120 |
+
|
| 121 |
+
# For TD3B, the "scalar reward" is the gated reward (second element)
|
| 122 |
+
scalar_reward = float(sv[1]) # gated_reward = g_ψ · (d* · sigmoid(f_φ-0.5)/α)
|
| 123 |
+
|
| 124 |
+
# Compute confidence-weighted importance weight
|
| 125 |
+
# w(y) = κ(y) · exp(S_total / α)
|
| 126 |
+
# In log space: log w(y) = log κ(y) + S_total / α
|
| 127 |
+
log_confidence = np.log(np.maximum(confidence, self.confidence_weighting.min_confidence))
|
| 128 |
+
traj_log_rnd = log_rnd[i] + (scalar_reward / self.args.alpha) + log_confidence
|
| 129 |
+
|
| 130 |
+
# Infer directional label from oracle (sign of gated reward)
|
| 131 |
+
# If gated_reward > 0, peptide is predicted as target direction
|
| 132 |
+
# This is approximate; in practice you might want to query f_φ directly
|
| 133 |
+
directional_label = np.sign(scalar_reward) if scalar_reward != 0 else 0.0
|
| 134 |
+
|
| 135 |
+
item = {
|
| 136 |
+
"x_final": x_final[i].clone(),
|
| 137 |
+
"log_rnd": traj_log_rnd.clone() if isinstance(traj_log_rnd, torch.Tensor) else torch.tensor(traj_log_rnd),
|
| 138 |
+
"final_reward": scalar_reward,
|
| 139 |
+
"score_vector": sv.copy(),
|
| 140 |
+
"seq": childSequences[i],
|
| 141 |
+
# TD3B-specific additions
|
| 142 |
+
"directional_label": directional_label,
|
| 143 |
+
"confidence": confidence,
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
# Pareto dominance filtering (same as base class)
|
| 147 |
+
from peptide_mcts import dominated_by, dominates
|
| 148 |
+
|
| 149 |
+
if any(dominated_by(sv, bi["score_vector"]) for bi in self.buffer):
|
| 150 |
+
self._debug_buffer_decision(sv, "rejected_dominated")
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
# Remove dominated items
|
| 154 |
+
keep = []
|
| 155 |
+
for bi in self.buffer:
|
| 156 |
+
if not dominates(sv, bi["score_vector"]):
|
| 157 |
+
keep.append(bi)
|
| 158 |
+
self.buffer = keep
|
| 159 |
+
|
| 160 |
+
# Insert with capacity constraint
|
| 161 |
+
if len(self.buffer) < self.buffer_size:
|
| 162 |
+
self.buffer.append(item)
|
| 163 |
+
else:
|
| 164 |
+
# Replace worst item
|
| 165 |
+
worst_i = int(np.argmin([np.sum(bi["score_vector"]) for bi in self.buffer]))
|
| 166 |
+
self.buffer[worst_i] = item
|
| 167 |
+
|
| 168 |
+
self._debug_buffer_decision(sv, "inserted", {"new_len": len(self.buffer)})
|
| 169 |
+
|
| 170 |
+
traj_log_rnds.append(traj_log_rnd)
|
| 171 |
+
scalar_rewards.append(scalar_reward)
|
| 172 |
+
|
| 173 |
+
traj_log_rnds = torch.stack([torch.tensor(x) if not isinstance(x, torch.Tensor) else x for x in traj_log_rnds], dim=0) if traj_log_rnds else torch.empty(0)
|
| 174 |
+
scalar_rewards = np.asarray(scalar_rewards, dtype=float)
|
| 175 |
+
return traj_log_rnds, scalar_rewards
|
| 176 |
+
|
| 177 |
+
def forward(self, resetTree=False):
|
| 178 |
+
"""
|
| 179 |
+
TD3B version of forward that returns 7 values.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
x_final: (N, L) sequence tokens
|
| 183 |
+
log_rnd: (N,) log importance weights
|
| 184 |
+
final_rewards: (N,) scalar rewards
|
| 185 |
+
score_vectors: (N, K) score arrays
|
| 186 |
+
sequences: List of N SMILES strings
|
| 187 |
+
directional_labels: (N,) directional labels
|
| 188 |
+
confidences: (N,) confidence scores
|
| 189 |
+
"""
|
| 190 |
+
self.reset(resetTree)
|
| 191 |
+
|
| 192 |
+
while (self.iter_num < self.num_iter):
|
| 193 |
+
self.iter_num += 1
|
| 194 |
+
|
| 195 |
+
# traverse the tree form the root node until a leaf node
|
| 196 |
+
with self.timer.section("select"):
|
| 197 |
+
leafNode, _ = self.select(self.rootNode)
|
| 198 |
+
|
| 199 |
+
# expand leaf node into num_children partially unmasked sequences at the next timestep
|
| 200 |
+
with self.timer.section("expand"):
|
| 201 |
+
self.expand(leafNode)
|
| 202 |
+
|
| 203 |
+
final_x, log_rnd, final_rewards, score_vectors, sequences, directional_labels, confidences = self.consolidateBuffer()
|
| 204 |
+
|
| 205 |
+
rows = self.timer.summary()
|
| 206 |
+
print("\n=== Timing summary (by total time) ===")
|
| 207 |
+
for name, cnt, total, mean, p50, p95 in rows:
|
| 208 |
+
print(f"{name:30s} n={cnt:5d} total={total:8.3f}s mean={mean*1e3:7.2f}ms "
|
| 209 |
+
f"p50={p50*1e3:7.2f}ms p95={p95*1e3:7.2f}ms")
|
| 210 |
+
|
| 211 |
+
return final_x, log_rnd, final_rewards, score_vectors, sequences, directional_labels, confidences
|
| 212 |
+
|
| 213 |
+
def consolidateBuffer(self):
|
| 214 |
+
"""
|
| 215 |
+
TD3B version: includes directional labels and confidences.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
x_final: (N, L) sequence tokens
|
| 219 |
+
log_rnd: (N,) log importance weights
|
| 220 |
+
final_rewards: (N,) scalar rewards
|
| 221 |
+
score_vectors: (N, K) score arrays
|
| 222 |
+
sequences: List of N SMILES strings
|
| 223 |
+
directional_labels: (N,) directional labels
|
| 224 |
+
confidences: (N,) confidence scores
|
| 225 |
+
"""
|
| 226 |
+
# Handle empty buffer case - return empty tensors/arrays
|
| 227 |
+
if len(self.buffer) == 0:
|
| 228 |
+
import logging
|
| 229 |
+
logger = logging.getLogger(__name__)
|
| 230 |
+
logger.warning("MCTS buffer is empty - no valid sequences found. Returning empty results.")
|
| 231 |
+
|
| 232 |
+
# Return empty tensors/arrays with correct shapes
|
| 233 |
+
# Use policy_model (set by base MCTS class) to get device
|
| 234 |
+
device = self.policy_model.device if hasattr(self.policy_model, 'device') else 'cpu'
|
| 235 |
+
return (
|
| 236 |
+
torch.empty(0, 0, dtype=torch.long, device=device), # x_final: (0, 0)
|
| 237 |
+
torch.empty(0, dtype=torch.float32, device=device), # log_rnd: (0,)
|
| 238 |
+
np.empty(0, dtype=np.float32), # final_rewards: (0,)
|
| 239 |
+
np.empty((0, 0), dtype=np.float32), # score_vectors: (0, 0)
|
| 240 |
+
[], # sequences: empty list
|
| 241 |
+
np.empty(0, dtype=np.float32), # directional_labels: (0,)
|
| 242 |
+
np.empty(0, dtype=np.float32) # confidences: (0,)
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
x_final = []
|
| 246 |
+
log_rnd = []
|
| 247 |
+
final_rewards = []
|
| 248 |
+
score_vectors = []
|
| 249 |
+
sequences = []
|
| 250 |
+
directional_labels = []
|
| 251 |
+
confidences = []
|
| 252 |
+
|
| 253 |
+
for item in self.buffer:
|
| 254 |
+
x_final.append(item["x_final"])
|
| 255 |
+
log_rnd.append(item["log_rnd"])
|
| 256 |
+
final_rewards.append(item["final_reward"])
|
| 257 |
+
score_vectors.append(item["score_vector"])
|
| 258 |
+
sequences.append(item["seq"])
|
| 259 |
+
directional_labels.append(item.get("directional_label", 0.0))
|
| 260 |
+
confidences.append(item.get("confidence", 1.0))
|
| 261 |
+
|
| 262 |
+
x_final = torch.stack(x_final, dim=0) # (N, L)
|
| 263 |
+
log_rnd = torch.stack(log_rnd, dim=0).to(dtype=torch.float32) # (N,)
|
| 264 |
+
final_rewards = np.stack(final_rewards, axis=0).astype(np.float32)
|
| 265 |
+
score_vectors = np.stack(score_vectors, axis=0).astype(np.float32)
|
| 266 |
+
directional_labels = np.array(directional_labels, dtype=np.float32)
|
| 267 |
+
confidences = np.array(confidences, dtype=np.float32)
|
| 268 |
+
|
| 269 |
+
return x_final, log_rnd, final_rewards, score_vectors, sequences, directional_labels, confidences
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def create_td3b_mcts(
|
| 273 |
+
args,
|
| 274 |
+
diffusion_model,
|
| 275 |
+
td3b_reward_function: TD3BRewardFunction,
|
| 276 |
+
alpha: float = 0.1,
|
| 277 |
+
**kwargs
|
| 278 |
+
) -> TD3B_MCTS:
|
| 279 |
+
"""
|
| 280 |
+
Factory function to create TD3B MCTS instance.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
args: Configuration arguments
|
| 284 |
+
diffusion_model: MDLM model
|
| 285 |
+
td3b_reward_function: TD3BRewardFunction instance
|
| 286 |
+
alpha: Temperature for importance weighting
|
| 287 |
+
**kwargs: Additional MCTS arguments
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
mcts: TD3B_MCTS instance
|
| 291 |
+
"""
|
| 292 |
+
# Create confidence weighting module
|
| 293 |
+
confidence_weighting = TD3BConfidenceWeighting(
|
| 294 |
+
alpha=alpha,
|
| 295 |
+
min_confidence=0.1
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Create TD3B MCTS
|
| 299 |
+
mcts = TD3B_MCTS(
|
| 300 |
+
args=args,
|
| 301 |
+
diffusion_model=diffusion_model,
|
| 302 |
+
td3b_reward_function=td3b_reward_function,
|
| 303 |
+
confidence_weighting=confidence_weighting,
|
| 304 |
+
**kwargs
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
return mcts
|
td3b/td3b_scoring.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TD3B Scoring Functions
|
| 3 |
+
Implements gated allosteric reward combining affinity prediction and directional oracle.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import List, Tuple, Optional
|
| 10 |
+
from .direction_oracle import DirectionalOracle
|
| 11 |
+
from scoring.functions.binding import BindingAffinity
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TD3BRewardFunction:
|
| 15 |
+
"""
|
| 16 |
+
Implements the TD3B gated total reward with sigmoid temperature scaling:
|
| 17 |
+
S_total(y; d*, x) = g_ψ(y, x) · σ(d* · (f_φ(y, x) -0.5) / τ)
|
| 18 |
+
|
| 19 |
+
where:
|
| 20 |
+
- g_ψ(y, x): affinity predictor (BindingAffinity)
|
| 21 |
+
- σ: sigmoid function σ(z) = 1 / (1 + exp(-z))
|
| 22 |
+
- d* ∈ {+1, -1}: target direction (agonist/antagonist)
|
| 23 |
+
- f_φ(y, x): directional oracle (DirectionalOracle)
|
| 24 |
+
* Directional oracle outputs p(agonist) in [0, 1]
|
| 25 |
+
- τ: temperature parameter (lower = sharper distribution)
|
| 26 |
+
- y: peptide sequence
|
| 27 |
+
- x: target protein sequence
|
| 28 |
+
|
| 29 |
+
Note: The placeholder oracle outputs 0.5, which makes (f_φ - 0.5) = 0, resulting in
|
| 30 |
+
neutral gating during initial training before a real oracle is trained.
|
| 31 |
+
|
| 32 |
+
Benefits of sigmoid formulation:
|
| 33 |
+
1. Output always in [0, 1] → bounded gated rewards
|
| 34 |
+
2. Temperature τ controls sharpness of selection
|
| 35 |
+
3. Differentiable gating for smooth optimization
|
| 36 |
+
4. Sharper discrimination between aligned and misaligned directions
|
| 37 |
+
|
| 38 |
+
OLD FORMULA (replaced):
|
| 39 |
+
S_total(y; d*, x) = g_ψ(y, x) · (d* · f_φ(y, x))
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
affinity_predictor: BindingAffinity,
|
| 45 |
+
directional_oracle: DirectionalOracle,
|
| 46 |
+
target_direction: float, # +1 for agonist, -1 for antagonist
|
| 47 |
+
target_protein_tokens: torch.Tensor,
|
| 48 |
+
peptide_tokenizer,
|
| 49 |
+
device: torch.device,
|
| 50 |
+
min_affinity_threshold: float = 0.0, # Minimum g_ψ for allosteric control
|
| 51 |
+
use_confidence_weighting: bool = True,
|
| 52 |
+
temperature: float = 0.1 # Temperature for sigmoid sharpening
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
Args:
|
| 56 |
+
affinity_predictor: Pretrained g_ψ model (BindingAffinity)
|
| 57 |
+
directional_oracle: Pretrained f_φ model (DirectionalOracle)
|
| 58 |
+
target_direction: d* in {+1, -1} for agonist/antagonist
|
| 59 |
+
target_protein_tokens: Tokenized target protein sequence
|
| 60 |
+
peptide_tokenizer: Tokenizer for converting SMILES to tokens
|
| 61 |
+
device: Computation device
|
| 62 |
+
min_affinity_threshold: Only apply directional control if g_ψ > threshold
|
| 63 |
+
use_confidence_weighting: Whether to use κ(y) for importance weights
|
| 64 |
+
temperature: Temperature τ for sigmoid sharpening (lower = sharper)
|
| 65 |
+
Default 0.1 makes distribution sharper than standard sigmoid
|
| 66 |
+
"""
|
| 67 |
+
self.g_psi = affinity_predictor # Affinity predictor
|
| 68 |
+
self.f_phi = directional_oracle # Directional oracle
|
| 69 |
+
self.target_direction = target_direction # d* ∈ {+1, -1}
|
| 70 |
+
self.protein_tokens = target_protein_tokens
|
| 71 |
+
self.peptide_tokenizer = peptide_tokenizer
|
| 72 |
+
self.device = device
|
| 73 |
+
self.min_affinity_threshold = min_affinity_threshold
|
| 74 |
+
self.use_confidence_weighting = use_confidence_weighting
|
| 75 |
+
self.temperature = temperature # τ for sigmoid temperature
|
| 76 |
+
|
| 77 |
+
def compute_affinity(self, peptide_seqs: List[str]) -> np.ndarray:
|
| 78 |
+
"""
|
| 79 |
+
Compute binding affinity g_ψ(y, x).
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
peptide_seqs: List of peptide SMILES strings
|
| 83 |
+
Returns:
|
| 84 |
+
affinities: (N,) array of affinity scores
|
| 85 |
+
"""
|
| 86 |
+
affinities = self.g_psi(peptide_seqs) # Returns list of scores
|
| 87 |
+
return np.array(affinities)
|
| 88 |
+
|
| 89 |
+
def compute_direction(self, peptide_seqs: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 90 |
+
"""
|
| 91 |
+
Compute directional bias f_φ(y, x) and confidence κ(y).
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
peptide_seqs: List of peptide SMILES strings
|
| 95 |
+
Returns:
|
| 96 |
+
directions: (N,) tensor of directional biases
|
| 97 |
+
- DirectionalOracle: p(agonist) in [0, 1]
|
| 98 |
+
confidences: (N,) tensor of confidence scores in [0, 1]
|
| 99 |
+
"""
|
| 100 |
+
# Tokenize peptides in a single batch for speed
|
| 101 |
+
peptide_tokens = None
|
| 102 |
+
peptide_token_dict = None
|
| 103 |
+
try:
|
| 104 |
+
peptide_token_dict = self.peptide_tokenizer(
|
| 105 |
+
peptide_seqs,
|
| 106 |
+
return_tensors='pt',
|
| 107 |
+
padding=True
|
| 108 |
+
)
|
| 109 |
+
peptide_token_dict = {k: v.to(self.device) for k, v in peptide_token_dict.items()}
|
| 110 |
+
peptide_tokens = peptide_token_dict.get('input_ids')
|
| 111 |
+
except Exception:
|
| 112 |
+
peptide_tokens_list = []
|
| 113 |
+
for seq in peptide_seqs:
|
| 114 |
+
tokens = self.peptide_tokenizer(seq, return_tensors='pt', padding=True)
|
| 115 |
+
peptide_tokens_list.append(tokens['input_ids'].to(self.device))
|
| 116 |
+
|
| 117 |
+
# Batch tokenization (simple stacking, assumes same length after padding)
|
| 118 |
+
try:
|
| 119 |
+
peptide_tokens = torch.cat(peptide_tokens_list, dim=0) # (N, L)
|
| 120 |
+
except Exception:
|
| 121 |
+
# Fallback: pad to max length
|
| 122 |
+
max_len = max(t.size(1) for t in peptide_tokens_list)
|
| 123 |
+
peptide_tokens = torch.zeros(len(peptide_tokens_list), max_len, dtype=torch.long, device=self.device)
|
| 124 |
+
for i, tokens in enumerate(peptide_tokens_list):
|
| 125 |
+
peptide_tokens[i, :tokens.size(1)] = tokens[0]
|
| 126 |
+
|
| 127 |
+
# Expand protein tokens to batch size
|
| 128 |
+
protein_tokens = self.protein_tokens.expand(len(peptide_seqs), -1) # (N, L_prot)
|
| 129 |
+
|
| 130 |
+
# Compute direction and confidence
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
if peptide_token_dict is not None and hasattr(self.f_phi, "_normalize_token_dict"):
|
| 133 |
+
directions, confidences = self.f_phi.predict_with_confidence(
|
| 134 |
+
peptide_token_dict, protein_tokens
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
directions, confidences = self.f_phi.predict_with_confidence(
|
| 138 |
+
peptide_tokens, protein_tokens
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
return directions, confidences
|
| 142 |
+
|
| 143 |
+
def compute_gated_reward(
|
| 144 |
+
self,
|
| 145 |
+
peptide_seqs: List[str]
|
| 146 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 147 |
+
"""
|
| 148 |
+
Compute gated total reward with sigmoid temperature scaling.
|
| 149 |
+
|
| 150 |
+
NEW FORMULA:
|
| 151 |
+
S_total = g_ψ · σ(d* · (f_φ-0.5) / τ)
|
| 152 |
+
|
| 153 |
+
Where:
|
| 154 |
+
- g_ψ: affinity score
|
| 155 |
+
- σ: sigmoid function
|
| 156 |
+
- d*: target direction (+1 or -1)
|
| 157 |
+
- f_φ: directional oracle prediction (in [-1, +1])
|
| 158 |
+
- τ: temperature (lower = sharper distribution)
|
| 159 |
+
|
| 160 |
+
OLD FORMULA (replaced):
|
| 161 |
+
S_total = g_ψ · (d* · f_φ)
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
peptide_seqs: List of peptide SMILES strings
|
| 165 |
+
Returns:
|
| 166 |
+
total_rewards: (N,) array of gated total rewards
|
| 167 |
+
affinities: (N,) array of affinity scores g_ψ
|
| 168 |
+
confidences: (N,) array of confidence scores κ
|
| 169 |
+
directions: (N,) array of directional predictions f_φ
|
| 170 |
+
"""
|
| 171 |
+
# Compute affinity g_ψ(y, x)
|
| 172 |
+
affinities = self.compute_affinity(peptide_seqs) # (N,)
|
| 173 |
+
|
| 174 |
+
# Compute directional bias f_φ(y, x) and confidence κ(y)
|
| 175 |
+
directions, confidences = self.compute_direction(peptide_seqs) # (N,), (N,)
|
| 176 |
+
directions = directions.cpu().numpy()
|
| 177 |
+
confidences = confidences.cpu().numpy()
|
| 178 |
+
|
| 179 |
+
# NEW: Sigmoid-based gated reward with temperature scaling
|
| 180 |
+
# S_total = g_ψ · σ(d* · (f_φ-0.5) / τ), use 0.5 as the threshold to make it balanced/symmetric.
|
| 181 |
+
directional_score = self.target_direction * (directions - 0.5) # (N,) in [-1, +1]
|
| 182 |
+
|
| 183 |
+
# Apply temperature scaling (lower τ → sharper sigmoid)
|
| 184 |
+
scaled_score = directional_score / self.temperature # (N,)
|
| 185 |
+
|
| 186 |
+
# Apply sigmoid to get value in [0, 1]
|
| 187 |
+
# σ(x) = 1 / (1 + exp(-x))
|
| 188 |
+
sigmoid_weight = 1.0 / (1.0 + np.exp(-scaled_score)) # (N,) in [0, 1]
|
| 189 |
+
|
| 190 |
+
# Gate affinity with sigmoid weight
|
| 191 |
+
gated_rewards = affinities * sigmoid_weight # (N,)
|
| 192 |
+
|
| 193 |
+
# Optional: only apply directional control if affinity is high enough
|
| 194 |
+
# This implements the "allosteric control only for binders" principle
|
| 195 |
+
low_affinity_mask = affinities < self.min_affinity_threshold
|
| 196 |
+
gated_rewards[low_affinity_mask] = affinities[low_affinity_mask] * 0.1 # Downweight
|
| 197 |
+
|
| 198 |
+
return gated_rewards, affinities, confidences, directions
|
| 199 |
+
|
| 200 |
+
def __call__(
|
| 201 |
+
self,
|
| 202 |
+
input_seqs: List[str]
|
| 203 |
+
) -> Tuple[np.ndarray, dict]:
|
| 204 |
+
"""
|
| 205 |
+
Main interface for reward computation.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
input_seqs: List of peptide SMILES strings
|
| 209 |
+
Returns:
|
| 210 |
+
rewards: (N,) array of total rewards
|
| 211 |
+
info: dict with 'affinities', 'confidences', 'directions', 'score_vectors'
|
| 212 |
+
"""
|
| 213 |
+
total_rewards, affinities, confidences, directions = self.compute_gated_reward(input_seqs)
|
| 214 |
+
|
| 215 |
+
info = {
|
| 216 |
+
'affinities': affinities,
|
| 217 |
+
'confidences': confidences,
|
| 218 |
+
'directions': directions, # Add direction predictions
|
| 219 |
+
'score_vectors': np.stack([affinities, total_rewards], axis=1) # (N, 2)
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
return total_rewards, info
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class TD3BConfidenceWeighting:
|
| 226 |
+
"""
|
| 227 |
+
Implements confidence-weighted importance sampling for TD3B.
|
| 228 |
+
|
| 229 |
+
The importance weights w(y) are modulated by confidence κ(y):
|
| 230 |
+
w(y) = κ(y) · exp(S_total(y) / α)
|
| 231 |
+
|
| 232 |
+
This distinguishes between:
|
| 233 |
+
- Full agonists/antagonists: high κ (|f_φ| ≈ 1)
|
| 234 |
+
- Partial agonists/antagonists: medium κ (|f_φ| ≈ 0.5)
|
| 235 |
+
- Non-selective binders: low κ (|f_φ| ≈ 0)
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
alpha: float = 0.1, # Temperature for reward scaling
|
| 241 |
+
min_confidence: float = 0.1 # Minimum confidence to avoid zero weights
|
| 242 |
+
):
|
| 243 |
+
"""
|
| 244 |
+
Args:
|
| 245 |
+
alpha: Temperature parameter for reward scaling
|
| 246 |
+
min_confidence: Minimum confidence threshold
|
| 247 |
+
"""
|
| 248 |
+
self.alpha = alpha
|
| 249 |
+
self.min_confidence = min_confidence
|
| 250 |
+
|
| 251 |
+
def compute_importance_weights(
|
| 252 |
+
self,
|
| 253 |
+
rewards: np.ndarray,
|
| 254 |
+
confidences: np.ndarray
|
| 255 |
+
) -> np.ndarray:
|
| 256 |
+
"""
|
| 257 |
+
Compute confidence-weighted importance weights.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
rewards: (N,) array of total rewards S_total
|
| 261 |
+
confidences: (N,) array of confidence scores κ ∈ [0, 1]
|
| 262 |
+
Returns:
|
| 263 |
+
weights: (N,) array of importance weights
|
| 264 |
+
"""
|
| 265 |
+
# Clip confidences to avoid zero weights
|
| 266 |
+
confidences = np.maximum(confidences, self.min_confidence)
|
| 267 |
+
|
| 268 |
+
# Compute importance weights: w(y) = κ(y) · exp(S_total / α)
|
| 269 |
+
log_weights = rewards / self.alpha # (N,)
|
| 270 |
+
weights = confidences * np.exp(log_weights) # (N,)
|
| 271 |
+
|
| 272 |
+
return weights
|
| 273 |
+
|
| 274 |
+
def compute_log_importance_weights(
|
| 275 |
+
self,
|
| 276 |
+
rewards: np.ndarray,
|
| 277 |
+
confidences: np.ndarray
|
| 278 |
+
) -> np.ndarray:
|
| 279 |
+
"""
|
| 280 |
+
Compute log importance weights for numerical stability.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
rewards: (N,) array of total rewards
|
| 284 |
+
confidences: (N,) array of confidence scores
|
| 285 |
+
Returns:
|
| 286 |
+
log_weights: (N,) array of log importance weights
|
| 287 |
+
"""
|
| 288 |
+
# Clip confidences
|
| 289 |
+
confidences = np.maximum(confidences, self.min_confidence)
|
| 290 |
+
|
| 291 |
+
# log w(y) = log κ(y) + S_total / α
|
| 292 |
+
log_weights = np.log(confidences) + (rewards / self.alpha) # (N,)
|
| 293 |
+
|
| 294 |
+
return log_weights
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# Factory function for creating TD3B reward function
|
| 298 |
+
def create_td3b_reward_function(
|
| 299 |
+
affinity_predictor: BindingAffinity,
|
| 300 |
+
target_protein_seq: str,
|
| 301 |
+
target_direction: str, # 'agonist' or 'antagonist'
|
| 302 |
+
peptide_tokenizer,
|
| 303 |
+
device: torch.device,
|
| 304 |
+
directional_oracle: Optional[DirectionalOracle] = None,
|
| 305 |
+
directional_oracle_checkpoint: Optional[str] = None,
|
| 306 |
+
base_path: Optional[str] = None,
|
| 307 |
+
direction_oracle_tr2d2_checkpoint: Optional[str] = None,
|
| 308 |
+
direction_oracle_tokenizer_vocab: Optional[str] = None,
|
| 309 |
+
direction_oracle_tokenizer_splits: Optional[str] = None,
|
| 310 |
+
direction_oracle_esm_name: str = "facebook/esm2_t33_650M_UR50D",
|
| 311 |
+
direction_oracle_esm_cache_dir: Optional[str] = None,
|
| 312 |
+
direction_oracle_esm_local_files_only: bool = False,
|
| 313 |
+
direction_oracle_max_ligand_length: int = 768,
|
| 314 |
+
direction_oracle_max_protein_length: int = 1024,
|
| 315 |
+
direction_oracle_d_model: int = 256,
|
| 316 |
+
direction_oracle_n_heads: int = 4,
|
| 317 |
+
direction_oracle_n_self_attn_layers: int = 1,
|
| 318 |
+
direction_oracle_n_bmca_layers: int = 2,
|
| 319 |
+
direction_oracle_dropout: float = 0.3,
|
| 320 |
+
**kwargs
|
| 321 |
+
) -> TD3BRewardFunction:
|
| 322 |
+
"""
|
| 323 |
+
Factory function to create TD3B reward function.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
affinity_predictor: Pretrained binding affinity model
|
| 327 |
+
directional_oracle: Preloaded DirectionalOracle instance (optional)
|
| 328 |
+
directional_oracle_checkpoint: Path to Directional oracle checkpoint (optional if instance provided)
|
| 329 |
+
base_path: Base path for default oracle assets
|
| 330 |
+
direction_oracle_tr2d2_checkpoint: TR2-D2 checkpoint for ligand encoder
|
| 331 |
+
direction_oracle_tokenizer_vocab: SMILES tokenizer vocab path
|
| 332 |
+
direction_oracle_tokenizer_splits: SMILES tokenizer splits path
|
| 333 |
+
target_protein_seq: Target protein amino acid sequence
|
| 334 |
+
target_direction: 'agonist' (+1) or 'antagonist' (-1)
|
| 335 |
+
peptide_tokenizer: Tokenizer for peptides
|
| 336 |
+
device: Computation device
|
| 337 |
+
**kwargs: Additional arguments for TD3BRewardFunction
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
reward_function: TD3BRewardFunction instance
|
| 341 |
+
"""
|
| 342 |
+
if directional_oracle is None:
|
| 343 |
+
if base_path is None:
|
| 344 |
+
base_path = "To Be Added"
|
| 345 |
+
tr2d2_root = os.path.join(base_path, "tr2d2-pep")
|
| 346 |
+
if directional_oracle_checkpoint is None:
|
| 347 |
+
directional_oracle_checkpoint = os.path.join(
|
| 348 |
+
tr2d2_root, "direction_oracle.pt"
|
| 349 |
+
)
|
| 350 |
+
if direction_oracle_tr2d2_checkpoint is None:
|
| 351 |
+
direction_oracle_tr2d2_checkpoint = os.path.join(
|
| 352 |
+
tr2d2_root, "pretrained", "peptune-pretrained.ckpt"
|
| 353 |
+
)
|
| 354 |
+
if direction_oracle_tokenizer_vocab is None:
|
| 355 |
+
direction_oracle_tokenizer_vocab = os.path.join(
|
| 356 |
+
tr2d2_root, "tokenizer", "new_vocab.txt"
|
| 357 |
+
)
|
| 358 |
+
if direction_oracle_tokenizer_splits is None:
|
| 359 |
+
direction_oracle_tokenizer_splits = os.path.join(
|
| 360 |
+
tr2d2_root, "tokenizer", "new_splits.txt"
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
directional_oracle = DirectionalOracle(
|
| 364 |
+
model_ckpt=directional_oracle_checkpoint,
|
| 365 |
+
tr2d2_checkpoint=direction_oracle_tr2d2_checkpoint,
|
| 366 |
+
tokenizer_vocab=direction_oracle_tokenizer_vocab,
|
| 367 |
+
tokenizer_splits=direction_oracle_tokenizer_splits,
|
| 368 |
+
esm_name=direction_oracle_esm_name,
|
| 369 |
+
d_model=direction_oracle_d_model,
|
| 370 |
+
n_heads=direction_oracle_n_heads,
|
| 371 |
+
n_self_attn_layers=direction_oracle_n_self_attn_layers,
|
| 372 |
+
n_bmca_layers=direction_oracle_n_bmca_layers,
|
| 373 |
+
dropout=direction_oracle_dropout,
|
| 374 |
+
max_ligand_length=direction_oracle_max_ligand_length,
|
| 375 |
+
max_protein_length=direction_oracle_max_protein_length,
|
| 376 |
+
device=device,
|
| 377 |
+
esm_cache_dir=direction_oracle_esm_cache_dir,
|
| 378 |
+
esm_local_files_only=direction_oracle_esm_local_files_only,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
directional_oracle.eval()
|
| 382 |
+
|
| 383 |
+
protein_tokens = directional_oracle.encode_protein(target_protein_seq)
|
| 384 |
+
|
| 385 |
+
# Convert direction string to numerical value
|
| 386 |
+
direction_map = {'agonist': +1.0, 'antagonist': -1.0}
|
| 387 |
+
d_star = direction_map.get(target_direction.lower(), +1.0)
|
| 388 |
+
|
| 389 |
+
# Create reward function
|
| 390 |
+
reward_function = TD3BRewardFunction(
|
| 391 |
+
affinity_predictor=affinity_predictor,
|
| 392 |
+
directional_oracle=directional_oracle,
|
| 393 |
+
target_direction=d_star,
|
| 394 |
+
target_protein_tokens=protein_tokens,
|
| 395 |
+
peptide_tokenizer=peptide_tokenizer,
|
| 396 |
+
device=device,
|
| 397 |
+
**kwargs
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
return reward_function
|
tokenizer/my_tokenizers.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
from transformers import PreTrainedTokenizer
|
| 6 |
+
from SmilesPE.tokenizer import SPE_Tokenizer
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
def load_vocab(vocab_file):
|
| 10 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 11 |
+
vocab = collections.OrderedDict()
|
| 12 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 13 |
+
tokens = reader.readlines()
|
| 14 |
+
for index, token in enumerate(tokens):
|
| 15 |
+
token = token.rstrip("\n")
|
| 16 |
+
vocab[token] = index
|
| 17 |
+
return vocab
|
| 18 |
+
|
| 19 |
+
class Atomwise_Tokenizer(object):
|
| 20 |
+
"""Run atom-level SMILES tokenization"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
""" Constructs a atom-level Tokenizer.
|
| 24 |
+
"""
|
| 25 |
+
# self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 26 |
+
self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 27 |
+
|
| 28 |
+
self.regex = re.compile(self.regex_pattern)
|
| 29 |
+
|
| 30 |
+
def tokenize(self, text):
|
| 31 |
+
""" Basic Tokenization of a SMILES.
|
| 32 |
+
"""
|
| 33 |
+
tokens = [token for token in self.regex.findall(text)]
|
| 34 |
+
return tokens
|
| 35 |
+
|
| 36 |
+
class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
|
| 37 |
+
r"""
|
| 38 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 39 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 40 |
+
should refer to the superclass for more information regarding methods.
|
| 41 |
+
Args:
|
| 42 |
+
vocab_file (:obj:`string`):
|
| 43 |
+
File containing the vocabulary.
|
| 44 |
+
spe_file (:obj:`string`):
|
| 45 |
+
File containing the trained SMILES Pair Encoding vocabulary.
|
| 46 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 47 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 48 |
+
token instead.
|
| 49 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 50 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 51 |
+
for sequence classification or for a text and a question for question answering.
|
| 52 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 53 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 54 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 55 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 56 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 57 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 58 |
+
special tokens.
|
| 59 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 60 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 61 |
+
modeling. This is the token which the model will try to predict.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, vocab_file, spe_file,
|
| 65 |
+
unk_token="[UNK]",
|
| 66 |
+
sep_token="[SEP]",
|
| 67 |
+
pad_token="[PAD]",
|
| 68 |
+
cls_token="[CLS]",
|
| 69 |
+
mask_token="[MASK]",
|
| 70 |
+
**kwargs):
|
| 71 |
+
if not os.path.isfile(vocab_file):
|
| 72 |
+
raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
|
| 73 |
+
if not os.path.isfile(spe_file):
|
| 74 |
+
raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
|
| 75 |
+
|
| 76 |
+
self.vocab = load_vocab(vocab_file)
|
| 77 |
+
self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
|
| 78 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 79 |
+
self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
|
| 80 |
+
|
| 81 |
+
super().__init__(
|
| 82 |
+
unk_token=unk_token,
|
| 83 |
+
sep_token=sep_token,
|
| 84 |
+
pad_token=pad_token,
|
| 85 |
+
cls_token=cls_token,
|
| 86 |
+
mask_token=mask_token,
|
| 87 |
+
**kwargs)
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def vocab_size(self):
|
| 91 |
+
return len(self.vocab)
|
| 92 |
+
|
| 93 |
+
def get_vocab(self):
|
| 94 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 95 |
+
|
| 96 |
+
def _tokenize(self, text):
|
| 97 |
+
return self.spe_tokenizer.tokenize(text).split(' ')
|
| 98 |
+
|
| 99 |
+
def _convert_token_to_id(self, token):
|
| 100 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 101 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 102 |
+
|
| 103 |
+
# changed encode and decode functions
|
| 104 |
+
def encode(self, token_array):
|
| 105 |
+
token_ids = []
|
| 106 |
+
token_ids.append(2)
|
| 107 |
+
for token in token_array:
|
| 108 |
+
id = self._convert_token_to_id(token)
|
| 109 |
+
token_ids.append(id)
|
| 110 |
+
token_ids.append(3)
|
| 111 |
+
token_ids = torch.tensor([token_ids])
|
| 112 |
+
attn_mask = torch.ones_like(token_ids)
|
| 113 |
+
return {'input_ids': token_ids, 'attention_mask': attn_mask}
|
| 114 |
+
|
| 115 |
+
def decode(self, token_ids, skip_special_tokens=True):
|
| 116 |
+
token_ids = token_ids.squeeze(0).cpu().tolist()
|
| 117 |
+
token_array = []
|
| 118 |
+
for idx in token_ids:
|
| 119 |
+
if idx == 3: # Stop decoding when token ID 3 is encountered
|
| 120 |
+
break
|
| 121 |
+
if skip_special_tokens and idx in self.all_special_ids:
|
| 122 |
+
continue
|
| 123 |
+
token = self._convert_id_to_token(idx)
|
| 124 |
+
token_array.append(token)
|
| 125 |
+
sequence = "".join(token_array)
|
| 126 |
+
return sequence
|
| 127 |
+
|
| 128 |
+
def batch_decode(self, batch_token_ids, skip_special_tokens=True):
|
| 129 |
+
sequences = []
|
| 130 |
+
for token_ids in batch_token_ids:
|
| 131 |
+
sequences.append(self.decode(token_ids))
|
| 132 |
+
return sequences
|
| 133 |
+
|
| 134 |
+
def get_token_split(self, token_ids):
|
| 135 |
+
if isinstance(token_ids, torch.Tensor):
|
| 136 |
+
token_ids = token_ids.cpu().tolist()
|
| 137 |
+
|
| 138 |
+
token_array = []
|
| 139 |
+
for seq_ids in token_ids:
|
| 140 |
+
seq_array = []
|
| 141 |
+
for id in seq_ids:
|
| 142 |
+
token = self._convert_id_to_token(id)
|
| 143 |
+
seq_array.append(token)
|
| 144 |
+
token_array.append(seq_array)
|
| 145 |
+
|
| 146 |
+
return token_array
|
| 147 |
+
|
| 148 |
+
def _convert_id_to_token(self, index):
|
| 149 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 150 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 151 |
+
|
| 152 |
+
def convert_tokens_to_string(self, tokens):
|
| 153 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 154 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 155 |
+
return out_string
|
| 156 |
+
|
| 157 |
+
def build_inputs_with_special_tokens(
|
| 158 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 159 |
+
) -> List[int]:
|
| 160 |
+
"""
|
| 161 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 162 |
+
by concatenating and adding special tokens.
|
| 163 |
+
A BERT sequence has the following format:
|
| 164 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 165 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 166 |
+
Args:
|
| 167 |
+
token_ids_0 (:obj:`List[int]`):
|
| 168 |
+
List of IDs to which the special tokens will be added
|
| 169 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 170 |
+
Optional second list of IDs for sequence pairs.
|
| 171 |
+
Returns:
|
| 172 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 173 |
+
"""
|
| 174 |
+
if token_ids_1 is None:
|
| 175 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 176 |
+
cls = [self.cls_token_id]
|
| 177 |
+
sep = [self.sep_token_id]
|
| 178 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 179 |
+
|
| 180 |
+
def get_special_tokens_mask(
|
| 181 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 182 |
+
) -> List[int]:
|
| 183 |
+
"""
|
| 184 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 185 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 186 |
+
Args:
|
| 187 |
+
token_ids_0 (:obj:`List[int]`):
|
| 188 |
+
List of ids.
|
| 189 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 190 |
+
Optional second list of IDs for sequence pairs.
|
| 191 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 192 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 193 |
+
Returns:
|
| 194 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
if already_has_special_tokens:
|
| 198 |
+
if token_ids_1 is not None:
|
| 199 |
+
raise ValueError(
|
| 200 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 201 |
+
"ids is already formated with special tokens for the model."
|
| 202 |
+
)
|
| 203 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 204 |
+
|
| 205 |
+
if token_ids_1 is not None:
|
| 206 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 207 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 208 |
+
|
| 209 |
+
def create_token_type_ids_from_sequences(
|
| 210 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 211 |
+
) -> List[int]:
|
| 212 |
+
"""
|
| 213 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 214 |
+
A BERT sequence pair mask has the following format:
|
| 215 |
+
::
|
| 216 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 217 |
+
| first sequence | second sequence |
|
| 218 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 219 |
+
Args:
|
| 220 |
+
token_ids_0 (:obj:`List[int]`):
|
| 221 |
+
List of ids.
|
| 222 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 223 |
+
Optional second list of IDs for sequence pairs.
|
| 224 |
+
Returns:
|
| 225 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 226 |
+
sequence(s).
|
| 227 |
+
"""
|
| 228 |
+
sep = [self.sep_token_id]
|
| 229 |
+
cls = [self.cls_token_id]
|
| 230 |
+
if token_ids_1 is None:
|
| 231 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 232 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 233 |
+
|
| 234 |
+
def save_vocabulary(self, vocab_path):
|
| 235 |
+
"""
|
| 236 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 237 |
+
Args:
|
| 238 |
+
vocab_path (:obj:`str`):
|
| 239 |
+
The directory in which to save the vocabulary.
|
| 240 |
+
Returns:
|
| 241 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 242 |
+
"""
|
| 243 |
+
index = 0
|
| 244 |
+
vocab_file = vocab_path
|
| 245 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 246 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 247 |
+
if index != token_index:
|
| 248 |
+
index = token_index
|
| 249 |
+
writer.write(token + "\n")
|
| 250 |
+
index += 1
|
| 251 |
+
return (vocab_file,)
|
| 252 |
+
|
| 253 |
+
class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
|
| 254 |
+
r"""
|
| 255 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 256 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 257 |
+
should refer to the superclass for more information regarding methods.
|
| 258 |
+
Args:
|
| 259 |
+
vocab_file (:obj:`string`):
|
| 260 |
+
File containing the vocabulary.
|
| 261 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 262 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 263 |
+
token instead.
|
| 264 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 265 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 266 |
+
for sequence classification or for a text and a question for question answering.
|
| 267 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 268 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 269 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 270 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 271 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 272 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 273 |
+
special tokens.
|
| 274 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 275 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 276 |
+
modeling. This is the token which the model will try to predict.
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
def __init__(
|
| 280 |
+
self,
|
| 281 |
+
vocab_file,
|
| 282 |
+
unk_token="[UNK]",
|
| 283 |
+
sep_token="[SEP]",
|
| 284 |
+
pad_token="[PAD]",
|
| 285 |
+
cls_token="[CLS]",
|
| 286 |
+
mask_token="[MASK]",
|
| 287 |
+
**kwargs
|
| 288 |
+
):
|
| 289 |
+
super().__init__(
|
| 290 |
+
unk_token=unk_token,
|
| 291 |
+
sep_token=sep_token,
|
| 292 |
+
pad_token=pad_token,
|
| 293 |
+
cls_token=cls_token,
|
| 294 |
+
mask_token=mask_token,
|
| 295 |
+
**kwargs,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if not os.path.isfile(vocab_file):
|
| 299 |
+
raise ValueError(
|
| 300 |
+
"Can't find a vocabulary file at path '{}'.".format(vocab_file)
|
| 301 |
+
)
|
| 302 |
+
self.vocab = load_vocab(vocab_file)
|
| 303 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 304 |
+
self.tokenizer = Atomwise_Tokenizer()
|
| 305 |
+
|
| 306 |
+
@property
|
| 307 |
+
def vocab_size(self):
|
| 308 |
+
return len(self.vocab)
|
| 309 |
+
|
| 310 |
+
def get_vocab(self):
|
| 311 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def _tokenize(self, text):
|
| 315 |
+
return self.tokenizer.tokenize(text)
|
| 316 |
+
|
| 317 |
+
def _convert_token_to_id(self, token):
|
| 318 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 319 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 320 |
+
|
| 321 |
+
def _convert_id_to_token(self, index):
|
| 322 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 323 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 324 |
+
|
| 325 |
+
def convert_tokens_to_string(self, tokens):
|
| 326 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 327 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 328 |
+
return out_string
|
| 329 |
+
|
| 330 |
+
def build_inputs_with_special_tokens(
|
| 331 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 332 |
+
) -> List[int]:
|
| 333 |
+
"""
|
| 334 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 335 |
+
by concatenating and adding special tokens.
|
| 336 |
+
A BERT sequence has the following format:
|
| 337 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 338 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 339 |
+
Args:
|
| 340 |
+
token_ids_0 (:obj:`List[int]`):
|
| 341 |
+
List of IDs to which the special tokens will be added
|
| 342 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 343 |
+
Optional second list of IDs for sequence pairs.
|
| 344 |
+
Returns:
|
| 345 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 346 |
+
"""
|
| 347 |
+
if token_ids_1 is None:
|
| 348 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 349 |
+
cls = [self.cls_token_id]
|
| 350 |
+
sep = [self.sep_token_id]
|
| 351 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 352 |
+
|
| 353 |
+
def get_special_tokens_mask(
|
| 354 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 355 |
+
) -> List[int]:
|
| 356 |
+
"""
|
| 357 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 358 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 359 |
+
Args:
|
| 360 |
+
token_ids_0 (:obj:`List[int]`):
|
| 361 |
+
List of ids.
|
| 362 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 363 |
+
Optional second list of IDs for sequence pairs.
|
| 364 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 365 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 366 |
+
Returns:
|
| 367 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 368 |
+
"""
|
| 369 |
+
|
| 370 |
+
if already_has_special_tokens:
|
| 371 |
+
if token_ids_1 is not None:
|
| 372 |
+
raise ValueError(
|
| 373 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 374 |
+
"ids is already formated with special tokens for the model."
|
| 375 |
+
)
|
| 376 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 377 |
+
|
| 378 |
+
if token_ids_1 is not None:
|
| 379 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 380 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 381 |
+
|
| 382 |
+
def create_token_type_ids_from_sequences(
|
| 383 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 384 |
+
) -> List[int]:
|
| 385 |
+
"""
|
| 386 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 387 |
+
A BERT sequence pair mask has the following format:
|
| 388 |
+
::
|
| 389 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 390 |
+
| first sequence | second sequence |
|
| 391 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 392 |
+
Args:
|
| 393 |
+
token_ids_0 (:obj:`List[int]`):
|
| 394 |
+
List of ids.
|
| 395 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 396 |
+
Optional second list of IDs for sequence pairs.
|
| 397 |
+
Returns:
|
| 398 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 399 |
+
sequence(s).
|
| 400 |
+
"""
|
| 401 |
+
sep = [self.sep_token_id]
|
| 402 |
+
cls = [self.cls_token_id]
|
| 403 |
+
if token_ids_1 is None:
|
| 404 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 405 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 406 |
+
|
| 407 |
+
def save_vocabulary(self, vocab_path):
|
| 408 |
+
"""
|
| 409 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 410 |
+
Args:
|
| 411 |
+
vocab_path (:obj:`str`):
|
| 412 |
+
The directory in which to save the vocabulary.
|
| 413 |
+
Returns:
|
| 414 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 415 |
+
"""
|
| 416 |
+
index = 0
|
| 417 |
+
vocab_file = vocab_path
|
| 418 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 419 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 420 |
+
if index != token_index:
|
| 421 |
+
index = token_index
|
| 422 |
+
writer.write(token + "\n")
|
| 423 |
+
index += 1
|
| 424 |
+
return (vocab_file,)
|
tokenizer/new_splits.txt
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
c 1
|
| 2 |
+
c 2
|
| 3 |
+
c 3
|
| 4 |
+
c 4
|
| 5 |
+
c 5
|
| 6 |
+
c 6
|
| 7 |
+
c 7
|
| 8 |
+
c 8
|
| 9 |
+
c 9
|
| 10 |
+
( c1
|
| 11 |
+
( c2
|
| 12 |
+
c1 )
|
| 13 |
+
c2 )
|
| 14 |
+
n 1
|
| 15 |
+
n 2
|
| 16 |
+
n 3
|
| 17 |
+
n 4
|
| 18 |
+
n 5
|
| 19 |
+
n 6
|
| 20 |
+
n 7
|
| 21 |
+
n 8
|
| 22 |
+
n 9
|
| 23 |
+
( n1
|
| 24 |
+
( n2
|
| 25 |
+
n1 )
|
| 26 |
+
n2 )
|
| 27 |
+
O 1
|
| 28 |
+
O 2
|
| 29 |
+
O 3
|
| 30 |
+
O 4
|
| 31 |
+
O 5
|
| 32 |
+
O 6
|
| 33 |
+
O 7
|
| 34 |
+
O 8
|
| 35 |
+
O 9
|
| 36 |
+
( O1
|
| 37 |
+
( O2
|
| 38 |
+
O2 )
|
| 39 |
+
O2 )
|
| 40 |
+
= O
|
| 41 |
+
= C
|
| 42 |
+
= c
|
| 43 |
+
= N
|
| 44 |
+
= n
|
| 45 |
+
=C C
|
| 46 |
+
=C N
|
| 47 |
+
=C c
|
| 48 |
+
=c c
|
| 49 |
+
=N C
|
| 50 |
+
=N c
|
| 51 |
+
=n C
|
| 52 |
+
=n c
|
| 53 |
+
# N
|
| 54 |
+
# C
|
| 55 |
+
#N C
|
| 56 |
+
#C C
|
| 57 |
+
#C N
|
| 58 |
+
#N N
|
| 59 |
+
( C
|
| 60 |
+
C )
|
| 61 |
+
( O
|
| 62 |
+
O )
|
| 63 |
+
( N
|
| 64 |
+
N )
|
| 65 |
+
Br c
|
| 66 |
+
( =O
|
| 67 |
+
(=O )
|
| 68 |
+
C (=O)
|
| 69 |
+
C =O
|
| 70 |
+
C =N
|
| 71 |
+
C #N
|
| 72 |
+
C #C
|
| 73 |
+
C C
|
| 74 |
+
CC C
|
| 75 |
+
CC N
|
| 76 |
+
CC O
|
| 77 |
+
CC S
|
| 78 |
+
CC c
|
| 79 |
+
CC n
|
| 80 |
+
C N
|
| 81 |
+
CN C
|
| 82 |
+
CN c
|
| 83 |
+
C O
|
| 84 |
+
CO C
|
| 85 |
+
CO N
|
| 86 |
+
CO c
|
| 87 |
+
C S
|
| 88 |
+
CS C
|
| 89 |
+
CS S
|
| 90 |
+
CS c
|
| 91 |
+
C c
|
| 92 |
+
Cl c
|
| 93 |
+
C n
|
| 94 |
+
F c
|
| 95 |
+
N C
|
| 96 |
+
NC C
|
| 97 |
+
NC c
|
| 98 |
+
N N
|
| 99 |
+
N O
|
| 100 |
+
N c
|
| 101 |
+
N n
|
| 102 |
+
O C
|
| 103 |
+
OC C
|
| 104 |
+
OC O
|
| 105 |
+
OC c
|
| 106 |
+
O N
|
| 107 |
+
O O
|
| 108 |
+
O c
|
| 109 |
+
S C
|
| 110 |
+
SC C
|
| 111 |
+
SC c
|
| 112 |
+
S S
|
| 113 |
+
S c
|
| 114 |
+
c c
|
| 115 |
+
cc c
|
| 116 |
+
cc n
|
| 117 |
+
cc o
|
| 118 |
+
cc s
|
| 119 |
+
cc cc
|
| 120 |
+
c n
|
| 121 |
+
cn c
|
| 122 |
+
cn n
|
| 123 |
+
c o
|
| 124 |
+
co c
|
| 125 |
+
c s
|
| 126 |
+
cs c
|
| 127 |
+
cs n
|
| 128 |
+
n c
|
| 129 |
+
nc c
|
| 130 |
+
nc n
|
| 131 |
+
nc o
|
| 132 |
+
nc s
|
| 133 |
+
n n
|
| 134 |
+
nn c
|
| 135 |
+
nn n
|
| 136 |
+
n o
|
| 137 |
+
no c
|
| 138 |
+
no n
|
| 139 |
+
n s
|
| 140 |
+
ns c
|
| 141 |
+
ns n
|
| 142 |
+
o c
|
| 143 |
+
oc c
|
| 144 |
+
o n
|
| 145 |
+
s c
|
| 146 |
+
sc c
|
| 147 |
+
sc n
|
| 148 |
+
s n
|
| 149 |
+
N P
|
| 150 |
+
P N
|
| 151 |
+
C P
|
| 152 |
+
P C
|
| 153 |
+
N S
|
| 154 |
+
S N
|
| 155 |
+
C S
|
| 156 |
+
S C
|
| 157 |
+
S P
|
| 158 |
+
P S
|
| 159 |
+
C I
|
tokenizer/new_vocab.txt
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[PAD]
|
| 2 |
+
[UNK]
|
| 3 |
+
[CLS]
|
| 4 |
+
[SEP]
|
| 5 |
+
[MASK]
|
| 6 |
+
#
|
| 7 |
+
%
|
| 8 |
+
(
|
| 9 |
+
)
|
| 10 |
+
+
|
| 11 |
+
-
|
| 12 |
+
/
|
| 13 |
+
0
|
| 14 |
+
1
|
| 15 |
+
2
|
| 16 |
+
3
|
| 17 |
+
4
|
| 18 |
+
5
|
| 19 |
+
6
|
| 20 |
+
7
|
| 21 |
+
8
|
| 22 |
+
9
|
| 23 |
+
=
|
| 24 |
+
@
|
| 25 |
+
A
|
| 26 |
+
B
|
| 27 |
+
Br
|
| 28 |
+
Brc
|
| 29 |
+
C
|
| 30 |
+
CC
|
| 31 |
+
CCC
|
| 32 |
+
CCN
|
| 33 |
+
CCO
|
| 34 |
+
CCS
|
| 35 |
+
CCc
|
| 36 |
+
CCn
|
| 37 |
+
CN
|
| 38 |
+
CNC
|
| 39 |
+
CNc
|
| 40 |
+
CO
|
| 41 |
+
COC
|
| 42 |
+
CON
|
| 43 |
+
COc
|
| 44 |
+
CS
|
| 45 |
+
CSC
|
| 46 |
+
CSS
|
| 47 |
+
CSc
|
| 48 |
+
Cc
|
| 49 |
+
Cl
|
| 50 |
+
Clc
|
| 51 |
+
Cn
|
| 52 |
+
F
|
| 53 |
+
Fc
|
| 54 |
+
H
|
| 55 |
+
I
|
| 56 |
+
K
|
| 57 |
+
L
|
| 58 |
+
M
|
| 59 |
+
N
|
| 60 |
+
NC
|
| 61 |
+
NCC
|
| 62 |
+
NCc
|
| 63 |
+
NN
|
| 64 |
+
NO
|
| 65 |
+
Nc
|
| 66 |
+
Nn
|
| 67 |
+
O
|
| 68 |
+
OC
|
| 69 |
+
OCC
|
| 70 |
+
OCO
|
| 71 |
+
OCc
|
| 72 |
+
ON
|
| 73 |
+
OO
|
| 74 |
+
Oc
|
| 75 |
+
P
|
| 76 |
+
R
|
| 77 |
+
S
|
| 78 |
+
SC
|
| 79 |
+
SCC
|
| 80 |
+
SCc
|
| 81 |
+
SS
|
| 82 |
+
Sc
|
| 83 |
+
T
|
| 84 |
+
X
|
| 85 |
+
Z
|
| 86 |
+
[
|
| 87 |
+
\\
|
| 88 |
+
(/
|
| 89 |
+
]
|
| 90 |
+
a
|
| 91 |
+
b
|
| 92 |
+
c
|
| 93 |
+
cc
|
| 94 |
+
ccc
|
| 95 |
+
cccc
|
| 96 |
+
ccn
|
| 97 |
+
cco
|
| 98 |
+
ccs
|
| 99 |
+
cn
|
| 100 |
+
cnc
|
| 101 |
+
cnn
|
| 102 |
+
co
|
| 103 |
+
coc
|
| 104 |
+
cs
|
| 105 |
+
csc
|
| 106 |
+
csn
|
| 107 |
+
e
|
| 108 |
+
g
|
| 109 |
+
i
|
| 110 |
+
l
|
| 111 |
+
n
|
| 112 |
+
nc
|
| 113 |
+
ncc
|
| 114 |
+
ncn
|
| 115 |
+
nco
|
| 116 |
+
ncs
|
| 117 |
+
nn
|
| 118 |
+
nnc
|
| 119 |
+
nnn
|
| 120 |
+
no
|
| 121 |
+
noc
|
| 122 |
+
non
|
| 123 |
+
ns
|
| 124 |
+
nsc
|
| 125 |
+
nsn
|
| 126 |
+
o
|
| 127 |
+
oc
|
| 128 |
+
occ
|
| 129 |
+
on
|
| 130 |
+
p
|
| 131 |
+
r
|
| 132 |
+
s
|
| 133 |
+
sc
|
| 134 |
+
scc
|
| 135 |
+
scn
|
| 136 |
+
sn
|
| 137 |
+
t
|
| 138 |
+
c1
|
| 139 |
+
c2
|
| 140 |
+
c3
|
| 141 |
+
c4
|
| 142 |
+
c5
|
| 143 |
+
c6
|
| 144 |
+
c7
|
| 145 |
+
c8
|
| 146 |
+
c9
|
| 147 |
+
n1
|
| 148 |
+
n2
|
| 149 |
+
n3
|
| 150 |
+
n4
|
| 151 |
+
n5
|
| 152 |
+
n6
|
| 153 |
+
n7
|
| 154 |
+
n8
|
| 155 |
+
n9
|
| 156 |
+
O1
|
| 157 |
+
O2
|
| 158 |
+
O3
|
| 159 |
+
O4
|
| 160 |
+
O5
|
| 161 |
+
O6
|
| 162 |
+
O7
|
| 163 |
+
O8
|
| 164 |
+
O9
|
| 165 |
+
(c1
|
| 166 |
+
(c2
|
| 167 |
+
c1)
|
| 168 |
+
c2)
|
| 169 |
+
(n1
|
| 170 |
+
(n2
|
| 171 |
+
n1)
|
| 172 |
+
n2)
|
| 173 |
+
(O1
|
| 174 |
+
(O2
|
| 175 |
+
O2)
|
| 176 |
+
=O
|
| 177 |
+
=C
|
| 178 |
+
=c
|
| 179 |
+
=N
|
| 180 |
+
=n
|
| 181 |
+
=CC
|
| 182 |
+
=CN
|
| 183 |
+
=Cc
|
| 184 |
+
=cc
|
| 185 |
+
=NC
|
| 186 |
+
=Nc
|
| 187 |
+
=nC
|
| 188 |
+
=nc
|
| 189 |
+
#C
|
| 190 |
+
#CC
|
| 191 |
+
#CN
|
| 192 |
+
#N
|
| 193 |
+
#NC
|
| 194 |
+
#NN
|
| 195 |
+
(C
|
| 196 |
+
C)
|
| 197 |
+
(O
|
| 198 |
+
O)
|
| 199 |
+
(N
|
| 200 |
+
N)
|
| 201 |
+
NP
|
| 202 |
+
PN
|
| 203 |
+
CP
|
| 204 |
+
PC
|
| 205 |
+
NS
|
| 206 |
+
SN
|
| 207 |
+
SP
|
| 208 |
+
PS
|
| 209 |
+
C(=O)
|
| 210 |
+
(/Br)
|
| 211 |
+
(/C#N)
|
| 212 |
+
(/C)
|
| 213 |
+
(/C=N)
|
| 214 |
+
(/C=O)
|
| 215 |
+
(/CBr)
|
| 216 |
+
(/CC)
|
| 217 |
+
(/CCC)
|
| 218 |
+
(/CCF)
|
| 219 |
+
(/CCN)
|
| 220 |
+
(/CCO)
|
| 221 |
+
(/CCl)
|
| 222 |
+
(/CI)
|
| 223 |
+
(/CN)
|
| 224 |
+
(/CO)
|
| 225 |
+
(/CS)
|
| 226 |
+
(/Cl)
|
| 227 |
+
(/F)
|
| 228 |
+
(/I)
|
| 229 |
+
(/N)
|
| 230 |
+
(/NC)
|
| 231 |
+
(/NCC)
|
| 232 |
+
(/NO)
|
| 233 |
+
(/O)
|
| 234 |
+
(/OC)
|
| 235 |
+
(/OCC)
|
| 236 |
+
(/S)
|
| 237 |
+
(/SC)
|
| 238 |
+
(=C)
|
| 239 |
+
(=C/C)
|
| 240 |
+
(=C/F)
|
| 241 |
+
(=C/I)
|
| 242 |
+
(=C/N)
|
| 243 |
+
(=C/O)
|
| 244 |
+
(=CBr)
|
| 245 |
+
(=CC)
|
| 246 |
+
(=CCF)
|
| 247 |
+
(=CCN)
|
| 248 |
+
(=CCO)
|
| 249 |
+
(=CCl)
|
| 250 |
+
(=CF)
|
| 251 |
+
(=CI)
|
| 252 |
+
(=CN)
|
| 253 |
+
(=CO)
|
| 254 |
+
(=C\\C)
|
| 255 |
+
(=C\\F)
|
| 256 |
+
(=C\\I)
|
| 257 |
+
(=C\\N)
|
| 258 |
+
(=C\\O)
|
| 259 |
+
(=N)
|
| 260 |
+
(=N/C)
|
| 261 |
+
(=N/N)
|
| 262 |
+
(=N/O)
|
| 263 |
+
(=NBr)
|
| 264 |
+
(=NC)
|
| 265 |
+
(=NCC)
|
| 266 |
+
(=NCl)
|
| 267 |
+
(=NN)
|
| 268 |
+
(=NO)
|
| 269 |
+
(=NOC)
|
| 270 |
+
(=N\\C)
|
| 271 |
+
(=N\\N)
|
| 272 |
+
(=N\\O)
|
| 273 |
+
(=O)
|
| 274 |
+
(=S)
|
| 275 |
+
(B)
|
| 276 |
+
(Br)
|
| 277 |
+
(C#C)
|
| 278 |
+
(C#CC)
|
| 279 |
+
(C#CI)
|
| 280 |
+
(C#CO)
|
| 281 |
+
(C#N)
|
| 282 |
+
(C#SN)
|
| 283 |
+
(C)
|
| 284 |
+
(C=C)
|
| 285 |
+
(C=CF)
|
| 286 |
+
(C=CI)
|
| 287 |
+
(C=N)
|
| 288 |
+
(C=NN)
|
| 289 |
+
(C=NO)
|
| 290 |
+
(C=O)
|
| 291 |
+
(C=S)
|
| 292 |
+
(CBr)
|
| 293 |
+
(CC#C)
|
| 294 |
+
(CC#N)
|
| 295 |
+
(CC)
|
| 296 |
+
(CC=C)
|
| 297 |
+
(CC=O)
|
| 298 |
+
(CCBr)
|
| 299 |
+
(CCC)
|
| 300 |
+
(CCCC)
|
| 301 |
+
(CCCF)
|
| 302 |
+
(CCCI)
|
| 303 |
+
(CCCN)
|
| 304 |
+
(CCCO)
|
| 305 |
+
(CCCS)
|
| 306 |
+
(CCCl)
|
| 307 |
+
(CCF)
|
| 308 |
+
(CCI)
|
| 309 |
+
(CCN)
|
| 310 |
+
(CCNC)
|
| 311 |
+
(CCNN)
|
| 312 |
+
(CCNO)
|
| 313 |
+
(CCO)
|
| 314 |
+
(CCOC)
|
| 315 |
+
(CCON)
|
| 316 |
+
(CCS)
|
| 317 |
+
(CCSC)
|
| 318 |
+
(CCl)
|
| 319 |
+
(CF)
|
| 320 |
+
(CI)
|
| 321 |
+
(CN)
|
| 322 |
+
(CN=O)
|
| 323 |
+
(CNC)
|
| 324 |
+
(CNCC)
|
| 325 |
+
(CNCO)
|
| 326 |
+
(CNN)
|
| 327 |
+
(CNNC)
|
| 328 |
+
(CNO)
|
| 329 |
+
(CNOC)
|
| 330 |
+
(CO)
|
| 331 |
+
(COC)
|
| 332 |
+
(COCC)
|
| 333 |
+
(COCI)
|
| 334 |
+
(COCN)
|
| 335 |
+
(COCO)
|
| 336 |
+
(COF)
|
| 337 |
+
(CON)
|
| 338 |
+
(COO)
|
| 339 |
+
(CS)
|
| 340 |
+
(CSC)
|
| 341 |
+
(CSCC)
|
| 342 |
+
(CSCF)
|
| 343 |
+
(CSO)
|
| 344 |
+
(Cl)
|
| 345 |
+
(F)
|
| 346 |
+
(I)
|
| 347 |
+
(N)
|
| 348 |
+
(N=N)
|
| 349 |
+
(N=NO)
|
| 350 |
+
(N=O)
|
| 351 |
+
(N=S)
|
| 352 |
+
(NBr)
|
| 353 |
+
(NC#N)
|
| 354 |
+
(NC)
|
| 355 |
+
(NC=N)
|
| 356 |
+
(NC=O)
|
| 357 |
+
(NC=S)
|
| 358 |
+
(NCBr)
|
| 359 |
+
(NCC)
|
| 360 |
+
(NCCC)
|
| 361 |
+
(NCCF)
|
| 362 |
+
(NCCN)
|
| 363 |
+
(NCCO)
|
| 364 |
+
(NCCS)
|
| 365 |
+
(NCCl)
|
| 366 |
+
(NCNC)
|
| 367 |
+
(NCO)
|
| 368 |
+
(NCS)
|
| 369 |
+
(NCl)
|
| 370 |
+
(NN)
|
| 371 |
+
(NN=O)
|
| 372 |
+
(NNC)
|
| 373 |
+
(NO)
|
| 374 |
+
(NOC)
|
| 375 |
+
(O)
|
| 376 |
+
(OC#N)
|
| 377 |
+
(OC)
|
| 378 |
+
(OC=C)
|
| 379 |
+
(OC=O)
|
| 380 |
+
(OC=S)
|
| 381 |
+
(OCBr)
|
| 382 |
+
(OCC)
|
| 383 |
+
(OCCC)
|
| 384 |
+
(OCCF)
|
| 385 |
+
(OCCI)
|
| 386 |
+
(OCCN)
|
| 387 |
+
(OCCO)
|
| 388 |
+
(OCCS)
|
| 389 |
+
(OCCl)
|
| 390 |
+
(OCF)
|
| 391 |
+
(OCI)
|
| 392 |
+
(OCO)
|
| 393 |
+
(OCOC)
|
| 394 |
+
(OCON)
|
| 395 |
+
(OCSC)
|
| 396 |
+
(OCl)
|
| 397 |
+
(OI)
|
| 398 |
+
(ON)
|
| 399 |
+
(OO)
|
| 400 |
+
(OOC)
|
| 401 |
+
(OOCC)
|
| 402 |
+
(OOSN)
|
| 403 |
+
(OSC)
|
| 404 |
+
(P)
|
| 405 |
+
(S)
|
| 406 |
+
(SC#N)
|
| 407 |
+
(SC)
|
| 408 |
+
(SCC)
|
| 409 |
+
(SCCC)
|
| 410 |
+
(SCCF)
|
| 411 |
+
(SCCN)
|
| 412 |
+
(SCCO)
|
| 413 |
+
(SCCS)
|
| 414 |
+
(SCCl)
|
| 415 |
+
(SCF)
|
| 416 |
+
(SCN)
|
| 417 |
+
(SCOC)
|
| 418 |
+
(SCSC)
|
| 419 |
+
(SCl)
|
| 420 |
+
(SI)
|
| 421 |
+
(SN)
|
| 422 |
+
(SN=O)
|
| 423 |
+
(SO)
|
| 424 |
+
(SOC)
|
| 425 |
+
(SOOO)
|
| 426 |
+
(SS)
|
| 427 |
+
(SSC)
|
| 428 |
+
(SSCC)
|
| 429 |
+
([At])
|
| 430 |
+
([O-])
|
| 431 |
+
([O])
|
| 432 |
+
([S-])
|
| 433 |
+
(\\Br)
|
| 434 |
+
(\\C#N)
|
| 435 |
+
(\\C)
|
| 436 |
+
(\\C=N)
|
| 437 |
+
(\\C=O)
|
| 438 |
+
(\\CBr)
|
| 439 |
+
(\\CC)
|
| 440 |
+
(\\CCC)
|
| 441 |
+
(\\CCO)
|
| 442 |
+
(\\CCl)
|
| 443 |
+
(\\CF)
|
| 444 |
+
(\\CN)
|
| 445 |
+
(\\CNC)
|
| 446 |
+
(\\CO)
|
| 447 |
+
(\\COC)
|
| 448 |
+
(\\Cl)
|
| 449 |
+
(\\F)
|
| 450 |
+
(\\I)
|
| 451 |
+
(\\N)
|
| 452 |
+
(\\NC)
|
| 453 |
+
(\\NCC)
|
| 454 |
+
(\\NN)
|
| 455 |
+
(\\NO)
|
| 456 |
+
(\\NOC)
|
| 457 |
+
(\\O)
|
| 458 |
+
(\\OC)
|
| 459 |
+
(\\OCC)
|
| 460 |
+
(\\ON)
|
| 461 |
+
(\\S)
|
| 462 |
+
(\\SC)
|
| 463 |
+
(\\SCC)
|
| 464 |
+
[Ag+]
|
| 465 |
+
[Ag-4]
|
| 466 |
+
[Ag]
|
| 467 |
+
[Al-3]
|
| 468 |
+
[Al]
|
| 469 |
+
[As+]
|
| 470 |
+
[AsH3]
|
| 471 |
+
[AsH]
|
| 472 |
+
[As]
|
| 473 |
+
[At]
|
| 474 |
+
[B-]
|
| 475 |
+
[B@-]
|
| 476 |
+
[B@@-]
|
| 477 |
+
[BH-]
|
| 478 |
+
[BH2-]
|
| 479 |
+
[BH3-]
|
| 480 |
+
[B]
|
| 481 |
+
[Ba]
|
| 482 |
+
[Br+2]
|
| 483 |
+
[BrH]
|
| 484 |
+
[Br]
|
| 485 |
+
[C+]
|
| 486 |
+
[C-]
|
| 487 |
+
[C@@H]
|
| 488 |
+
[C@@]
|
| 489 |
+
[C@H]
|
| 490 |
+
[C@]
|
| 491 |
+
[CH-]
|
| 492 |
+
[CH2]
|
| 493 |
+
[CH3]
|
| 494 |
+
[CH]
|
| 495 |
+
[C]
|
| 496 |
+
[CaH2]
|
| 497 |
+
[Ca]
|
| 498 |
+
[Cl+2]
|
| 499 |
+
[Cl+3]
|
| 500 |
+
[Cl+]
|
| 501 |
+
[Cs]
|
| 502 |
+
[FH]
|
| 503 |
+
[F]
|
| 504 |
+
[H]
|
| 505 |
+
[He]
|
| 506 |
+
[I+2]
|
| 507 |
+
[I+3]
|
| 508 |
+
[I+]
|
| 509 |
+
[IH]
|
| 510 |
+
[I]
|
| 511 |
+
[K]
|
| 512 |
+
[Kr]
|
| 513 |
+
[Li+]
|
| 514 |
+
[LiH]
|
| 515 |
+
[MgH2]
|
| 516 |
+
[Mg]
|
| 517 |
+
[N+]
|
| 518 |
+
[N-]
|
| 519 |
+
[N@+]
|
| 520 |
+
[N@@+]
|
| 521 |
+
[N@@]
|
| 522 |
+
[N@]
|
| 523 |
+
[NH+]
|
| 524 |
+
[NH-]
|
| 525 |
+
[NH2+]
|
| 526 |
+
[NH3]
|
| 527 |
+
[NH]
|
| 528 |
+
[N]
|
| 529 |
+
[Na]
|
| 530 |
+
[O+]
|
| 531 |
+
[O-]
|
| 532 |
+
[OH+]
|
| 533 |
+
[OH2]
|
| 534 |
+
[OH]
|
| 535 |
+
[O]
|
| 536 |
+
[P+]
|
| 537 |
+
[P@+]
|
| 538 |
+
[P@@+]
|
| 539 |
+
[P@@]
|
| 540 |
+
[P@]
|
| 541 |
+
[PH2]
|
| 542 |
+
[PH]
|
| 543 |
+
[P]
|
| 544 |
+
[Ra]
|
| 545 |
+
[Rb]
|
| 546 |
+
[S+]
|
| 547 |
+
[S-]
|
| 548 |
+
[S@+]
|
| 549 |
+
[S@@+]
|
| 550 |
+
[S@@]
|
| 551 |
+
[S@]
|
| 552 |
+
[SH+]
|
| 553 |
+
[SH2]
|
| 554 |
+
[SH]
|
| 555 |
+
[S]
|
| 556 |
+
[Se+]
|
| 557 |
+
[Se-2]
|
| 558 |
+
[SeH2]
|
| 559 |
+
[SeH]
|
| 560 |
+
[Se]
|
| 561 |
+
[Si@]
|
| 562 |
+
[SiH2]
|
| 563 |
+
[SiH]
|
| 564 |
+
[Si]
|
| 565 |
+
[SrH2]
|
| 566 |
+
[TeH]
|
| 567 |
+
[Te]
|
| 568 |
+
[Xe]
|
| 569 |
+
[Zn+2]
|
| 570 |
+
[Zn-2]
|
| 571 |
+
[Zn]
|
| 572 |
+
[b-]
|
| 573 |
+
[c+]
|
| 574 |
+
[c-]
|
| 575 |
+
[cH-]
|
| 576 |
+
[cH]
|
| 577 |
+
[c]
|
| 578 |
+
[n+]
|
| 579 |
+
[n-]
|
| 580 |
+
[nH]
|
| 581 |
+
[n]
|
| 582 |
+
[o+]
|
| 583 |
+
[s+]
|
| 584 |
+
[se+]
|
| 585 |
+
[se]
|
| 586 |
+
[te+]
|
| 587 |
+
[te]
|
utils/app.py
ADDED
|
@@ -0,0 +1,1287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from io import StringIO
|
| 5 |
+
import rdkit
|
| 6 |
+
from rdkit import Chem
|
| 7 |
+
from rdkit.Chem import AllChem, Draw
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import matplotlib.patches as patches
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
import tempfile
|
| 14 |
+
from rdkit import Chem
|
| 15 |
+
|
| 16 |
+
class PeptideAnalyzer:
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.bond_patterns = [
|
| 19 |
+
(r'OC\(=O\)', 'ester'), # Ester bond
|
| 20 |
+
(r'N\(C\)C\(=O\)', 'n_methyl'), # N-methylated peptide bond
|
| 21 |
+
(r'N[0-9]C\(=O\)', 'proline'), # Proline peptide bond
|
| 22 |
+
(r'NC\(=O\)', 'peptide'), # Standard peptide bond
|
| 23 |
+
(r'C\(=O\)N\(C\)', 'n_methyl_reverse'), # Reverse N-methylated
|
| 24 |
+
(r'C\(=O\)N[12]?', 'peptide_reverse') # Reverse peptide bond
|
| 25 |
+
]
|
| 26 |
+
# Three to one letter code mapping
|
| 27 |
+
self.three_to_one = {
|
| 28 |
+
'Ala': 'A', 'Cys': 'C', 'Asp': 'D', 'Glu': 'E',
|
| 29 |
+
'Phe': 'F', 'Gly': 'G', 'His': 'H', 'Ile': 'I',
|
| 30 |
+
'Lys': 'K', 'Leu': 'L', 'Met': 'M', 'Asn': 'N',
|
| 31 |
+
'Pro': 'P', 'Gln': 'Q', 'Arg': 'R', 'Ser': 'S',
|
| 32 |
+
'Thr': 'T', 'Val': 'V', 'Trp': 'W', 'Tyr': 'Y'
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
def is_amino_acid_sequence(self, seq):
|
| 36 |
+
"""
|
| 37 |
+
Check if the input is a valid amino acid sequence.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
seq: String to check
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
bool: True if valid amino acid sequence, False otherwise
|
| 44 |
+
"""
|
| 45 |
+
if not seq or not isinstance(seq, str):
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
# Valid amino acid letters (20 standard + some common modifications)
|
| 49 |
+
valid_amino_acids = set('ACDEFGHIKLMNPQRSTVWY')
|
| 50 |
+
|
| 51 |
+
# Check if all characters are valid amino acids
|
| 52 |
+
# Allow for some special characters that might be in the sequence
|
| 53 |
+
seq_clean = seq.strip().upper()
|
| 54 |
+
|
| 55 |
+
# Must have at least 2 amino acids to be a peptide
|
| 56 |
+
if len(seq_clean) < 2:
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
# Check if all characters are valid amino acids
|
| 60 |
+
return all(c in valid_amino_acids for c in seq_clean)
|
| 61 |
+
|
| 62 |
+
def is_peptide(self, smiles):
|
| 63 |
+
"""Check if the SMILES represents a peptide structure"""
|
| 64 |
+
# First check if it's an amino acid sequence (not SMILES)
|
| 65 |
+
if self.is_amino_acid_sequence(smiles):
|
| 66 |
+
return True
|
| 67 |
+
|
| 68 |
+
# Otherwise check if it's a SMILES peptide
|
| 69 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 70 |
+
if mol is None:
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
# Look for peptide bonds: NC(=O) pattern
|
| 74 |
+
peptide_bond_pattern = Chem.MolFromSmarts('[NH][C](=O)')
|
| 75 |
+
if mol.HasSubstructMatch(peptide_bond_pattern):
|
| 76 |
+
return True
|
| 77 |
+
|
| 78 |
+
# Look for N-methylated peptide bonds: N(C)C(=O) pattern
|
| 79 |
+
n_methyl_pattern = Chem.MolFromSmarts('[N;H0;$(NC)](C)[C](=O)')
|
| 80 |
+
if mol.HasSubstructMatch(n_methyl_pattern):
|
| 81 |
+
return True
|
| 82 |
+
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
def is_cyclic(self, smiles):
|
| 86 |
+
"""Improved cyclic peptide detection"""
|
| 87 |
+
# Check for C-terminal carboxyl
|
| 88 |
+
if smiles.endswith('C(=O)O'):
|
| 89 |
+
return False, [], []
|
| 90 |
+
|
| 91 |
+
# Find all numbers used in ring closures
|
| 92 |
+
ring_numbers = re.findall(r'(?:^|[^c])[0-9](?=[A-Z@\(\)])', smiles)
|
| 93 |
+
|
| 94 |
+
# Find aromatic ring numbers
|
| 95 |
+
aromatic_matches = re.findall(r'c[0-9](?:ccccc|c\[nH\]c)[0-9]', smiles)
|
| 96 |
+
aromatic_cycles = []
|
| 97 |
+
for match in aromatic_matches:
|
| 98 |
+
numbers = re.findall(r'[0-9]', match)
|
| 99 |
+
aromatic_cycles.extend(numbers)
|
| 100 |
+
|
| 101 |
+
# Numbers that aren't part of aromatic rings are peptide cycles
|
| 102 |
+
peptide_cycles = [n for n in ring_numbers if n not in aromatic_cycles]
|
| 103 |
+
|
| 104 |
+
is_cyclic = len(peptide_cycles) > 0 and not smiles.endswith('C(=O)O')
|
| 105 |
+
return is_cyclic, peptide_cycles, aromatic_cycles
|
| 106 |
+
|
| 107 |
+
def split_on_bonds(self, smiles):
|
| 108 |
+
"""Split SMILES into segments with simplified Pro handling"""
|
| 109 |
+
positions = []
|
| 110 |
+
used = set()
|
| 111 |
+
|
| 112 |
+
# Find Gly pattern first
|
| 113 |
+
gly_pattern = r'NCC\(=O\)'
|
| 114 |
+
for match in re.finditer(gly_pattern, smiles):
|
| 115 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 116 |
+
positions.append({
|
| 117 |
+
'start': match.start(),
|
| 118 |
+
'end': match.end(),
|
| 119 |
+
'type': 'gly',
|
| 120 |
+
'pattern': match.group()
|
| 121 |
+
})
|
| 122 |
+
used.update(range(match.start(), match.end()))
|
| 123 |
+
|
| 124 |
+
for pattern, bond_type in self.bond_patterns:
|
| 125 |
+
for match in re.finditer(pattern, smiles):
|
| 126 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 127 |
+
positions.append({
|
| 128 |
+
'start': match.start(),
|
| 129 |
+
'end': match.end(),
|
| 130 |
+
'type': bond_type,
|
| 131 |
+
'pattern': match.group()
|
| 132 |
+
})
|
| 133 |
+
used.update(range(match.start(), match.end()))
|
| 134 |
+
|
| 135 |
+
# Sort by position
|
| 136 |
+
positions.sort(key=lambda x: x['start'])
|
| 137 |
+
|
| 138 |
+
# Create segments
|
| 139 |
+
segments = []
|
| 140 |
+
|
| 141 |
+
if positions:
|
| 142 |
+
# First segment
|
| 143 |
+
if positions[0]['start'] > 0:
|
| 144 |
+
segments.append({
|
| 145 |
+
'content': smiles[0:positions[0]['start']],
|
| 146 |
+
'bond_after': positions[0]['pattern']
|
| 147 |
+
})
|
| 148 |
+
|
| 149 |
+
# Process segments
|
| 150 |
+
for i in range(len(positions)-1):
|
| 151 |
+
current = positions[i]
|
| 152 |
+
next_pos = positions[i+1]
|
| 153 |
+
|
| 154 |
+
if current['type'] == 'gly':
|
| 155 |
+
segments.append({
|
| 156 |
+
'content': 'NCC(=O)',
|
| 157 |
+
'bond_before': positions[i-1]['pattern'] if i > 0 else None,
|
| 158 |
+
'bond_after': next_pos['pattern']
|
| 159 |
+
})
|
| 160 |
+
else:
|
| 161 |
+
content = smiles[current['end']:next_pos['start']]
|
| 162 |
+
if content:
|
| 163 |
+
segments.append({
|
| 164 |
+
'content': content,
|
| 165 |
+
'bond_before': current['pattern'],
|
| 166 |
+
'bond_after': next_pos['pattern']
|
| 167 |
+
})
|
| 168 |
+
|
| 169 |
+
# Last segment
|
| 170 |
+
if positions[-1]['end'] < len(smiles):
|
| 171 |
+
segments.append({
|
| 172 |
+
'content': smiles[positions[-1]['end']:],
|
| 173 |
+
'bond_before': positions[-1]['pattern']
|
| 174 |
+
})
|
| 175 |
+
|
| 176 |
+
return segments
|
| 177 |
+
|
| 178 |
+
def clean_terminal_carboxyl(self, segment):
|
| 179 |
+
"""Remove C-terminal carboxyl only if it's the true terminus"""
|
| 180 |
+
content = segment['content']
|
| 181 |
+
|
| 182 |
+
# Only clean if:
|
| 183 |
+
# 1. Contains C(=O)O
|
| 184 |
+
# 2. No bond_after exists (meaning it's the last segment)
|
| 185 |
+
# 3. C(=O)O is at the end of the content
|
| 186 |
+
if 'C(=O)O' in content and not segment.get('bond_after'):
|
| 187 |
+
print('recognized?')
|
| 188 |
+
# Remove C(=O)O pattern regardless of position
|
| 189 |
+
cleaned = re.sub(r'\(C\(=O\)O\)', '', content)
|
| 190 |
+
# Remove any leftover empty parentheses
|
| 191 |
+
cleaned = re.sub(r'\(\)', '', cleaned)
|
| 192 |
+
print(cleaned)
|
| 193 |
+
return cleaned
|
| 194 |
+
return content
|
| 195 |
+
|
| 196 |
+
def identify_residue(self, segment):
|
| 197 |
+
"""Identify residue with Pro reconstruction"""
|
| 198 |
+
# Only clean terminal carboxyl if this is the last segment
|
| 199 |
+
content = self.clean_terminal_carboxyl(segment)
|
| 200 |
+
mods = self.get_modifications(segment)
|
| 201 |
+
|
| 202 |
+
# UAA pattern matching section - before regular residues
|
| 203 |
+
# Phenylglycine and derivatives
|
| 204 |
+
if 'c1ccccc1' in content:
|
| 205 |
+
if '[C@@H](c1ccccc1)' in content or '[C@H](c1ccccc1)' in content:
|
| 206 |
+
return '4', mods # Base phenylglycine
|
| 207 |
+
|
| 208 |
+
# 4-substituted phenylalanines
|
| 209 |
+
if 'Cc1ccc' in content:
|
| 210 |
+
if 'OMe' in content or 'OCc1ccc' in content:
|
| 211 |
+
return '0A1', mods # 4-methoxy-Phenylalanine
|
| 212 |
+
elif 'Clc1ccc' in content:
|
| 213 |
+
return '200', mods # 4-chloro-Phenylalanine
|
| 214 |
+
elif 'Brc1ccc' in content:
|
| 215 |
+
return '4BF', mods # 4-Bromo-phenylalanine
|
| 216 |
+
elif 'C#Nc1ccc' in content:
|
| 217 |
+
return '4CF', mods # 4-cyano-phenylalanine
|
| 218 |
+
elif 'Ic1ccc' in content:
|
| 219 |
+
return 'PHI', mods # 4-Iodo-phenylalanine
|
| 220 |
+
elif 'Fc1ccc' in content:
|
| 221 |
+
return 'PFF', mods # 4-Fluoro-phenylalanine
|
| 222 |
+
|
| 223 |
+
# Modified tryptophans
|
| 224 |
+
if 'c[nH]c2' in content:
|
| 225 |
+
if 'Oc2cccc2' in content:
|
| 226 |
+
return '0AF', mods # 7-hydroxy-tryptophan
|
| 227 |
+
elif 'Fc2cccc2' in content:
|
| 228 |
+
return '4FW', mods # 4-fluoro-tryptophan
|
| 229 |
+
elif 'Clc2cccc2' in content:
|
| 230 |
+
return '6CW', mods # 6-chloro-tryptophan
|
| 231 |
+
elif 'Brc2cccc2' in content:
|
| 232 |
+
return 'BTR', mods # 6-bromo-tryptophan
|
| 233 |
+
elif 'COc2cccc2' in content:
|
| 234 |
+
return 'MOT5', mods # 5-Methoxy-tryptophan
|
| 235 |
+
elif 'Cc2cccc2' in content:
|
| 236 |
+
return 'MTR5', mods # 5-Methyl-tryptophan
|
| 237 |
+
|
| 238 |
+
# Special amino acids
|
| 239 |
+
if 'CC(C)(C)[C@@H]' in content or 'CC(C)(C)[C@H]' in content:
|
| 240 |
+
return 'BUG', mods # Tertleucine
|
| 241 |
+
|
| 242 |
+
if 'CCCNC(=N)N' in content:
|
| 243 |
+
return 'CIR', mods # Citrulline
|
| 244 |
+
|
| 245 |
+
if '[SeH]' in content:
|
| 246 |
+
return 'CSE', mods # Selenocysteine
|
| 247 |
+
|
| 248 |
+
if '[NH3]CC[C@@H]' in content or '[NH3]CC[C@H]' in content:
|
| 249 |
+
return 'DAB', mods # Diaminobutyric acid
|
| 250 |
+
|
| 251 |
+
if 'C1CCCCC1' in content:
|
| 252 |
+
if 'C1CCCCC1[C@@H]' in content or 'C1CCCCC1[C@H]' in content:
|
| 253 |
+
return 'CHG', mods # Cyclohexylglycine
|
| 254 |
+
elif 'C1CCCCC1C[C@@H]' in content or 'C1CCCCC1C[C@H]' in content:
|
| 255 |
+
return 'ALC', mods # 3-cyclohexyl-alanine
|
| 256 |
+
|
| 257 |
+
# Naphthalene derivatives
|
| 258 |
+
if 'c1cccc2c1cccc2' in content:
|
| 259 |
+
if 'c1cccc2c1cccc2[C@@H]' in content or 'c1cccc2c1cccc2[C@H]' in content:
|
| 260 |
+
return 'NAL', mods # 2-Naphthyl-alanine
|
| 261 |
+
|
| 262 |
+
# Heteroaromatic derivatives
|
| 263 |
+
if 'c1cncc' in content:
|
| 264 |
+
return 'PYR4', mods # 3-(4-Pyridyl)-alanine
|
| 265 |
+
if 'c1cscc' in content:
|
| 266 |
+
return 'THA3', mods # 3-(3-thienyl)-alanine
|
| 267 |
+
if 'c1nnc' in content:
|
| 268 |
+
return 'TRZ4', mods # 3-(1,2,4-Triazol-1-yl)-alanine
|
| 269 |
+
|
| 270 |
+
# Modified serines and threonines
|
| 271 |
+
if 'OP(O)(O)O' in content:
|
| 272 |
+
if '[C@@H](COP' in content or '[C@H](COP' in content:
|
| 273 |
+
return 'SEP', mods # phosphoserine
|
| 274 |
+
elif '[C@@H](OP' in content or '[C@H](OP' in content:
|
| 275 |
+
return 'TPO', mods # phosphothreonine
|
| 276 |
+
|
| 277 |
+
# Specialized ring systems
|
| 278 |
+
if 'c1c2ccccc2cc2c1cccc2' in content:
|
| 279 |
+
return 'ANTH', mods # 3-(9-anthryl)-alanine
|
| 280 |
+
if 'c1csc2c1cccc2' in content:
|
| 281 |
+
return 'BTH3', mods # 3-(3-benzothienyl)-alanine
|
| 282 |
+
if '[C@]12C[C@H]3C[C@@H](C2)C[C@@H](C1)C3' in content:
|
| 283 |
+
return 'ADAM', mods # Adamanthane
|
| 284 |
+
|
| 285 |
+
# Fluorinated derivatives
|
| 286 |
+
if 'FC(F)(F)' in content:
|
| 287 |
+
if 'CC(F)(F)F' in content:
|
| 288 |
+
return 'FLA', mods # Trifluoro-alanine
|
| 289 |
+
if 'C(F)(F)F)c1' in content:
|
| 290 |
+
if 'c1ccccc1C(F)(F)F' in content:
|
| 291 |
+
return 'TFG2', mods # 2-(Trifluoromethyl)-phenylglycine
|
| 292 |
+
if 'c1cccc(c1)C(F)(F)F' in content:
|
| 293 |
+
return 'TFG3', mods # 3-(Trifluoromethyl)-phenylglycine
|
| 294 |
+
if 'c1ccc(cc1)C(F)(F)F' in content:
|
| 295 |
+
return 'TFG4', mods # 4-(Trifluoromethyl)-phenylglycine
|
| 296 |
+
|
| 297 |
+
# Multiple halogen patterns
|
| 298 |
+
if 'F' in content and 'c1' in content:
|
| 299 |
+
if 'c1ccc(c(c1)F)F' in content:
|
| 300 |
+
return 'F2F', mods # 3,4-Difluoro-phenylalanine
|
| 301 |
+
if 'cc(F)cc(c1)F' in content:
|
| 302 |
+
return 'WFP', mods # 3,5-Difluoro-phenylalanine
|
| 303 |
+
if 'Cl' in content and 'c1' in content:
|
| 304 |
+
if 'c1ccc(cc1Cl)Cl' in content:
|
| 305 |
+
return 'CP24', mods # 2,4-dichloro-phenylalanine
|
| 306 |
+
if 'c1ccc(c(c1)Cl)Cl' in content:
|
| 307 |
+
return 'CP34', mods # 3,4-dichloro-phenylalanine
|
| 308 |
+
|
| 309 |
+
# Hydroxy and amino derivatives
|
| 310 |
+
if 'O' in content and 'c1' in content:
|
| 311 |
+
if 'c1cc(O)cc(c1)O' in content:
|
| 312 |
+
return '3FG', mods # (2s)-amino(3,5-dihydroxyphenyl)-ethanoic acid
|
| 313 |
+
if 'c1ccc(c(c1)O)O' in content:
|
| 314 |
+
return 'DAH', mods # 3,4-Dihydroxy-phenylalanine
|
| 315 |
+
|
| 316 |
+
# Cyclic amino acids
|
| 317 |
+
if 'C1CCCC1' in content:
|
| 318 |
+
return 'CPA3', mods # 3-Cyclopentyl-alanine
|
| 319 |
+
if 'C1CCCCC1' in content:
|
| 320 |
+
if 'CC1CCCCC1' in content:
|
| 321 |
+
return 'ALC', mods # 3-cyclohexyl-alanine
|
| 322 |
+
else:
|
| 323 |
+
return 'CHG', mods # Cyclohexylglycine
|
| 324 |
+
|
| 325 |
+
# Chain-length variants
|
| 326 |
+
if 'CCC[C@@H]' in content or 'CCC[C@H]' in content:
|
| 327 |
+
return 'NLE', mods # Norleucine
|
| 328 |
+
if 'CC[C@@H]' in content or 'CC[C@H]' in content:
|
| 329 |
+
if not any(x in content for x in ['CC(C)', 'COC', 'CN(']):
|
| 330 |
+
return 'ABA', mods # 2-Aminobutyric acid
|
| 331 |
+
|
| 332 |
+
# Modified histidines
|
| 333 |
+
if 'c1cnc' in content:
|
| 334 |
+
if '[C@@H]1CN[C@@H](N1)F' in content:
|
| 335 |
+
return '2HF', mods # 2-fluoro-l-histidine
|
| 336 |
+
if 'c1cnc([nH]1)F' in content:
|
| 337 |
+
return '2HF1', mods # 2-fluoro-l-histidine variant
|
| 338 |
+
if 'c1c[nH]c(n1)F' in content:
|
| 339 |
+
return '2HF2', mods # 2-fluoro-l-histidine variant
|
| 340 |
+
|
| 341 |
+
# Sulfur and selenium containing
|
| 342 |
+
if '[SeH]' in content:
|
| 343 |
+
return 'CSE', mods # Selenocysteine
|
| 344 |
+
if 'S' in content:
|
| 345 |
+
if 'CSCc1ccccc1' in content:
|
| 346 |
+
return 'BCS', mods # benzylcysteine
|
| 347 |
+
if 'CCSC' in content:
|
| 348 |
+
return 'ESC', mods # Ethionine
|
| 349 |
+
if 'CCS' in content:
|
| 350 |
+
return 'HCS', mods # homocysteine
|
| 351 |
+
|
| 352 |
+
# Additional modifications
|
| 353 |
+
if 'CN=[N]=N' in content:
|
| 354 |
+
return 'AZDA', mods # azido-alanine
|
| 355 |
+
if '[NH]=[C](=[NH2])=[NH2]' in content:
|
| 356 |
+
if 'CCC[NH]=' in content:
|
| 357 |
+
return 'AGM', mods # 5-methyl-arginine
|
| 358 |
+
if 'CC[NH]=' in content:
|
| 359 |
+
return 'GDPR', mods # 2-Amino-3-guanidinopropionic acid
|
| 360 |
+
|
| 361 |
+
if 'CCON' in content:
|
| 362 |
+
return 'CAN', mods # canaline
|
| 363 |
+
if '[C@@H]1C=C[C@@H](C=C1)' in content:
|
| 364 |
+
return 'ACZ', mods # cis-amiclenomycin
|
| 365 |
+
if 'CCC(=O)[NH3]' in content:
|
| 366 |
+
return 'ONL', mods # 5-oxo-l-norleucine
|
| 367 |
+
if 'c1ccncc1' in content:
|
| 368 |
+
return 'PYR4', mods # 3-(4-Pyridyl)-alanine
|
| 369 |
+
if 'c1ccco1' in content:
|
| 370 |
+
return 'FUA2', mods # (2-furyl)-alanine
|
| 371 |
+
|
| 372 |
+
if 'c1ccc' in content:
|
| 373 |
+
if 'c1ccc(cc1)c1ccccc1' in content:
|
| 374 |
+
return 'BIF', mods # 4,4-biphenylalanine
|
| 375 |
+
if 'c1ccc(cc1)C(=O)c1ccccc1' in content:
|
| 376 |
+
return 'PBF', mods # 4-benzoyl-phenylalanine
|
| 377 |
+
if 'c1ccc(cc1)C(C)(C)C' in content:
|
| 378 |
+
return 'TBP4', mods # 4-tert-butyl-phenylalanine
|
| 379 |
+
if 'c1ccc(cc1)[C](=[NH2])=[NH2]' in content:
|
| 380 |
+
return '0BN', mods # 4-carbamimidoyl-l-phenylalanine
|
| 381 |
+
if 'c1cccc(c1)[C](=[NH2])=[NH2]' in content:
|
| 382 |
+
return 'APM', mods # m-amidinophenyl-3-alanine
|
| 383 |
+
|
| 384 |
+
# Multiple hydroxy patterns
|
| 385 |
+
if 'O' in content:
|
| 386 |
+
if '[C@H]([C@H](C)O)O' in content:
|
| 387 |
+
return 'ILX', mods # 4,5-dihydroxy-isoleucine
|
| 388 |
+
if '[C@H]([C@@H](C)O)O' in content:
|
| 389 |
+
return 'ALO', mods # Allo-threonine
|
| 390 |
+
if '[C@H](COP(O)(O)O)' in content:
|
| 391 |
+
return 'SEP', mods # phosphoserine
|
| 392 |
+
if '[C@H]([C@@H](C)OP(O)(O)O)' in content:
|
| 393 |
+
return 'TPO', mods # phosphothreonine
|
| 394 |
+
if '[C@H](c1ccc(O)cc1)O' in content:
|
| 395 |
+
return 'OMX', mods # (betar)-beta-hydroxy-l-tyrosine
|
| 396 |
+
if '[C@H](c1ccc(c(Cl)c1)O)O' in content:
|
| 397 |
+
return 'OMY', mods # (betar)-3-chloro-beta-hydroxy-l-tyrosine
|
| 398 |
+
|
| 399 |
+
# Heterocyclic patterns
|
| 400 |
+
if 'n1' in content:
|
| 401 |
+
if 'n1cccn1' in content:
|
| 402 |
+
return 'PYZ1', mods # 3-(1-Pyrazolyl)-alanine
|
| 403 |
+
if 'n1nncn1' in content:
|
| 404 |
+
return 'TEZA', mods # 3-(2-Tetrazolyl)-alanine
|
| 405 |
+
if 'c2c(n1)cccc2' in content:
|
| 406 |
+
return 'QU32', mods # 3-(2-Quinolyl)-alanine
|
| 407 |
+
if 'c1cnc2c(c1)cccc2' in content:
|
| 408 |
+
return 'QU33', mods # 3-(3-quinolyl)-alanine
|
| 409 |
+
if 'c1ccnc2c1cccc2' in content:
|
| 410 |
+
return 'QU34', mods # 3-(4-quinolyl)-alanine
|
| 411 |
+
if 'c1ccc2c(c1)nccc2' in content:
|
| 412 |
+
return 'QU35', mods # 3-(5-Quinolyl)-alanine
|
| 413 |
+
if 'c1ccc2c(c1)cncc2' in content:
|
| 414 |
+
return 'QU36', mods # 3-(6-Quinolyl)-alanine
|
| 415 |
+
if 'c1cnc2c(n1)cccc2' in content:
|
| 416 |
+
return 'QX32', mods # 3-(2-quinoxalyl)-alanine
|
| 417 |
+
|
| 418 |
+
# Multiple nitrogen patterns
|
| 419 |
+
if 'N' in content:
|
| 420 |
+
if '[NH3]CC[C@@H]' in content:
|
| 421 |
+
return 'DAB', mods # Diaminobutyric acid
|
| 422 |
+
if '[NH3]C[C@@H]' in content:
|
| 423 |
+
return 'DPP', mods # 2,3-Diaminopropanoic acid
|
| 424 |
+
if '[NH3]CCCCCC[C@@H]' in content:
|
| 425 |
+
return 'HHK', mods # (2s)-2,8-diaminooctanoic acid
|
| 426 |
+
if 'CCC[NH]=[C](=[NH2])=[NH2]' in content:
|
| 427 |
+
return 'GBUT', mods # 2-Amino-4-guanidinobutryric acid
|
| 428 |
+
if '[NH]=[C](=S)=[NH2]' in content:
|
| 429 |
+
return 'THIC', mods # Thio-citrulline
|
| 430 |
+
|
| 431 |
+
# Chain modified amino acids
|
| 432 |
+
if 'CC' in content:
|
| 433 |
+
if 'CCCC[C@@H]' in content:
|
| 434 |
+
return 'AHP', mods # 2-Aminoheptanoic acid
|
| 435 |
+
if 'CCC([C@@H])(C)C' in content:
|
| 436 |
+
return 'I2M', mods # 3-methyl-l-alloisoleucine
|
| 437 |
+
if 'CC[C@H]([C@@H])C' in content:
|
| 438 |
+
return 'IIL', mods # Allo-Isoleucine
|
| 439 |
+
if '[C@H](CCC(C)C)' in content:
|
| 440 |
+
return 'HLEU', mods # Homoleucine
|
| 441 |
+
if '[C@@H]([C@@H](C)O)C' in content:
|
| 442 |
+
return 'HLU', mods # beta-hydroxyleucine
|
| 443 |
+
|
| 444 |
+
# Modified glutamate/aspartate patterns
|
| 445 |
+
if '[C@@H]' in content:
|
| 446 |
+
if '[C@@H](C[C@@H](F))' in content:
|
| 447 |
+
return 'FGA4', mods # 4-Fluoro-glutamic acid
|
| 448 |
+
if '[C@@H](C[C@@H](O))' in content:
|
| 449 |
+
return '3GL', mods # 4-hydroxy-glutamic-acid
|
| 450 |
+
if '[C@@H](C[C@H](C))' in content:
|
| 451 |
+
return 'LME', mods # (3r)-3-methyl-l-glutamic acid
|
| 452 |
+
if '[C@@H](CC[C@H](C))' in content:
|
| 453 |
+
return 'MEG', mods # (3s)-3-methyl-l-glutamic acid
|
| 454 |
+
|
| 455 |
+
# Sulfur and selenium modifications
|
| 456 |
+
if 'S' in content:
|
| 457 |
+
if 'SCC[C@@H]' in content:
|
| 458 |
+
return 'HSER', mods # homoserine
|
| 459 |
+
if 'SCCN' in content:
|
| 460 |
+
return 'SLZ', mods # thialysine
|
| 461 |
+
if 'SC(=O)' in content:
|
| 462 |
+
return 'CSA', mods # s-acetonylcysteine
|
| 463 |
+
if '[S@@](=O)' in content:
|
| 464 |
+
return 'SME', mods # Methionine sulfoxide
|
| 465 |
+
if 'S(=O)(=O)' in content:
|
| 466 |
+
return 'OMT', mods # Methionine sulfone
|
| 467 |
+
|
| 468 |
+
# Double bond containing
|
| 469 |
+
if 'C=' in content:
|
| 470 |
+
if 'C=C[C@@H]' in content:
|
| 471 |
+
return '2AG', mods # 2-Allyl-glycine
|
| 472 |
+
if 'C=C[C@@H]' in content:
|
| 473 |
+
return 'LVG', mods # vinylglycine
|
| 474 |
+
if 'C=Cc1ccccc1' in content:
|
| 475 |
+
return 'STYA', mods # Styrylalanine
|
| 476 |
+
|
| 477 |
+
# Special cases
|
| 478 |
+
if '[C@@H]1Cc2c(C1)cccc2' in content:
|
| 479 |
+
return 'IGL', mods # alpha-amino-2-indanacetic acid
|
| 480 |
+
if '[C](=[C](=O)=O)=O' in content:
|
| 481 |
+
return '26P', mods # 2-amino-6-oxopimelic acid
|
| 482 |
+
if '[C](=[C](=O)=O)=C' in content:
|
| 483 |
+
return '2NP', mods # l-2-amino-6-methylene-pimelic acid
|
| 484 |
+
if 'c2cnc[nH]2' in content:
|
| 485 |
+
return 'HIS', mods # histidine core
|
| 486 |
+
if 'c1cccc2c1cc(O)cc2' in content:
|
| 487 |
+
return 'NAO1', mods # 5-hydroxy-1-naphthalene
|
| 488 |
+
if 'c1ccc2c(c1)cc(O)cc2' in content:
|
| 489 |
+
return 'NAO2', mods # 6-hydroxy-2-naphthalene
|
| 490 |
+
|
| 491 |
+
# Proline (P) - flexible ring numbers
|
| 492 |
+
if any([
|
| 493 |
+
# Check for any ring number in bond patterns
|
| 494 |
+
(segment.get('bond_after', '').startswith(f'N{n}C(=O)') and 'CCC' in content and
|
| 495 |
+
any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
|
| 496 |
+
for n in '123456789'
|
| 497 |
+
]) or any([
|
| 498 |
+
# Check ending patterns with any ring number
|
| 499 |
+
(f'CCCN{n}' in content and content.endswith('=O') and
|
| 500 |
+
any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
|
| 501 |
+
for n in '123456789'
|
| 502 |
+
]) or any([
|
| 503 |
+
# Handle CCC[C@H]n patterns
|
| 504 |
+
(content == f'CCC[C@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
|
| 505 |
+
(content == f'CCC[C@@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
|
| 506 |
+
# N-terminal Pro with any ring number
|
| 507 |
+
(f'N{n}CCC[C@H]{n}' in content) or
|
| 508 |
+
(f'N{n}CCC[C@@H]{n}' in content)
|
| 509 |
+
for n in '123456789'
|
| 510 |
+
]):
|
| 511 |
+
return 'Pro', mods
|
| 512 |
+
|
| 513 |
+
# Tryptophan (W) - more specific indole pattern
|
| 514 |
+
if re.search(r'c[0-9]c\[nH\]c[0-9]ccccc[0-9][0-9]', content) and \
|
| 515 |
+
'c[nH]c' in content.replace(' ', ''):
|
| 516 |
+
return 'Trp', mods
|
| 517 |
+
|
| 518 |
+
# Lysine (K) - both patterns
|
| 519 |
+
if '[C@@H](CCCCN)' in content or '[C@H](CCCCN)' in content:
|
| 520 |
+
return 'Lys', mods
|
| 521 |
+
|
| 522 |
+
# Arginine (R) - both patterns
|
| 523 |
+
if '[C@@H](CCCNC(=N)N)' in content or '[C@H](CCCNC(=N)N)' in content:
|
| 524 |
+
return 'Arg', mods
|
| 525 |
+
|
| 526 |
+
if ('C[C@H](CCCC)' in content or 'C[C@@H](CCCC)' in content) and 'CC(C)' not in content:
|
| 527 |
+
return 'Nle', mods
|
| 528 |
+
|
| 529 |
+
# Ornithine (Orn) - 3-carbon chain with NH2
|
| 530 |
+
if ('C[C@H](CCCN)' in content or 'C[C@@H](CCCN)' in content) and 'CC(C)' not in content:
|
| 531 |
+
return 'Orn', mods
|
| 532 |
+
|
| 533 |
+
# 2-Naphthylalanine (2Nal) - distinct from Phe pattern
|
| 534 |
+
if ('Cc3cc2ccccc2c3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 535 |
+
return '2Nal', mods
|
| 536 |
+
|
| 537 |
+
# Cyclohexylalanine (Cha) - already in your code but moved here for clarity
|
| 538 |
+
if 'N2CCCCC2' in content or 'CCCCC2' in content:
|
| 539 |
+
return 'Cha', mods
|
| 540 |
+
|
| 541 |
+
# Aminobutyric acid (Abu) - 2-carbon chain
|
| 542 |
+
if ('C[C@H](CC)' in content or 'C[C@@H](CC)' in content) and not any(p in content for p in ['CC(C)', 'CCCC', 'CCC(C)']):
|
| 543 |
+
return 'Abu', mods
|
| 544 |
+
|
| 545 |
+
# Pipecolic acid (Pip) - 6-membered ring like Pro
|
| 546 |
+
if ('N3CCCCC3' in content or 'CCCCC3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 547 |
+
return 'Pip', mods
|
| 548 |
+
|
| 549 |
+
# Cyclohexylglycine (Chg) - direct cyclohexyl without CH2
|
| 550 |
+
if ('C[C@H](C1CCCCC1)' in content or 'C[C@@H](C1CCCCC1)' in content):
|
| 551 |
+
return 'Chg', mods
|
| 552 |
+
|
| 553 |
+
# 4-Fluorophenylalanine (4F-Phe)
|
| 554 |
+
if ('Cc2ccc(F)cc2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 555 |
+
return '4F-Phe', mods
|
| 556 |
+
|
| 557 |
+
# Regular residue identification
|
| 558 |
+
if ('NCC(=O)' in content) or (content == 'C'):
|
| 559 |
+
# Middle case - between bonds
|
| 560 |
+
if segment.get('bond_before') and segment.get('bond_after'):
|
| 561 |
+
if ('C(=O)N' in segment['bond_before'] or 'C(=O)N(C)' in segment['bond_before']):
|
| 562 |
+
return 'Gly', mods
|
| 563 |
+
# Terminal case - at the end
|
| 564 |
+
elif segment.get('bond_before') and segment.get('bond_before').startswith('C(=O)N'):
|
| 565 |
+
return 'Gly', mods
|
| 566 |
+
|
| 567 |
+
if 'CC(C)C[C@H]' in content or 'CC(C)C[C@@H]' in content:
|
| 568 |
+
return 'Leu', mods
|
| 569 |
+
if '[C@@H](CC(C)C)' in content or '[C@H](CC(C)C)' in content:
|
| 570 |
+
return 'Leu', mods
|
| 571 |
+
|
| 572 |
+
if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content:
|
| 573 |
+
return 'Thr', mods
|
| 574 |
+
|
| 575 |
+
if '[C@H](Cc2ccccc2)' in content or '[C@@H](Cc2ccccc2)' in content:
|
| 576 |
+
return 'Phe', mods
|
| 577 |
+
|
| 578 |
+
if ('[C@H](C(C)C)' in content or # With outer parentheses
|
| 579 |
+
'[C@@H](C(C)C)' in content or # With outer parentheses
|
| 580 |
+
'[C@H]C(C)C' in content or # Without outer parentheses
|
| 581 |
+
'[C@@H]C(C)C' in content): # Without outer parentheses
|
| 582 |
+
if not any(p in content for p in ['CC(C)C[C@H]', 'CC(C)C[C@@H]']): # Still check not Leu
|
| 583 |
+
return 'Val', mods
|
| 584 |
+
|
| 585 |
+
if '[C@H](COC(C)(C)C)' in content or '[C@@H](COC(C)(C)C)' in content:
|
| 586 |
+
return 'O-tBu', mods
|
| 587 |
+
|
| 588 |
+
if any([
|
| 589 |
+
'CC[C@H](C)' in content,
|
| 590 |
+
'CC[C@@H](C)' in content,
|
| 591 |
+
'C(C)C[C@H]' in content and 'CC(C)C' not in content,
|
| 592 |
+
'C(C)C[C@@H]' in content and 'CC(C)C' not in content
|
| 593 |
+
]):
|
| 594 |
+
return 'Ile', mods
|
| 595 |
+
|
| 596 |
+
if ('[C@H](C)' in content or '[C@@H](C)' in content):
|
| 597 |
+
if not any(p in content for p in ['C(C)C', 'COC', 'CN(', 'C(C)O', 'CC[C@H]', 'CC[C@@H]']):
|
| 598 |
+
return 'Ala', mods
|
| 599 |
+
|
| 600 |
+
# Tyrosine (Tyr) - 4-hydroxybenzyl side chain
|
| 601 |
+
if re.search(r'Cc[0-9]ccc\(O\)cc[0-9]', content):
|
| 602 |
+
return 'Tyr', mods
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
# Serine (Ser) - Hydroxymethyl side chain
|
| 606 |
+
if '[C@H](CO)' in content or '[C@@H](CO)' in content:
|
| 607 |
+
if not ('C(C)O' in content or 'COC' in content):
|
| 608 |
+
return 'Ser', mods
|
| 609 |
+
|
| 610 |
+
# Threonine (Thr) - 1-hydroxyethyl side chain
|
| 611 |
+
if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content or '[C@@H](C)O' in content or '[C@H](C)O' in content:
|
| 612 |
+
return 'Thr', mods
|
| 613 |
+
|
| 614 |
+
# Cysteine (Cys) - Thiol side chain
|
| 615 |
+
if '[C@H](CS)' in content or '[C@@H](CS)' in content:
|
| 616 |
+
return 'Cys', mods
|
| 617 |
+
|
| 618 |
+
# Methionine (Met) - Methylthioethyl side chain
|
| 619 |
+
if ('C[C@H](CCSC)' in content or 'C[C@@H](CCSC)' in content):
|
| 620 |
+
return 'Met', mods
|
| 621 |
+
|
| 622 |
+
# Asparagine (Asn) - Carbamoylmethyl side chain
|
| 623 |
+
if ('CC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 624 |
+
return 'Asn', mods
|
| 625 |
+
|
| 626 |
+
# Glutamine (Gln) - Carbamoylethyl side chain
|
| 627 |
+
if ('CCC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 628 |
+
return 'Gln', mods
|
| 629 |
+
|
| 630 |
+
# Aspartic acid (Asp) - Carboxymethyl side chain
|
| 631 |
+
if ('CC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 632 |
+
return 'Asp', mods
|
| 633 |
+
|
| 634 |
+
# Glutamic acid (Glu) - Carboxyethyl side chain
|
| 635 |
+
if ('CCC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 636 |
+
return 'Glu', mods
|
| 637 |
+
|
| 638 |
+
# Arginine (Arg) - 3-guanidinopropyl side chain
|
| 639 |
+
if ('CCCNC(=N)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 640 |
+
return 'Arg', mods
|
| 641 |
+
|
| 642 |
+
# Histidine (His) - Imidazole side chain
|
| 643 |
+
if ('Cc2cnc[nH]2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 644 |
+
return 'His', mods
|
| 645 |
+
|
| 646 |
+
return None, mods
|
| 647 |
+
|
| 648 |
+
def get_modifications(self, segment):
|
| 649 |
+
"""Get modifications based on bond types"""
|
| 650 |
+
mods = []
|
| 651 |
+
if segment.get('bond_after'):
|
| 652 |
+
if 'N(C)' in segment['bond_after'] or segment['bond_after'].startswith('C(=O)N(C)'):
|
| 653 |
+
mods.append('N-Me')
|
| 654 |
+
if 'OC(=O)' in segment['bond_after']:
|
| 655 |
+
mods.append('O-linked')
|
| 656 |
+
return mods
|
| 657 |
+
|
| 658 |
+
def analyze_structure(self, smiles):
|
| 659 |
+
"""Main analysis function with debug output"""
|
| 660 |
+
print("\nAnalyzing structure:", smiles)
|
| 661 |
+
|
| 662 |
+
# Split into segments
|
| 663 |
+
segments = self.split_on_bonds(smiles)
|
| 664 |
+
|
| 665 |
+
print("\nSegment Analysis:")
|
| 666 |
+
sequence = []
|
| 667 |
+
for i, segment in enumerate(segments):
|
| 668 |
+
print(f"\nSegment {i}:")
|
| 669 |
+
print(f"Content: {segment['content']}")
|
| 670 |
+
print(f"Bond before: {segment.get('bond_before', 'None')}")
|
| 671 |
+
print(f"Bond after: {segment.get('bond_after', 'None')}")
|
| 672 |
+
|
| 673 |
+
residue, mods = self.identify_residue(segment)
|
| 674 |
+
if residue:
|
| 675 |
+
if mods:
|
| 676 |
+
sequence.append(f"{residue}({','.join(mods)})")
|
| 677 |
+
else:
|
| 678 |
+
sequence.append(residue)
|
| 679 |
+
print(f"Identified as: {residue}")
|
| 680 |
+
print(f"Modifications: {mods}")
|
| 681 |
+
else:
|
| 682 |
+
print(f"Warning: Could not identify residue in segment: {segment['content']}")
|
| 683 |
+
|
| 684 |
+
# Check if cyclic
|
| 685 |
+
is_cyclic, peptide_cycles, aromatic_cycles = self.is_cyclic(smiles)
|
| 686 |
+
three_letter = '-'.join(sequence)
|
| 687 |
+
one_letter = ''.join(self.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence)
|
| 688 |
+
|
| 689 |
+
if is_cyclic:
|
| 690 |
+
three_letter = f"cyclo({three_letter})"
|
| 691 |
+
one_letter = f"cyclo({one_letter})"
|
| 692 |
+
|
| 693 |
+
print(f"\nFinal sequence: {three_letter}")
|
| 694 |
+
print(f"One-letter code: {one_letter}")
|
| 695 |
+
print(f"Is cyclic: {is_cyclic}")
|
| 696 |
+
#print(f"Peptide cycles: {peptide_cycles}")
|
| 697 |
+
#print(f"Aromatic cycles: {aromatic_cycles}")
|
| 698 |
+
|
| 699 |
+
return three_letter, len(segments)
|
| 700 |
+
"""return {
|
| 701 |
+
'three_letter': three_letter,
|
| 702 |
+
#'one_letter': one_letter,
|
| 703 |
+
'is_cyclic': is_cyclic
|
| 704 |
+
}"""
|
| 705 |
+
|
| 706 |
+
def return_sequence(self, smiles):
|
| 707 |
+
"""Main analysis function with debug output"""
|
| 708 |
+
print("\nAnalyzing structure:", smiles)
|
| 709 |
+
|
| 710 |
+
# Split into segments
|
| 711 |
+
segments = self.split_on_bonds(smiles)
|
| 712 |
+
|
| 713 |
+
print("\nSegment Analysis:")
|
| 714 |
+
sequence = []
|
| 715 |
+
for i, segment in enumerate(segments):
|
| 716 |
+
print(f"\nSegment {i}:")
|
| 717 |
+
print(f"Content: {segment['content']}")
|
| 718 |
+
print(f"Bond before: {segment.get('bond_before', 'None')}")
|
| 719 |
+
print(f"Bond after: {segment.get('bond_after', 'None')}")
|
| 720 |
+
|
| 721 |
+
residue, mods = self.identify_residue(segment)
|
| 722 |
+
if residue:
|
| 723 |
+
if mods:
|
| 724 |
+
sequence.append(f"{residue}({','.join(mods)})")
|
| 725 |
+
else:
|
| 726 |
+
sequence.append(residue)
|
| 727 |
+
print(f"Identified as: {residue}")
|
| 728 |
+
print(f"Modifications: {mods}")
|
| 729 |
+
else:
|
| 730 |
+
print(f"Warning: Could not identify residue in segment: {segment['content']}")
|
| 731 |
+
|
| 732 |
+
return sequence
|
| 733 |
+
|
| 734 |
+
"""
|
| 735 |
+
def annotate_cyclic_structure(mol, sequence):
|
| 736 |
+
'''Create annotated 2D structure with clear, non-overlapping residue labels'''
|
| 737 |
+
# Generate 2D coordinates
|
| 738 |
+
# Generate 2D coordinates
|
| 739 |
+
AllChem.Compute2DCoords(mol)
|
| 740 |
+
|
| 741 |
+
# Create drawer with larger size for annotations
|
| 742 |
+
drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000) # Even larger size
|
| 743 |
+
|
| 744 |
+
# Get residue list and reverse it to match structural representation
|
| 745 |
+
if sequence.startswith('cyclo('):
|
| 746 |
+
residues = sequence[6:-1].split('-')
|
| 747 |
+
else:
|
| 748 |
+
residues = sequence.split('-')
|
| 749 |
+
residues = list(reversed(residues)) # Reverse the sequence
|
| 750 |
+
|
| 751 |
+
# Draw molecule first to get its bounds
|
| 752 |
+
drawer.drawOptions().addAtomIndices = False
|
| 753 |
+
drawer.DrawMolecule(mol)
|
| 754 |
+
drawer.FinishDrawing()
|
| 755 |
+
|
| 756 |
+
# Convert to PIL Image
|
| 757 |
+
img = Image.open(BytesIO(drawer.GetDrawingText()))
|
| 758 |
+
draw = ImageDraw.Draw(img)
|
| 759 |
+
|
| 760 |
+
try:
|
| 761 |
+
# Try to use DejaVuSans as it's commonly available on Linux systems
|
| 762 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
|
| 763 |
+
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
|
| 764 |
+
except OSError:
|
| 765 |
+
try:
|
| 766 |
+
# Fallback to Arial if available (common on Windows)
|
| 767 |
+
font = ImageFont.truetype("arial.ttf", 60)
|
| 768 |
+
small_font = ImageFont.truetype("arial.ttf", 60)
|
| 769 |
+
except OSError:
|
| 770 |
+
# If no TrueType fonts are available, fall back to default
|
| 771 |
+
print("Warning: TrueType fonts not available, using default font")
|
| 772 |
+
font = ImageFont.load_default()
|
| 773 |
+
small_font = ImageFont.load_default()
|
| 774 |
+
# Get molecule bounds
|
| 775 |
+
conf = mol.GetConformer()
|
| 776 |
+
positions = []
|
| 777 |
+
for i in range(mol.GetNumAtoms()):
|
| 778 |
+
pos = conf.GetAtomPosition(i)
|
| 779 |
+
positions.append((pos.x, pos.y))
|
| 780 |
+
|
| 781 |
+
x_coords = [p[0] for p in positions]
|
| 782 |
+
y_coords = [p[1] for p in positions]
|
| 783 |
+
min_x, max_x = min(x_coords), max(x_coords)
|
| 784 |
+
min_y, max_y = min(y_coords), max(y_coords)
|
| 785 |
+
|
| 786 |
+
# Calculate scaling factors
|
| 787 |
+
scale = 150 # Increased scale factor
|
| 788 |
+
center_x = 1000 # Image center
|
| 789 |
+
center_y = 1000
|
| 790 |
+
|
| 791 |
+
# Add residue labels in a circular arrangement around the structure
|
| 792 |
+
n_residues = len(residues)
|
| 793 |
+
radius = 700 # Distance of labels from center
|
| 794 |
+
|
| 795 |
+
# Start from the rightmost point (3 o'clock position) and go counterclockwise
|
| 796 |
+
# Offset by -3 positions to align with structure
|
| 797 |
+
offset = 0 # Adjust this value to match the structure alignment
|
| 798 |
+
for i, residue in enumerate(residues):
|
| 799 |
+
# Calculate position in a circle around the structure
|
| 800 |
+
# Start from 0 (3 o'clock) and go counterclockwise
|
| 801 |
+
angle = -(2 * np.pi * ((i + offset) % n_residues) / n_residues)
|
| 802 |
+
|
| 803 |
+
# Calculate label position
|
| 804 |
+
label_x = center_x + radius * np.cos(angle)
|
| 805 |
+
label_y = center_y + radius * np.sin(angle)
|
| 806 |
+
|
| 807 |
+
# Draw residue label
|
| 808 |
+
text = f"{i+1}. {residue}"
|
| 809 |
+
bbox = draw.textbbox((label_x, label_y), text, font=font)
|
| 810 |
+
padding = 10
|
| 811 |
+
draw.rectangle([bbox[0]-padding, bbox[1]-padding,
|
| 812 |
+
bbox[2]+padding, bbox[3]+padding],
|
| 813 |
+
fill='white', outline='white')
|
| 814 |
+
draw.text((label_x, label_y), text,
|
| 815 |
+
font=font, fill='black', anchor="mm")
|
| 816 |
+
|
| 817 |
+
# Add sequence at the top with white background
|
| 818 |
+
seq_text = f"Sequence: {sequence}"
|
| 819 |
+
bbox = draw.textbbox((center_x, 100), seq_text, font=small_font)
|
| 820 |
+
padding = 10
|
| 821 |
+
draw.rectangle([bbox[0]-padding, bbox[1]-padding,
|
| 822 |
+
bbox[2]+padding, bbox[3]+padding],
|
| 823 |
+
fill='white', outline='white')
|
| 824 |
+
draw.text((center_x, 100), seq_text,
|
| 825 |
+
font=small_font, fill='black', anchor="mm")
|
| 826 |
+
|
| 827 |
+
return img
|
| 828 |
+
|
| 829 |
+
"""
|
| 830 |
+
def annotate_cyclic_structure(mol, sequence):
|
| 831 |
+
"""Create structure visualization with just the sequence header"""
|
| 832 |
+
# Generate 2D coordinates
|
| 833 |
+
AllChem.Compute2DCoords(mol)
|
| 834 |
+
|
| 835 |
+
# Create drawer with larger size for annotations
|
| 836 |
+
drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000)
|
| 837 |
+
|
| 838 |
+
# Draw molecule first
|
| 839 |
+
drawer.drawOptions().addAtomIndices = False
|
| 840 |
+
drawer.DrawMolecule(mol)
|
| 841 |
+
drawer.FinishDrawing()
|
| 842 |
+
|
| 843 |
+
# Convert to PIL Image
|
| 844 |
+
img = Image.open(BytesIO(drawer.GetDrawingText()))
|
| 845 |
+
draw = ImageDraw.Draw(img)
|
| 846 |
+
try:
|
| 847 |
+
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
|
| 848 |
+
except OSError:
|
| 849 |
+
try:
|
| 850 |
+
small_font = ImageFont.truetype("arial.ttf", 60)
|
| 851 |
+
except OSError:
|
| 852 |
+
print("Warning: TrueType fonts not available, using default font")
|
| 853 |
+
small_font = ImageFont.load_default()
|
| 854 |
+
|
| 855 |
+
# Add just the sequence header at the top
|
| 856 |
+
seq_text = f"Sequence: {sequence}"
|
| 857 |
+
bbox = draw.textbbox((1000, 100), seq_text, font=small_font)
|
| 858 |
+
padding = 10
|
| 859 |
+
draw.rectangle([bbox[0]-padding, bbox[1]-padding,
|
| 860 |
+
bbox[2]+padding, bbox[3]+padding],
|
| 861 |
+
fill='white', outline='white')
|
| 862 |
+
draw.text((1000, 100), seq_text,
|
| 863 |
+
font=small_font, fill='black', anchor="mm")
|
| 864 |
+
|
| 865 |
+
return img
|
| 866 |
+
|
| 867 |
+
def create_enhanced_linear_viz(sequence, smiles):
|
| 868 |
+
"""Create an enhanced linear representation using PeptideAnalyzer"""
|
| 869 |
+
analyzer = PeptideAnalyzer() # Create analyzer instance
|
| 870 |
+
|
| 871 |
+
# Create figure with two subplots
|
| 872 |
+
fig = plt.figure(figsize=(15, 10))
|
| 873 |
+
gs = fig.add_gridspec(2, 1, height_ratios=[1, 2])
|
| 874 |
+
ax_struct = fig.add_subplot(gs[0])
|
| 875 |
+
ax_detail = fig.add_subplot(gs[1])
|
| 876 |
+
|
| 877 |
+
# Parse sequence and get residues
|
| 878 |
+
if sequence.startswith('cyclo('):
|
| 879 |
+
residues = sequence[6:-1].split('-')
|
| 880 |
+
else:
|
| 881 |
+
residues = sequence.split('-')
|
| 882 |
+
|
| 883 |
+
# Get segments using analyzer
|
| 884 |
+
segments = analyzer.split_on_bonds(smiles)
|
| 885 |
+
|
| 886 |
+
# Debug print
|
| 887 |
+
print(f"Number of residues: {len(residues)}")
|
| 888 |
+
print(f"Number of segments: {len(segments)}")
|
| 889 |
+
|
| 890 |
+
# Top subplot - Basic structure
|
| 891 |
+
ax_struct.set_xlim(0, 10)
|
| 892 |
+
ax_struct.set_ylim(0, 2)
|
| 893 |
+
|
| 894 |
+
num_residues = len(residues)
|
| 895 |
+
spacing = 9.0 / (num_residues - 1) if num_residues > 1 else 9.0
|
| 896 |
+
|
| 897 |
+
# Draw basic structure
|
| 898 |
+
y_pos = 1.5
|
| 899 |
+
for i in range(num_residues):
|
| 900 |
+
x_pos = 0.5 + i * spacing
|
| 901 |
+
|
| 902 |
+
# Draw amino acid box
|
| 903 |
+
rect = patches.Rectangle((x_pos-0.3, y_pos-0.2), 0.6, 0.4,
|
| 904 |
+
facecolor='lightblue', edgecolor='black')
|
| 905 |
+
ax_struct.add_patch(rect)
|
| 906 |
+
|
| 907 |
+
# Draw connecting bonds if not the last residue
|
| 908 |
+
if i < num_residues - 1:
|
| 909 |
+
segment = segments[i] if i < len(segments) else None
|
| 910 |
+
if segment:
|
| 911 |
+
# Determine bond type from segment info
|
| 912 |
+
bond_type = 'ester' if 'O-linked' in segment.get('bond_after', '') else 'peptide'
|
| 913 |
+
is_n_methylated = 'N-Me' in segment.get('bond_after', '')
|
| 914 |
+
|
| 915 |
+
bond_color = 'red' if bond_type == 'ester' else 'black'
|
| 916 |
+
linestyle = '--' if bond_type == 'ester' else '-'
|
| 917 |
+
|
| 918 |
+
# Draw bond line
|
| 919 |
+
ax_struct.plot([x_pos+0.3, x_pos+spacing-0.3], [y_pos, y_pos],
|
| 920 |
+
color=bond_color, linestyle=linestyle, linewidth=2)
|
| 921 |
+
|
| 922 |
+
# Add bond type label
|
| 923 |
+
mid_x = x_pos + spacing/2
|
| 924 |
+
bond_label = f"{bond_type}"
|
| 925 |
+
if is_n_methylated:
|
| 926 |
+
bond_label += "\n(N-Me)"
|
| 927 |
+
ax_struct.text(mid_x, y_pos+0.1, bond_label,
|
| 928 |
+
ha='center', va='bottom', fontsize=10,
|
| 929 |
+
color=bond_color)
|
| 930 |
+
|
| 931 |
+
# Add residue label
|
| 932 |
+
ax_struct.text(x_pos, y_pos-0.5, residues[i],
|
| 933 |
+
ha='center', va='top', fontsize=14)
|
| 934 |
+
|
| 935 |
+
# Bottom subplot - Detailed breakdown
|
| 936 |
+
ax_detail.set_ylim(0, len(segments)+1)
|
| 937 |
+
ax_detail.set_xlim(0, 1)
|
| 938 |
+
|
| 939 |
+
# Create detailed breakdown
|
| 940 |
+
segment_y = len(segments) # Start from top
|
| 941 |
+
for i, segment in enumerate(segments):
|
| 942 |
+
y = segment_y - i
|
| 943 |
+
|
| 944 |
+
# Check if this is a bond or residue
|
| 945 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 946 |
+
if residue:
|
| 947 |
+
text = f"Residue {i+1}: {residue}"
|
| 948 |
+
if mods:
|
| 949 |
+
text += f" ({', '.join(mods)})"
|
| 950 |
+
color = 'blue'
|
| 951 |
+
else:
|
| 952 |
+
# Must be a bond
|
| 953 |
+
text = f"Bond {i}: "
|
| 954 |
+
if 'O-linked' in segment.get('bond_after', ''):
|
| 955 |
+
text += "ester"
|
| 956 |
+
elif 'N-Me' in segment.get('bond_after', ''):
|
| 957 |
+
text += "peptide (N-methylated)"
|
| 958 |
+
else:
|
| 959 |
+
text += "peptide"
|
| 960 |
+
color = 'red'
|
| 961 |
+
|
| 962 |
+
# Add segment analysis
|
| 963 |
+
ax_detail.text(0.05, y, text, fontsize=12, color=color)
|
| 964 |
+
ax_detail.text(0.5, y, f"SMILES: {segment.get('content', '')}", fontsize=10, color='gray')
|
| 965 |
+
|
| 966 |
+
# If cyclic, add connection indicator
|
| 967 |
+
if sequence.startswith('cyclo('):
|
| 968 |
+
ax_struct.annotate('', xy=(9.5, y_pos), xytext=(0.5, y_pos),
|
| 969 |
+
arrowprops=dict(arrowstyle='<->', color='red', lw=2))
|
| 970 |
+
ax_struct.text(5, y_pos+0.3, 'Cyclic Connection',
|
| 971 |
+
ha='center', color='red', fontsize=14)
|
| 972 |
+
|
| 973 |
+
# Add titles and adjust layout
|
| 974 |
+
ax_struct.set_title("Peptide Structure Overview", pad=20)
|
| 975 |
+
ax_detail.set_title("Segment Analysis Breakdown", pad=20)
|
| 976 |
+
|
| 977 |
+
# Remove axes
|
| 978 |
+
for ax in [ax_struct, ax_detail]:
|
| 979 |
+
ax.set_xticks([])
|
| 980 |
+
ax.set_yticks([])
|
| 981 |
+
ax.axis('off')
|
| 982 |
+
|
| 983 |
+
plt.tight_layout()
|
| 984 |
+
return fig
|
| 985 |
+
|
| 986 |
+
class PeptideStructureGenerator:
|
| 987 |
+
"""A class to generate 3D structures of peptides using different embedding methods"""
|
| 988 |
+
|
| 989 |
+
@staticmethod
|
| 990 |
+
def prepare_molecule(smiles):
|
| 991 |
+
"""Prepare molecule with proper hydrogen handling"""
|
| 992 |
+
mol = Chem.MolFromSmiles(smiles, sanitize=False)
|
| 993 |
+
if mol is None:
|
| 994 |
+
raise ValueError("Failed to create molecule from SMILES")
|
| 995 |
+
|
| 996 |
+
# Calculate valence for each atom
|
| 997 |
+
for atom in mol.GetAtoms():
|
| 998 |
+
atom.UpdatePropertyCache(strict=False)
|
| 999 |
+
|
| 1000 |
+
# Sanitize with reduced requirements
|
| 1001 |
+
Chem.SanitizeMol(mol,
|
| 1002 |
+
sanitizeOps=Chem.SANITIZE_FINDRADICALS|
|
| 1003 |
+
Chem.SANITIZE_KEKULIZE|
|
| 1004 |
+
Chem.SANITIZE_SETAROMATICITY|
|
| 1005 |
+
Chem.SANITIZE_SETCONJUGATION|
|
| 1006 |
+
Chem.SANITIZE_SETHYBRIDIZATION|
|
| 1007 |
+
Chem.SANITIZE_CLEANUPCHIRALITY)
|
| 1008 |
+
|
| 1009 |
+
mol = Chem.AddHs(mol)
|
| 1010 |
+
return mol
|
| 1011 |
+
|
| 1012 |
+
@staticmethod
|
| 1013 |
+
def get_etkdg_params(attempt=0):
|
| 1014 |
+
"""Get ETKDG parameters with optional modifications based on attempt number"""
|
| 1015 |
+
params = AllChem.ETKDGv3()
|
| 1016 |
+
params.randomSeed = -1
|
| 1017 |
+
params.maxIterations = 200
|
| 1018 |
+
params.numThreads = 4 # Reduced for web interface
|
| 1019 |
+
params.useBasicKnowledge = True
|
| 1020 |
+
params.enforceChirality = True
|
| 1021 |
+
params.useExpTorsionAnglePrefs = True
|
| 1022 |
+
params.useSmallRingTorsions = True
|
| 1023 |
+
params.useMacrocycleTorsions = True
|
| 1024 |
+
params.ETversion = 2
|
| 1025 |
+
params.pruneRmsThresh = -1
|
| 1026 |
+
params.embedRmsThresh = 0.5
|
| 1027 |
+
|
| 1028 |
+
if attempt > 10:
|
| 1029 |
+
params.bondLength = 1.5 + (attempt - 10) * 0.02
|
| 1030 |
+
params.useExpTorsionAnglePrefs = False
|
| 1031 |
+
|
| 1032 |
+
return params
|
| 1033 |
+
|
| 1034 |
+
def generate_structure_etkdg(self, smiles, max_attempts=20):
|
| 1035 |
+
"""Generate 3D structure using ETKDG without UFF optimization"""
|
| 1036 |
+
success = False
|
| 1037 |
+
mol = None
|
| 1038 |
+
|
| 1039 |
+
for attempt in range(max_attempts):
|
| 1040 |
+
try:
|
| 1041 |
+
mol = self.prepare_molecule(smiles)
|
| 1042 |
+
params = self.get_etkdg_params(attempt)
|
| 1043 |
+
|
| 1044 |
+
if AllChem.EmbedMolecule(mol, params) == 0:
|
| 1045 |
+
success = True
|
| 1046 |
+
break
|
| 1047 |
+
except Exception as e:
|
| 1048 |
+
continue
|
| 1049 |
+
|
| 1050 |
+
if not success:
|
| 1051 |
+
raise ValueError("Failed to generate structure with ETKDG")
|
| 1052 |
+
|
| 1053 |
+
return mol
|
| 1054 |
+
|
| 1055 |
+
def generate_structure_uff(self, smiles, max_attempts=20):
|
| 1056 |
+
"""Generate 3D structure using ETKDG followed by UFF optimization"""
|
| 1057 |
+
best_mol = None
|
| 1058 |
+
lowest_energy = float('inf')
|
| 1059 |
+
|
| 1060 |
+
for attempt in range(max_attempts):
|
| 1061 |
+
try:
|
| 1062 |
+
test_mol = self.prepare_molecule(smiles)
|
| 1063 |
+
params = self.get_etkdg_params(attempt)
|
| 1064 |
+
|
| 1065 |
+
if AllChem.EmbedMolecule(test_mol, params) == 0:
|
| 1066 |
+
res = AllChem.UFFOptimizeMolecule(test_mol, maxIters=2000,
|
| 1067 |
+
vdwThresh=10.0, confId=0,
|
| 1068 |
+
ignoreInterfragInteractions=True)
|
| 1069 |
+
|
| 1070 |
+
if res == 0:
|
| 1071 |
+
ff = AllChem.UFFGetMoleculeForceField(test_mol)
|
| 1072 |
+
if ff:
|
| 1073 |
+
current_energy = ff.CalcEnergy()
|
| 1074 |
+
if current_energy < lowest_energy:
|
| 1075 |
+
lowest_energy = current_energy
|
| 1076 |
+
best_mol = Chem.Mol(test_mol)
|
| 1077 |
+
except Exception:
|
| 1078 |
+
continue
|
| 1079 |
+
|
| 1080 |
+
if best_mol is None:
|
| 1081 |
+
raise ValueError("Failed to generate optimized structure")
|
| 1082 |
+
|
| 1083 |
+
return best_mol
|
| 1084 |
+
|
| 1085 |
+
@staticmethod
|
| 1086 |
+
def mol_to_sdf_bytes(mol):
|
| 1087 |
+
"""Convert RDKit molecule to SDF file bytes"""
|
| 1088 |
+
# First write to StringIO in text mode
|
| 1089 |
+
sio = StringIO()
|
| 1090 |
+
writer = Chem.SDWriter(sio)
|
| 1091 |
+
writer.write(mol)
|
| 1092 |
+
writer.close()
|
| 1093 |
+
|
| 1094 |
+
# Convert the string to bytes
|
| 1095 |
+
return sio.getvalue().encode('utf-8')
|
| 1096 |
+
|
| 1097 |
+
def process_input(smiles_input=None, file_obj=None, show_linear=False,
|
| 1098 |
+
show_segment_details=False, generate_3d=False, use_uff=False):
|
| 1099 |
+
"""Process input and create visualizations using PeptideAnalyzer"""
|
| 1100 |
+
analyzer = PeptideAnalyzer()
|
| 1101 |
+
temp_dir = tempfile.mkdtemp() if generate_3d else None
|
| 1102 |
+
structure_files = []
|
| 1103 |
+
|
| 1104 |
+
# Handle direct SMILES input
|
| 1105 |
+
if smiles_input:
|
| 1106 |
+
smiles = smiles_input.strip()
|
| 1107 |
+
|
| 1108 |
+
# First check if it's a peptide using analyzer's method
|
| 1109 |
+
if not analyzer.is_peptide(smiles):
|
| 1110 |
+
return "Error: Input SMILES does not appear to be a peptide structure.", None, None
|
| 1111 |
+
|
| 1112 |
+
try:
|
| 1113 |
+
# Create molecule
|
| 1114 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 1115 |
+
if mol is None:
|
| 1116 |
+
return "Error: Invalid SMILES notation.", None, None
|
| 1117 |
+
|
| 1118 |
+
# Generate 3D structures if requested
|
| 1119 |
+
if generate_3d:
|
| 1120 |
+
generator = PeptideStructureGenerator()
|
| 1121 |
+
|
| 1122 |
+
try:
|
| 1123 |
+
# Generate ETKDG structure
|
| 1124 |
+
mol_etkdg = generator.generate_structure_etkdg(smiles)
|
| 1125 |
+
etkdg_path = os.path.join(temp_dir, "structure_etkdg.sdf")
|
| 1126 |
+
writer = Chem.SDWriter(etkdg_path)
|
| 1127 |
+
writer.write(mol_etkdg)
|
| 1128 |
+
writer.close()
|
| 1129 |
+
structure_files.append(etkdg_path)
|
| 1130 |
+
|
| 1131 |
+
# Generate UFF structure if requested
|
| 1132 |
+
if use_uff:
|
| 1133 |
+
mol_uff = generator.generate_structure_uff(smiles)
|
| 1134 |
+
uff_path = os.path.join(temp_dir, "structure_uff.sdf")
|
| 1135 |
+
writer = Chem.SDWriter(uff_path)
|
| 1136 |
+
writer.write(mol_uff)
|
| 1137 |
+
writer.close()
|
| 1138 |
+
structure_files.append(uff_path)
|
| 1139 |
+
|
| 1140 |
+
except Exception as e:
|
| 1141 |
+
return f"Error generating 3D structures: {str(e)}", None, None, None
|
| 1142 |
+
|
| 1143 |
+
# Use analyzer to get sequence
|
| 1144 |
+
segments = analyzer.split_on_bonds(smiles)
|
| 1145 |
+
|
| 1146 |
+
# Process segments and build sequence
|
| 1147 |
+
sequence_parts = []
|
| 1148 |
+
output_text = ""
|
| 1149 |
+
|
| 1150 |
+
# Only include segment analysis in output if requested
|
| 1151 |
+
if show_segment_details:
|
| 1152 |
+
output_text += "Segment Analysis:\n"
|
| 1153 |
+
for i, segment in enumerate(segments):
|
| 1154 |
+
output_text += f"\nSegment {i}:\n"
|
| 1155 |
+
output_text += f"Content: {segment['content']}\n"
|
| 1156 |
+
output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
|
| 1157 |
+
output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
|
| 1158 |
+
|
| 1159 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1160 |
+
if residue:
|
| 1161 |
+
if mods:
|
| 1162 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1163 |
+
else:
|
| 1164 |
+
sequence_parts.append(residue)
|
| 1165 |
+
output_text += f"Identified as: {residue}\n"
|
| 1166 |
+
output_text += f"Modifications: {mods}\n"
|
| 1167 |
+
else:
|
| 1168 |
+
output_text += f"Warning: Could not identify residue in segment: {segment['content']}\n"
|
| 1169 |
+
output_text += "\n"
|
| 1170 |
+
else:
|
| 1171 |
+
# Just build sequence without detailed analysis in output
|
| 1172 |
+
for segment in segments:
|
| 1173 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1174 |
+
if residue:
|
| 1175 |
+
if mods:
|
| 1176 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1177 |
+
else:
|
| 1178 |
+
sequence_parts.append(residue)
|
| 1179 |
+
|
| 1180 |
+
# Check if cyclic using analyzer's method
|
| 1181 |
+
is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
|
| 1182 |
+
three_letter = '-'.join(sequence_parts)
|
| 1183 |
+
one_letter = ''.join(analyzer.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence_parts)
|
| 1184 |
+
|
| 1185 |
+
if is_cyclic:
|
| 1186 |
+
three_letter = f"cyclo({three_letter})"
|
| 1187 |
+
one_letter = f"cyclo({one_letter})"
|
| 1188 |
+
|
| 1189 |
+
# Create cyclic structure visualization
|
| 1190 |
+
img_cyclic = annotate_cyclic_structure(mol, three_letter)
|
| 1191 |
+
|
| 1192 |
+
# Create linear representation if requested
|
| 1193 |
+
img_linear = None
|
| 1194 |
+
if show_linear:
|
| 1195 |
+
fig_linear = create_enhanced_linear_viz(three_letter, smiles)
|
| 1196 |
+
buf = BytesIO()
|
| 1197 |
+
fig_linear.savefig(buf, format='png', bbox_inches='tight', dpi=300)
|
| 1198 |
+
buf.seek(0)
|
| 1199 |
+
img_linear = Image.open(buf)
|
| 1200 |
+
plt.close(fig_linear)
|
| 1201 |
+
|
| 1202 |
+
# Add summary to output
|
| 1203 |
+
summary = "Summary:\n"
|
| 1204 |
+
summary += f"Sequence: {three_letter}\n"
|
| 1205 |
+
summary += f"One-letter code: {one_letter}\n"
|
| 1206 |
+
summary += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
|
| 1207 |
+
#if is_cyclic:
|
| 1208 |
+
#summary += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
|
| 1209 |
+
#summary += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
|
| 1210 |
+
|
| 1211 |
+
if structure_files:
|
| 1212 |
+
summary += "\n3D Structures Generated:\n"
|
| 1213 |
+
for filepath in structure_files:
|
| 1214 |
+
summary += f"- {os.path.basename(filepath)}\n"
|
| 1215 |
+
|
| 1216 |
+
return summary + output_text, img_cyclic, img_linear, structure_files if structure_files else None
|
| 1217 |
+
|
| 1218 |
+
except Exception as e:
|
| 1219 |
+
return f"Error processing SMILES: {str(e)}", None, None, None
|
| 1220 |
+
|
| 1221 |
+
# Handle file input
|
| 1222 |
+
if file_obj is not None:
|
| 1223 |
+
try:
|
| 1224 |
+
# Handle file content
|
| 1225 |
+
if hasattr(file_obj, 'name'):
|
| 1226 |
+
with open(file_obj.name, 'r') as f:
|
| 1227 |
+
content = f.read()
|
| 1228 |
+
else:
|
| 1229 |
+
content = file_obj.decode('utf-8') if isinstance(file_obj, bytes) else str(file_obj)
|
| 1230 |
+
|
| 1231 |
+
output_text = ""
|
| 1232 |
+
for line in content.splitlines():
|
| 1233 |
+
smiles = line.strip()
|
| 1234 |
+
if smiles:
|
| 1235 |
+
# Check if it's a peptide
|
| 1236 |
+
if not analyzer.is_peptide(smiles):
|
| 1237 |
+
output_text += f"Skipping non-peptide SMILES: {smiles}\n"
|
| 1238 |
+
continue
|
| 1239 |
+
|
| 1240 |
+
# Process this SMILES
|
| 1241 |
+
segments = analyzer.split_on_bonds(smiles)
|
| 1242 |
+
sequence_parts = []
|
| 1243 |
+
|
| 1244 |
+
# Add segment details if requested
|
| 1245 |
+
if show_segment_details:
|
| 1246 |
+
output_text += f"\nSegment Analysis for SMILES: {smiles}\n"
|
| 1247 |
+
for i, segment in enumerate(segments):
|
| 1248 |
+
output_text += f"\nSegment {i}:\n"
|
| 1249 |
+
output_text += f"Content: {segment['content']}\n"
|
| 1250 |
+
output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
|
| 1251 |
+
output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
|
| 1252 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1253 |
+
if residue:
|
| 1254 |
+
if mods:
|
| 1255 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1256 |
+
else:
|
| 1257 |
+
sequence_parts.append(residue)
|
| 1258 |
+
output_text += f"Identified as: {residue}\n"
|
| 1259 |
+
output_text += f"Modifications: {mods}\n"
|
| 1260 |
+
else:
|
| 1261 |
+
for segment in segments:
|
| 1262 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1263 |
+
if residue:
|
| 1264 |
+
if mods:
|
| 1265 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1266 |
+
else:
|
| 1267 |
+
sequence_parts.append(residue)
|
| 1268 |
+
|
| 1269 |
+
# Get cyclicity and create sequence
|
| 1270 |
+
is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
|
| 1271 |
+
sequence = f"cyclo({'-'.join(sequence_parts)})" if is_cyclic else '-'.join(sequence_parts)
|
| 1272 |
+
|
| 1273 |
+
output_text += f"\nSummary for SMILES: {smiles}\n"
|
| 1274 |
+
output_text += f"Sequence: {sequence}\n"
|
| 1275 |
+
output_text += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
|
| 1276 |
+
if is_cyclic:
|
| 1277 |
+
output_text += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
|
| 1278 |
+
#output_text += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
|
| 1279 |
+
output_text += "-" * 50 + "\n"
|
| 1280 |
+
|
| 1281 |
+
return output_text, None, None
|
| 1282 |
+
|
| 1283 |
+
except Exception as e:
|
| 1284 |
+
return f"Error processing file: {str(e)}", None, None
|
| 1285 |
+
|
| 1286 |
+
return "No input provided.", None, None
|
| 1287 |
+
|
utils/timer.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time, torch
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
|
| 5 |
+
class StepTimer:
|
| 6 |
+
def __init__(self, device=None):
|
| 7 |
+
self.times = defaultdict(list)
|
| 8 |
+
self.device = device
|
| 9 |
+
self._use_cuda_sync = (
|
| 10 |
+
isinstance(device, torch.device) and device.type == "cuda"
|
| 11 |
+
) or (isinstance(device, str) and "cuda" in device)
|
| 12 |
+
|
| 13 |
+
@contextmanager
|
| 14 |
+
def section(self, name):
|
| 15 |
+
if self._use_cuda_sync:
|
| 16 |
+
torch.cuda.synchronize()
|
| 17 |
+
t0 = time.perf_counter()
|
| 18 |
+
try:
|
| 19 |
+
yield
|
| 20 |
+
finally:
|
| 21 |
+
if self._use_cuda_sync:
|
| 22 |
+
torch.cuda.synchronize()
|
| 23 |
+
dt = time.perf_counter() - t0
|
| 24 |
+
self.times[name].append(dt)
|
| 25 |
+
|
| 26 |
+
def summary(self, top_k=None):
|
| 27 |
+
# returns (name, count, total, mean, p50, p95)
|
| 28 |
+
import numpy as np
|
| 29 |
+
rows = []
|
| 30 |
+
for k, v in self.times.items():
|
| 31 |
+
a = np.array(v, dtype=float)
|
| 32 |
+
rows.append((k, len(a), a.sum(), a.mean(), np.median(a), np.percentile(a, 95)))
|
| 33 |
+
rows.sort(key=lambda r: r[2], reverse=True) # by total time
|
| 34 |
+
return rows[:top_k] if top_k else rows
|
utils/utils.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Console logger utilities.
|
| 2 |
+
|
| 3 |
+
Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
|
| 4 |
+
Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import fsspec
|
| 9 |
+
import lightning
|
| 10 |
+
import torch
|
| 11 |
+
from timm.scheduler import CosineLRScheduler
|
| 12 |
+
import argparse
|
| 13 |
+
import numpy as np
|
| 14 |
+
import random
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
def sample_categorical_logits(logits, dtype=torch.float64):
|
| 18 |
+
# do not require logits to be log-softmaxed
|
| 19 |
+
gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log()
|
| 20 |
+
return (logits + gumbel_noise).argmax(dim=-1)
|
| 21 |
+
|
| 22 |
+
def fsspec_exists(filename):
|
| 23 |
+
"""Check if a file exists using fsspec."""
|
| 24 |
+
fs, _ = fsspec.core.url_to_fs(filename)
|
| 25 |
+
return fs.exists(filename)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def fsspec_listdir(dirname):
|
| 29 |
+
"""Listdir in manner compatible with fsspec."""
|
| 30 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
| 31 |
+
return fs.ls(dirname)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def fsspec_mkdirs(dirname, exist_ok=True):
|
| 35 |
+
"""Mkdirs in manner compatible with fsspec."""
|
| 36 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
| 37 |
+
fs.makedirs(dirname, exist_ok=exist_ok)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def print_nans(tensor, name):
|
| 41 |
+
if torch.isnan(tensor).any():
|
| 42 |
+
print(name, tensor)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class CosineDecayWarmupLRScheduler(
|
| 46 |
+
CosineLRScheduler,
|
| 47 |
+
torch.optim.lr_scheduler._LRScheduler):
|
| 48 |
+
|
| 49 |
+
def __init__(self, *args, **kwargs):
|
| 50 |
+
super().__init__(*args, **kwargs)
|
| 51 |
+
self._last_epoch = -1
|
| 52 |
+
self.step(epoch=0)
|
| 53 |
+
|
| 54 |
+
def step(self, epoch=None):
|
| 55 |
+
if epoch is None:
|
| 56 |
+
self._last_epoch += 1
|
| 57 |
+
else:
|
| 58 |
+
self._last_epoch = epoch
|
| 59 |
+
# We call either step or step_update, depending on
|
| 60 |
+
# whether we're using the scheduler every epoch or every
|
| 61 |
+
# step.
|
| 62 |
+
# Otherwise, lightning will always call step (i.e.,
|
| 63 |
+
# meant for each epoch), and if we set scheduler
|
| 64 |
+
# interval to "step", then the learning rate update will
|
| 65 |
+
# be wrong.
|
| 66 |
+
if self.t_in_epochs:
|
| 67 |
+
super().step(epoch=self._last_epoch)
|
| 68 |
+
else:
|
| 69 |
+
super().step_update(num_updates=self._last_epoch)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class LoggingContext:
|
| 73 |
+
"""Context manager for selective logging."""
|
| 74 |
+
def __init__(self, logger, level=None, handler=None, close=True):
|
| 75 |
+
self.logger = logger
|
| 76 |
+
self.level = level
|
| 77 |
+
self.handler = handler
|
| 78 |
+
self.close = close
|
| 79 |
+
|
| 80 |
+
def __enter__(self):
|
| 81 |
+
if self.level is not None:
|
| 82 |
+
self.old_level = self.logger.level
|
| 83 |
+
self.logger.setLevel(self.level)
|
| 84 |
+
if self.handler:
|
| 85 |
+
self.logger.addHandler(self.handler)
|
| 86 |
+
|
| 87 |
+
def __exit__(self, et, ev, tb):
|
| 88 |
+
if self.level is not None:
|
| 89 |
+
self.logger.setLevel(self.old_level)
|
| 90 |
+
if self.handler:
|
| 91 |
+
self.logger.removeHandler(self.handler)
|
| 92 |
+
if self.handler and self.close:
|
| 93 |
+
self.handler.close()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
|
| 97 |
+
"""Initializes multi-GPU-friendly python logger."""
|
| 98 |
+
|
| 99 |
+
logger = logging.getLogger(name)
|
| 100 |
+
logger.setLevel(level)
|
| 101 |
+
|
| 102 |
+
# this ensures all logging levels get marked with the rank zero decorator
|
| 103 |
+
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
| 104 |
+
for level in ('debug', 'info', 'warning', 'error',
|
| 105 |
+
'exception', 'fatal', 'critical'):
|
| 106 |
+
setattr(logger,
|
| 107 |
+
level,
|
| 108 |
+
lightning.pytorch.utilities.rank_zero_only(
|
| 109 |
+
getattr(logger, level)))
|
| 110 |
+
|
| 111 |
+
return logger
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def str2bool(v):
|
| 115 |
+
if isinstance(v, bool):
|
| 116 |
+
return v
|
| 117 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 118 |
+
return True
|
| 119 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 120 |
+
return False
|
| 121 |
+
else:
|
| 122 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def set_seed(seed, use_cuda):
|
| 126 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 127 |
+
np.random.seed(seed)
|
| 128 |
+
random.seed(seed)
|
| 129 |
+
torch.manual_seed(seed)
|
| 130 |
+
# torch.backends.cudnn.deterministic = True
|
| 131 |
+
if use_cuda:
|
| 132 |
+
torch.cuda.manual_seed(seed)
|
| 133 |
+
torch.cuda.manual_seed_all(seed)
|
| 134 |
+
print(f'=> Seed of the run set to {seed}')
|
| 135 |
+
|