Sophia Tang commited on
Commit
5e90249
·
1 Parent(s): 9aa9a1f

Initial commit

Browse files
Files changed (40) hide show
  1. .gitattributes +4 -0
  2. .gitignore +17 -0
  3. README.md +46 -0
  4. assets/anim-good.gif +3 -0
  5. assets/peptides.png +3 -0
  6. tr2d2-pep/README.md +41 -0
  7. tr2d2-pep/configs/peptune_config.yaml +159 -0
  8. tr2d2-pep/diffusion.py +1526 -0
  9. tr2d2-pep/finetune.py +133 -0
  10. tr2d2-pep/finetune.sh +37 -0
  11. tr2d2-pep/finetune_peptides.py +193 -0
  12. tr2d2-pep/finetune_utils.py +138 -0
  13. tr2d2-pep/generate_mcts.py +192 -0
  14. tr2d2-pep/metrics.py +71 -0
  15. tr2d2-pep/noise_schedule.py +150 -0
  16. tr2d2-pep/peptide_mcts.py +648 -0
  17. tr2d2-pep/plotting.py +148 -0
  18. tr2d2-pep/roformer.py +74 -0
  19. tr2d2-pep/run_mcts.sh +29 -0
  20. tr2d2-pep/scoring/functions/binding.py +178 -0
  21. tr2d2-pep/scoring/functions/binding_utils.py +290 -0
  22. tr2d2-pep/scoring/functions/classifiers/hemolysis-xgboost.json +0 -0
  23. tr2d2-pep/scoring/functions/classifiers/nonfouling-xgboost.json +0 -0
  24. tr2d2-pep/scoring/functions/classifiers/permeability-xgboost.json +3 -0
  25. tr2d2-pep/scoring/functions/classifiers/solubility-xgboost.json +0 -0
  26. tr2d2-pep/scoring/functions/hemolysis.py +63 -0
  27. tr2d2-pep/scoring/functions/nonfouling.py +66 -0
  28. tr2d2-pep/scoring/functions/permeability.py +171 -0
  29. tr2d2-pep/scoring/functions/scoring_utils.py +94 -0
  30. tr2d2-pep/scoring/functions/solubility.py +63 -0
  31. tr2d2-pep/scoring/scoring_functions.py +77 -0
  32. tr2d2-pep/scoring/tokenizer/my_tokenizers.py +424 -0
  33. tr2d2-pep/scoring/tokenizer/new_splits.txt +159 -0
  34. tr2d2-pep/scoring/tokenizer/new_vocab.txt +587 -0
  35. tr2d2-pep/tokenizer/my_tokenizers.py +424 -0
  36. tr2d2-pep/tokenizer/new_splits.txt +159 -0
  37. tr2d2-pep/tokenizer/new_vocab.txt +587 -0
  38. tr2d2-pep/utils/app.py +1255 -0
  39. tr2d2-pep/utils/timer.py +34 -0
  40. tr2d2-pep/utils/utils.py +135 -0
.gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ tr2d2-pep/scoring/functions/classifiers/permeability-xgboost.json filter=lfs diff=lfs merge=lfs -text
2
+ assets/peptides.png filter=lfs diff=lfs merge=lfs -text
3
+ assets/tr2d2-anim.gif filter=lfs diff=lfs merge=lfs -text
4
+ assets/anim-good.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .venv/
4
+ .env
5
+ .DS_Store
6
+ tr2d2-pep/wandb/
7
+ tr2d2-pep/pretrained/
8
+ tr2d2-pep/logs/
9
+ tr2d2-pep/__pycache__/
10
+ tr2d2-pep/scoring/__pycache__/
11
+ tr2d2-pep/tokenizer/__pycache__/
12
+ tr2d2-pep/utils/__pycache__/
13
+ *.pyc
14
+ *.log
15
+ *.ipynb
16
+ *.pt
17
+ tr2d2-pep/scoring/functions/classifiers/best_model.pt
README.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [TR2-D2: Tree Search Guided Trajectory-Aware Fine-Tuning for Discrete Diffusion](https://arxiv.org/abs/2509.25171) 🤖🌳
2
+
3
+
4
+
5
+ [**Sophia Tang**](https://sophtang.github.io/)\*, [**Yuchen Zhu**](https://yuchen-zhu-zyc.github.io/)\*, [**Molei Tao**](https://mtao8.math.gatech.edu/), and [**Pranam Chatterjee**](https://www.chatterjeelab.com/)
6
+
7
+ ![TR2-D2](assets/anim-good.gif)
8
+
9
+ This is the repository for **[TR2-D2: Tree Search Guided Trajectory-Aware Fine-Tuning for Discrete Diffusion](https://arxiv.org/abs/2509.25171)** 🤖🌳. It is partially built on the **[PepTune repo](https://github.com/programmablebio/peptune)** ([Tang et al. 2024](https://arxiv.org/abs/2412.17780)) and **MDNS** ([Zhu et al. 2025](https://arxiv.org/abs/2508.10684)).
10
+
11
+ Inspired by the incredible success of off-policy reinforcement learning (RL), **TR2-D2** introduces a general framework that enhances the performance of off-policy RL with tree search for discrete diffusion fine-tuning.
12
+
13
+ 🤖 Off-policy RL enables learning from diffusion trajectories from the non-gradient tracking policy model by storing sampling trajectories in a replay buffer for repeated use.
14
+
15
+ 🌳 Tree search balances exploration and exploitation to generate optimal diffusion trajectories, and stores the optimal sequences in the buffer.
16
+
17
+ We use this framework to develop an efficient discrete diffusion fine-tuning strategy that leverages **Monte-Carlo Tree Search (MCTS)** to curate a replay buffer of optimal trajectories combined with an **off-policy control-based RL algorithm grounded in stochastic optimal control theory**, yielding theoretically guaranteed convergence to the optimal distribution. 🌟
18
+
19
+ ### Regulatory DNA Sequence Design 🧬
20
+
21
+ ---
22
+
23
+ In this experiment, we fine-tune the pre-trained **DNA enhancer MDM from DRAKES** (Wang et al. 2025) trained on **~700k HepG2 sequences** to optimize the measured enhancer activity using the reward oracles from DRAKES. Code and instructions to reproduce our results are provided in `/tr2d2-dna`.
24
+
25
+ ### Multi-Objective Therapeutic Peptide Design 🧫
26
+
27
+ ---
28
+
29
+ In this experiment, we fine-tune the pre-trained **unconditional peptide SMILES MDM from PepTune** ([Tang et al. 2024](https://arxiv.org/abs/2412.17780)) ****to optimize **multiple therapeutic properties**, including target protein binding affinity, solubility, non-hemolysis, non-fouling, and permeability. We show that one-shot generation from the fine-tuned policy outperforms inference-time multi-objective guidance, marking a significant advance over prior fine-tuning methods. Code and instructions to reproduce our results are provided in `/tr2d2-pep`.
30
+
31
+ ![TR2-D2 for Multi-Objective Peptide Design](assets/peptides.png)
32
+
33
+ ## Citation
34
+
35
+ If you find this repository helpful for your publications, please consider citing our paper:
36
+
37
+ ```python
38
+ @article{tang2024tr2d2,
39
+ title={TR2-D2: Tree Search Guided Trajectory-Aware Fine-Tuning for Discrete Diffusion},
40
+ author={Sophia Tang and Yuchen Zhu and Molei Tao and Pranam Chatterjee},
41
+ journal={arXiv preprint arXiv:2509.25171},
42
+ year={2025}
43
+ }
44
+ ```
45
+
46
+ To use this repository, you agree to abide by the [PepTune License](https://drive.google.com/file/d/1Hsu91wTmxyoJLNJzfPDw5_nTbxVySP5x/view?usp=sharing).
assets/anim-good.gif ADDED

Git LFS Details

  • SHA256: 3e75f6abc926c1b5817d7b8c56485ea3492b355851ce0d70e44d57b57ce1aaf8
  • Pointer size: 132 Bytes
  • Size of remote file: 3.83 MB
assets/peptides.png ADDED

Git LFS Details

  • SHA256: 19c242f4b9842bcc37c2869b5f03bc56681c0afdc6312db07c22ba001bb89206
  • Pointer size: 132 Bytes
  • Size of remote file: 4.16 MB
tr2d2-pep/README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TR2-D2 For Multi-Objective Therapeutic Peptide Design 🧫
2
+
3
+ This part of the code is for finetuning a peptide MDM to optimize multiple therapeutic properties, including binding affinity to a protein target, solubility, non-hemolysis, non-fouling, and cell membrane permeability, with TR2-D2.
4
+
5
+ The codebase is built upon [PepTune (Tang et.al, 2024)](https://arxiv.org/abs/2412.17780), [MDLM (Sahoo et.al, 2023)](https://github.com/kuleshov-group/mdlm), [SEPO (Zekri et.al, 2025)](https://github.com/ozekri/SEPO/tree/main), and [MDNS (Zhu et.al, 2025)](https://arxiv.org/abs/2508.10684).
6
+
7
+ ## Environment Installation
8
+ ```
9
+ conda env create -f environment.yml
10
+
11
+ conda activate tr2d2-pep
12
+ ```
13
+
14
+ ## Model Pretrained Weights Download
15
+
16
+ Follow the steps below to download the model weights requried for this experiment, which is originally from [PepTune](https://arxiv.org/abs/2412.17780).
17
+
18
+ 1. Download the PepTune pre-trained MDLM and place in `/TR2-D2/peptides/pretrained/`: https://drive.google.com/file/d/1oXGDpKLNF0KX0ZdOcl1NZj5Czk2lSFUn/view?usp=sharing
19
+ 2. Download the pre-trained binding affinity Transformer model and place in `/TR2-D2/tr2d2-pep/scoring/functions/classifiers/`: https://drive.google.com/file/d/128shlEP_-rYAxPgZRCk_n0HBWVbOYSva/view?usp=sharing
20
+
21
+
22
+
23
+ ## Finetune with TR2-D2
24
+ After downloading the pretrained checkpoints, follow the steps below to run fine-tuning:
25
+
26
+ 1. Fill in the `base_path` in `scoring/scoring_functions.py` and `diffusion.py`.
27
+ 2. Fill in `HOME_LOC` to the base path where `TR2-D2` is located and `ENV_PATH` to the directory where your environment is downloaded in `finetune.sh`.
28
+ 3. Create a path `tr2d2-pep/results` where the fine-tuning curves and generation results will be saved and `tr2d2-pep/checkpoints` for checkpoint saving. Also, create `tr2d2-pep/logs` where the training logs will be saved.
29
+ 3. To specify a target protein, set `--prot_seq <insert amino acid sequence>` and `--prot_name <insert protein name>`. Default protein is Transferrin receptor (TfR).
30
+
31
+ Run fine-tuning using `nohup` with the following commands:
32
+ ```
33
+ chmod +x finetune.sh
34
+
35
+ nohup ./finetune.sh > finetune.log 2>&1 &
36
+ ```
37
+ Evaluation will run automatically after the specified number of fine-tuning epochs `--num_epochs` is finished. To summarize metrics, fill in `path` and `prot_name` in `metrics.py` and run:
38
+
39
+ ```
40
+ python metrics.py
41
+ ```
tr2d2-pep/configs/peptune_config.yaml ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ noise:
2
+ type: loglinear
3
+ sigma_min: 1e-4
4
+ sigma_max: 20
5
+ state_dependent: True
6
+
7
+ mode: ppl_eval # train / ppl_eval / sample_eval
8
+ diffusion: absorbing_state
9
+ vocab: old_smiles # old_smiles / new_smiles / selfies / helm
10
+ backbone: roformer # peptideclm / helmgpt / dit / roformer / finetune_roformer
11
+ parameterization: subs # subs
12
+ time_conditioning: False
13
+ T: 0 # 0 (continuous time) / 1000
14
+ subs_masking: False
15
+
16
+ seed: 42
17
+
18
+ mcts:
19
+ num_children: 50
20
+ num_objectives: 5
21
+ topk: 100
22
+ mask_token: 4
23
+ num_iter: 128
24
+ sampling: 0 # 0 is gumbel sampling / > 0 samples children from top k probs
25
+ invalid_penalty: 0.5
26
+ sample_prob: 1.0
27
+ perm: True
28
+ dual: False
29
+ single: False
30
+ time_dependent: True
31
+
32
+ lr_scheduler:
33
+ _target_: transformers.get_constant_schedule_with_warmup
34
+ num_warmup_steps: 2500
35
+
36
+ data:
37
+ train: /home/st512/peptune/scripts/peptide-mdlm-mcts/data/finetune2/30K-train.csv
38
+ valid: /home/st512/peptune/scripts/peptide-mdlm-mcts/data/finetune2/30K-val.csv
39
+ batchinohup ng: wrapping # padding / wrapping
40
+
41
+ loader:
42
+ global_batch_size: 64
43
+ eval_global_batch_size: ${.global_batch_size}
44
+ # Note: batch_size and eval_batch_size are **per machine**
45
+ batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
46
+ eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
47
+ num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
48
+ pin_memory: True
49
+
50
+ sampling:
51
+ predictor: ddpm_cache # analytic, ddpm, ddpm_cache
52
+ num_sequences: 100
53
+ sampling_eps: 1e-3
54
+ steps: 128
55
+ seq_length: 100
56
+ noise_removal: True
57
+ num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
58
+ num_sample_log: 2
59
+ stride_length: 1
60
+ num_strides: 1
61
+
62
+ training:
63
+ antithetic_sampling: True
64
+ sampling_eps: 1e-3
65
+ focus_mask: False
66
+ #dynamic_batching: True
67
+ accumulator: False
68
+
69
+ eval:
70
+ checkpoint_path:
71
+ disable_ema: False
72
+ compute_generative_perplexity: False
73
+ perplexity_batch_size: 8
74
+ compute_perplexity_on_sanity: False
75
+ gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
76
+ generate_samples: True
77
+ generation_model:
78
+
79
+ optim:
80
+ weight_decay: 0.075
81
+ lr: 3e-4
82
+ beta1: 0.9
83
+ beta2: 0.999
84
+ eps: 1e-8
85
+
86
+ pepclm:
87
+ hidden_size: 768
88
+ cond_dim: 256
89
+ n_heads: 20
90
+ n_blocks: 4
91
+ dropout: 0.5
92
+ length: 512
93
+ #scale_by_sigma: True
94
+
95
+ model:
96
+ type: ddit
97
+ hidden_size: 768
98
+ cond_dim: 128
99
+ length: 512
100
+ n_blocks: 12
101
+ n_heads: 12
102
+ scale_by_sigma: True
103
+ dropout: 0.1
104
+
105
+ roformer:
106
+ hidden_size: 768
107
+ n_layers: 8
108
+ n_heads: 8
109
+ max_position_embeddings: 1035
110
+
111
+ helmgpt:
112
+ hidden_size: 256
113
+ embd_pdrop: 0.1
114
+ resid_pdrop: 0.1
115
+ attn_pdrop: 0.1
116
+ ff_dropout: 0.
117
+ block_size: 140
118
+ n_layer: 8
119
+ n_heads: 8
120
+
121
+
122
+ trainer:
123
+ _target_: lightning.Trainer
124
+ accelerator: cuda
125
+ num_nodes: 1
126
+ devices: ${device_count:}
127
+ accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
128
+ gradient_clip_val: 1.0
129
+ precision: 64-true
130
+ num_sanity_val_steps: 2
131
+ max_epochs: 100
132
+ max_steps: 1_000_000
133
+ log_every_n_steps: 10
134
+ limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
135
+ limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
136
+ #val_check_interval: 40 #954
137
+ check_val_every_n_epoch: 1
138
+
139
+ hydra:
140
+ run:
141
+ dir: ./${now:%Y.%m.%d}/
142
+ job:
143
+ chdir: True
144
+
145
+ checkpointing:
146
+ # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
147
+ save_dir: ${cwd:}
148
+ # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
149
+ resume_from_ckpt: True
150
+ resume_ckpt_path:
151
+
152
+ callbacks:
153
+ model_checkpoint:
154
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
155
+ every_n_epochs: 1
156
+ monitor: "val/nll"
157
+ save_top_k: 10
158
+ mode: "min"
159
+ dirpath:
tr2d2-pep/diffusion.py ADDED
@@ -0,0 +1,1526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import sys
3
+ import itertools
4
+ import time
5
+ import torch
6
+ from torch import Tensor
7
+ import math
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import random as rd
11
+ import lightning as L
12
+ import torchmetrics
13
+ from dataclasses import dataclass
14
+ import gc
15
+ import utils.utils as utils
16
+
17
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
18
+ import noise_schedule
19
+ from torch.optim.lr_scheduler import _LRScheduler
20
+ import roformer as roformer
21
+ from utils.app import PeptideAnalyzer
22
+ import pandas as pd
23
+
24
+ base_path = '/path/to/your/home'
25
+
26
+ def _sample_categorical(categorical_probs):
27
+ gumbel_norm = (
28
+ 1e-10
29
+ - (torch.rand_like(categorical_probs) + 1e-10).log())
30
+ return (categorical_probs / gumbel_norm).argmax(dim=-1).to(dtype=torch.long)
31
+
32
+ def _sample_categorical_gradient(categorical_probs, temp = 1.0):
33
+ gumbel_norm = (
34
+ 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log())
35
+ output = torch.nn.functional.softmax((torch.log(categorical_probs)-torch.log(gumbel_norm))/temp, 2)
36
+ return output
37
+
38
+ def _unsqueeze(x, reference):
39
+ return x.view(
40
+ * x.shape,
41
+ * ((1,) * (len(reference.shape) - len(x.shape))))
42
+
43
+ def sample_batched_categorical(categorical_probs, batch_size):
44
+ """
45
+ Generates `m` distinct sequences sampled from categorical probabilities
46
+ using the Gumbel distribution to ensure randomness while following probabilities
47
+
48
+ Args:
49
+ categorical_probs (torch.Tensor): tensor of shape (sequence_length, vocab_length)
50
+ representing categorical probabilities
51
+ m (int): number of distinct sequences to sample
52
+
53
+ Returns:
54
+ torch.Tensor: tensor of shape (m, sequence_length), where each row is a
55
+ distinct sequence of sampled category indices.
56
+ """
57
+ _, sequence_length, vocab_size = categorical_probs.shape
58
+
59
+ # add Gumbel noise and sample m sequences
60
+ gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_size) + 1e-10) + 1e-10)).to(categorical_probs.device)
61
+ noisy_scores = torch.log(categorical_probs) + gumbel_noise # add Gumbel noise to log probabilities
62
+
63
+ # select the highest score (most likely category after Gumbel noise)
64
+ sampled_sequences = noisy_scores.argmax(dim=-1).to(dtype=torch.long) # shape: (m, sequence_length)
65
+
66
+ return sampled_sequences
67
+
68
+ def sample_batched_top_k(categorical_probs, batch_size, k):
69
+ """
70
+ Generates `m` sequences sampled from the top-k probabilities of each token
71
+ using Gumbel noise to ensure randomness and reduce bias towards the most likely options.
72
+
73
+ Args:
74
+ categorical_probs (torch.Tensor): A tensor of shape (sequence_length, vocab_length)
75
+ representing categorical probabilities.
76
+ m (int): Number of sequences to sample.
77
+ k (int): Number of top probabilities to consider for sampling.
78
+
79
+ Returns:
80
+ torch.Tensor: A tensor of shape (m, sequence_length), where each row is a
81
+ sampled sequence of category indices.
82
+ """
83
+ _, sequence_length, vocab_length = categorical_probs.shape
84
+
85
+ # Add Gumbel noise to the log probabilities
86
+ gumbel_noise = -torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_length) + 1e-10) + 1e-10).to(categorical_probs.device)
87
+ noisy_scores = torch.log(categorical_probs[None, :, :]) + gumbel_noise # Shape: (m, sequence_length, vocab_length)
88
+
89
+ # Get the top-k categories based on noisy scores
90
+ top_k_scores, top_k_indices = torch.topk(noisy_scores, k, dim=-1) # Shape: (m, sequence_length, k)
91
+
92
+ # Convert top-k scores back to probabilities and normalize
93
+ top_k_probs = torch.softmax(top_k_scores, dim=-1).to(categorical_probs.device) # Shape: (m, sequence_length, k)
94
+
95
+ # Sample randomly from the top-k probabilities
96
+ sampled_indices_in_top_k = torch.multinomial(top_k_probs.reshape(-1, k), num_samples=1).squeeze(-1).to(categorical_probs.device)
97
+ sampled_indices_in_top_k = sampled_indices_in_top_k.view(batch_size, sequence_length).to(categorical_probs.device) # Shape: (batch_size, sequence_length)
98
+
99
+ # Map sampled indices back to the original vocabulary indices
100
+ sampled_sequences = torch.gather(top_k_indices, -1, sampled_indices_in_top_k.unsqueeze(-1)).squeeze(-1).to(categorical_probs.device).to(dtype=torch.long)
101
+
102
+ return sampled_sequences
103
+
104
+ @dataclass
105
+ class Loss:
106
+ loss: torch.FloatTensor
107
+ nlls: torch.FloatTensor
108
+ attn_mask: torch.FloatTensor
109
+
110
+
111
+ class NLL(torchmetrics.aggregation.MeanMetric):
112
+ pass
113
+
114
+
115
+ class BPD(NLL):
116
+ def compute(self) -> Tensor:
117
+ """Computes the bits per dimension.
118
+
119
+ Returns:
120
+ bpd
121
+ """
122
+ return self.mean_value / self.weight / math.log(2)
123
+
124
+
125
+ class Perplexity(NLL):
126
+ def compute(self) -> Tensor:
127
+ """Computes the Perplexity.
128
+
129
+ Returns:
130
+ Perplexity
131
+ """
132
+ return torch.exp(self.mean_value / self.weight)
133
+
134
+
135
+ class Diffusion(L.LightningModule):
136
+ def __init__(
137
+ self,
138
+ config,
139
+ tokenizer = None,
140
+ mode="finetune",
141
+ device=None,
142
+ ):
143
+
144
+ super().__init__()
145
+ self.config = config
146
+ #self.save_hyperparameters()
147
+
148
+ # PeptideCLM tokenizer
149
+ if tokenizer is None:
150
+ self.tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/TR2-D2/tr2d2-pep/tokenizer/new_vocab.txt',
151
+ f'{base_path}/TR2-D2/tr2d2-pep/tokenizer/new_splits.txt')
152
+ else:
153
+ self.tokenizer = tokenizer
154
+
155
+ self.vocab_size = self.tokenizer.vocab_size
156
+ self.mask_index = self.tokenizer.mask_token_id
157
+ self.sampler = self.config.sampling.predictor
158
+ self.analyzer = PeptideAnalyzer()
159
+
160
+ # backbone LM PeptideCLM model
161
+ self.backbone = roformer.Roformer(self.config, self.tokenizer, device=device)
162
+ if mode == "finetune":
163
+ self.backbone.freeze_model()
164
+ self.backbone.unfreeze_n_layers(n=8)
165
+ elif mode == "eval":
166
+ self.backbone.freeze_model()
167
+ self.backbone.requires_grad_(False)
168
+ self.backbone.eval()
169
+ elif mode == "train":
170
+ self.backbone.requires_grad_(True)
171
+ self.backbone.train()
172
+
173
+ self.neg_infinity = -1000000.0
174
+ self.T = config.T
175
+ # noise schedule for non-peptide bond tokens (default to log-linear)
176
+ self.noise = noise_schedule.get_noise(config)
177
+
178
+ # noise schedule for peptide bonds (log-polynomial)
179
+ self.bond_noise = noise_schedule.LogPolyNoise()
180
+ self.time_conditioning = self.config.time_conditioning
181
+ self.fast_forward_epochs = None
182
+ self.fast_forward_batches = None
183
+
184
+ self.gen_ppl_eval_model_name_or_path = self.config.eval.gen_ppl_eval_model_name_or_path
185
+ self.gen_ppl_metric = Perplexity()
186
+
187
+ self.lr = self.config.optim.lr
188
+ self.sampling_eps = self.config.training.sampling_eps
189
+
190
+ metrics = torchmetrics.MetricCollection({
191
+ 'nll': NLL(),
192
+ 'bpd': BPD(),
193
+ 'ppl': Perplexity(),
194
+ })
195
+ metrics.set_dtype(torch.float64)
196
+ self.train_metrics = metrics.clone(prefix='trainer/')
197
+ self.valid_metrics = metrics.clone(prefix='val/')
198
+ self.test_metrics = metrics.clone(prefix='test/')
199
+
200
+ ### FOR THE EXPANSION AND ROLLOUT STEP ###
201
+ def sample_finetuned_with_rnd(self, args, reward_model, pretrained, eps=1e-5):
202
+ num_steps = args.total_num_steps
203
+ B = args.batch_size
204
+ x_rollout = self.sample_prior(
205
+ B, args.seq_length).to(self.device)
206
+
207
+ log_rnd = torch.zeros(args.batch_size, device=self.device)
208
+
209
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
210
+ dt = (1 - eps) / num_steps
211
+
212
+ for i in range(num_steps):
213
+ t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
214
+
215
+ log_p, x_next, log_policy_step, log_pretrained_step = \
216
+ self.mcts_reverse_step(x_rollout, t=t, dt=dt, pretrained=pretrained)
217
+
218
+ log_rnd += log_pretrained_step - log_policy_step
219
+
220
+ x_rollout = x_next
221
+
222
+ # if mask token remains, fully unmask
223
+ mask_positions = (x_rollout == self.mask_index) # (B, L) bool
224
+
225
+ # does **any** mask remain in any sequence
226
+ any_mask_global = mask_positions.any().item() # true if mask remains
227
+ if any_mask_global:
228
+ log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
229
+
230
+ x_rollout = x_next
231
+
232
+ childSequences = self.tokenizer.batch_decode(x_rollout)
233
+
234
+ # change rewards for peptides
235
+ valid_x_final = []
236
+ validSequences = []
237
+ valid_log_rnd = []
238
+
239
+ for i in range(B):
240
+ # string sequence
241
+ childSeq = childSequences[i]
242
+
243
+ # check if the peptide is valid
244
+ if self.analyzer.is_peptide(childSeq):
245
+ valid_x_final.append(x_rollout[i])
246
+ validSequences.append(childSeq)
247
+ valid_log_rnd.append(log_rnd[i])
248
+
249
+ # compute multi-objective rewards
250
+ score_vectors = reward_model(input_seqs=validSequences)
251
+ scalar_rewards = np.sum(score_vectors, axis=-1)
252
+ scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=self.device)
253
+
254
+ print(f"scalar reward dim{len(scalar_rewards)}")
255
+ valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
256
+
257
+ log_rnd = valid_log_rnd + (scalar_rewards / args.alpha) # scale down by alpha
258
+ valid_x_final = torch.stack(valid_x_final, dim=0)
259
+
260
+ return valid_x_final, log_rnd, scalar_rewards
261
+
262
+ def sample_finetuned(self, args, reward_model, batch_size=None, dataframe=False, eps=1e-5):
263
+ torch.cuda.empty_cache()
264
+ self.backbone.eval()
265
+ self.noise.eval()
266
+ print(f"device:{self.device}")
267
+
268
+ if batch_size is None:
269
+ batch_size = args.batch_size
270
+
271
+ num_steps = args.total_num_steps
272
+ x_rollout = self.sample_prior(
273
+ batch_size,
274
+ args.seq_length).to(self.device, dtype=torch.long)
275
+
276
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
277
+ dt = torch.tensor((1 - eps) / num_steps, device=self.device)
278
+
279
+ for i in range(num_steps):
280
+ t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
281
+
282
+ log_p, x_next = self.single_reverse_step(x_rollout, t=t, dt=dt)
283
+
284
+ x_rollout = x_next
285
+ x_rollout = x_rollout.to(self.device)
286
+
287
+ # if mask token remains, fully unmask
288
+ mask_positions = (x_rollout == self.mask_index) # (B, L) bool
289
+
290
+ # does **any** mask remain in any sequence
291
+ any_mask_global = mask_positions.any().item() # true if mask remains
292
+ if any_mask_global:
293
+ log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
294
+
295
+ x_rollout = x_next
296
+ x_rollout = x_rollout.to(self.device)
297
+
298
+ childSequences = self.tokenizer.batch_decode(x_rollout)
299
+ valid_x_final = []
300
+ validSequences = []
301
+
302
+ for idx, seq in enumerate(childSequences):
303
+ if self.analyzer.is_peptide(seq):
304
+ valid_x_final.append(x_rollout[idx])
305
+ validSequences.append(seq)
306
+
307
+ valid_fraction = len(validSequences) / batch_size
308
+
309
+ if (len(validSequences) != 0):
310
+ # add scores to log
311
+ score_vectors = reward_model(input_seqs=validSequences) # (num_children, num_objectives)
312
+ average_scores = score_vectors.T
313
+
314
+ affinity = average_scores[0]
315
+ sol = average_scores[1]
316
+ hemo = average_scores[2]
317
+ nf = average_scores[3]
318
+ permeability = average_scores[4]
319
+
320
+ else:
321
+ zeros = [0.0]
322
+
323
+ affinity = zeros
324
+ sol = zeros
325
+ hemo = zeros
326
+ nf = zeros
327
+ permeability = zeros
328
+
329
+ if dataframe:
330
+ df = pd.DataFrame({
331
+ "Peptide Sequence": validSequences,
332
+ "Binding Affinity": affinity if len(validSequences) else [0.0],
333
+ "Solubility": sol if len(validSequences) else [0.0],
334
+ "Hemolysis": hemo if len(validSequences) else [0.0],
335
+ "Nonfouling": nf if len(validSequences) else [0.0],
336
+ "Permeability": permeability if len(validSequences) else [0.0],
337
+ })
338
+ return x_rollout, affinity, sol, hemo, nf, permeability, valid_fraction, df
339
+
340
+ return x_rollout, affinity, sol, hemo, nf, permeability, valid_fraction
341
+
342
+ def compute_log_policy(self, token_array, x_next, t, dt, attn_mask=None):
343
+ torch.cuda.empty_cache()
344
+ self.backbone.eval()
345
+ self.noise.eval()
346
+
347
+ sigma_t, _ = self.noise(t)
348
+
349
+ if token_array.ndim == 1:
350
+ token_array = token_array.unsqueeze(0)
351
+
352
+ if x_next.ndim == 1:
353
+ x_next = x_next.unsqueeze(0)
354
+
355
+ if t.ndim > 1:
356
+ t = t.squeeze(-1)
357
+ assert t.ndim == 1
358
+
359
+ change_prob_t = t[:, None, None]
360
+ change_prob_s = (t - dt)[:, None, None]
361
+
362
+ assert change_prob_t.ndim == 3, change_prob_t.shape
363
+
364
+ if attn_mask is None:
365
+ attn_mask = torch.ones_like(token_array).to(self.device)
366
+
367
+ log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
368
+ p_x0 = log_p.exp()
369
+
370
+ assert change_prob_t.ndim == p_x0.ndim
371
+
372
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
373
+
374
+ # zero-masking probability
375
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
376
+
377
+ copy_flag = (token_array != self.mask_index)
378
+
379
+ assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
380
+ changed_mask = (~copy_flag)
381
+
382
+ # compute the per-sequence log-probability under the pretrained model
383
+ log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1)
384
+
385
+ unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_policy_token.dtype)
386
+ log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
387
+
388
+ # returns:
389
+ # log_policy_step (B, ) log probability x_next tokens under policy
390
+ if log_policy_step.ndim == 1:
391
+ log_policy_step = log_policy_step.squeeze(0)
392
+
393
+ return log_policy_step
394
+
395
+
396
+ def single_reverse_step(self, token_array, t, dt, p_x0=None, attn_mask=None):
397
+ torch.cuda.empty_cache()
398
+ dev = self.device
399
+ self.backbone.to(dev).eval()
400
+ self.noise.eval()
401
+
402
+ t = t.to(dev)
403
+ dt = torch.as_tensor(dt, device=dev, dtype=t.dtype)
404
+ assert self.config.noise.type == 'loglinear'
405
+ sigma_t, _ = self.noise(t)
406
+ sigma_t = sigma_t.to(dev)
407
+
408
+ if t.ndim > 1:
409
+ t = t.squeeze(-1)
410
+ assert t.ndim == 1
411
+
412
+ change_prob_t = t[:, None, None]
413
+ change_prob_s = (t - dt)[:, None, None]
414
+
415
+ assert change_prob_t.ndim == 3, change_prob_t.shape
416
+
417
+ if attn_mask is None:
418
+ attn_mask = torch.ones_like(token_array, device=dev, dtype=torch.long)
419
+ else:
420
+ attn_mask = attn_mask.to(dev)
421
+
422
+ if p_x0 is None:
423
+ log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
424
+ p_x0 = log_p.exp()
425
+ else:
426
+ # ensure provided p_x0 is on dev
427
+ log_p = None
428
+ p_x0 = p_x0.to(dev)
429
+
430
+ assert change_prob_t.ndim == p_x0.ndim
431
+
432
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
433
+
434
+ # zero-masking probability
435
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
436
+
437
+ x_changed = _sample_categorical(q_xs)
438
+ if x_changed.device != dev or x_changed.dtype != token_array.dtype:
439
+ x_changed = x_changed.to(dev, dtype=token_array.dtype)
440
+
441
+ copy_flag = (token_array != self.mask_index)
442
+
443
+ int_copy_flag = copy_flag.to(token_array.dtype)
444
+ x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
445
+
446
+ # returns:
447
+ # log_p (B, L, D) log probabilties of each token under the policy model
448
+ # x_next (B, L) next sequences
449
+ return log_p, x_next
450
+
451
+
452
+ def single_noise_removal(self, token_array, t, dt, p_x0=None, attn_mask=None):
453
+ torch.cuda.empty_cache()
454
+ self.backbone.eval()
455
+ self.noise.eval()
456
+
457
+ assert self.config.noise.type == 'loglinear'
458
+ sigma_t, _ = self.noise(t)
459
+
460
+ if t.ndim > 1:
461
+ t = t.squeeze(-1)
462
+ assert t.ndim == 1
463
+
464
+ change_prob_t = t[:, None, None]
465
+ change_prob_s = (t - dt)[:, None, None]
466
+
467
+ assert change_prob_t.ndim == 3, change_prob_t.shape
468
+
469
+ if attn_mask is None:
470
+ attn_mask = torch.ones_like(token_array).to(self.device)
471
+
472
+ if p_x0 is None:
473
+ log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
474
+ p_x0 = log_p.exp()
475
+
476
+ assert change_prob_t.ndim == p_x0.ndim
477
+
478
+ # changed for noise removal
479
+ p_x0 = p_x0.clone()
480
+ p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask
481
+ p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK
482
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
483
+
484
+ x_changed = _sample_categorical(q_xs)
485
+
486
+ copy_flag = (token_array != self.mask_index)
487
+
488
+ int_copy_flag = copy_flag.to(token_array.dtype)
489
+ x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
490
+
491
+ # returns:
492
+ # log_p (B, L, D) log probabilties of each token under the policy model
493
+ # x_next (B, L) next sequences
494
+ return log_p, x_next
495
+
496
+ def mcts_reverse_step(self, token_array, t, dt, pretrained, p_x0=None, attn_mask=None):
497
+ torch.cuda.empty_cache()
498
+ self.backbone.eval()
499
+ self.noise.eval()
500
+ assert self.config.noise.type == 'loglinear'
501
+ sigma_t, _ = self.noise(t)
502
+
503
+ if t.ndim > 1:
504
+ t = t.squeeze(-1)
505
+ assert t.ndim == 1
506
+
507
+ change_prob_t = t[:, None, None]
508
+ change_prob_s = (t - dt)[:, None, None]
509
+
510
+ assert change_prob_t.ndim == 3, change_prob_t.shape
511
+
512
+ if attn_mask is None:
513
+ attn_mask = torch.ones_like(token_array).to(self.device)
514
+
515
+ if p_x0 is None:
516
+ log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
517
+ p_x0 = log_p.exp()
518
+
519
+ assert change_prob_t.ndim == p_x0.ndim
520
+
521
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
522
+
523
+ # zero-masking probability
524
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
525
+
526
+ x_changed = _sample_categorical(q_xs)
527
+
528
+ copy_flag = (token_array != self.mask_index)
529
+
530
+ int_copy_flag = copy_flag.to(token_array.dtype)
531
+ x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
532
+
533
+ # compute the log-probability under pretrained model at each step
534
+ with torch.no_grad():
535
+ # pretrained should output log-probs over vocab at each position given the *parent* (masked) input
536
+ log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
537
+
538
+ # log-prob of the *sampled token* at each position
539
+ log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
540
+
541
+ # sum only over the sites actually sampled this step (i.e., where parent was mask)
542
+
543
+ assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
544
+ changed_mask = (~copy_flag)
545
+ # mask of tokens that were unmasked in this step
546
+ unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype)
547
+
548
+ log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
549
+
550
+ # compute the per-sequence log-probability under the pretrained model
551
+ log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
552
+ log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
553
+
554
+ # returns:
555
+ # log_p (B, L, D) log probabilties of each token under the policy model
556
+ # x_next (B, L) next sequences
557
+ # log_policy_step (B, ) log probability of all unmasked tokens under policy
558
+ # log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
559
+ return log_p, x_next, log_policy_step, log_pretrained_step
560
+
561
+ def mcts_noise_removal(self, token_array, t, dt, pretrained, p_x0=None, attn_mask=None):
562
+ torch.cuda.empty_cache()
563
+ self.backbone.eval()
564
+ self.noise.eval()
565
+
566
+ assert self.config.noise.type == 'loglinear'
567
+ sigma_t, _ = self.noise(t)
568
+
569
+ if t.ndim > 1:
570
+ t = t.squeeze(-1)
571
+ assert t.ndim == 1
572
+
573
+ change_prob_t = t[:, None, None]
574
+ change_prob_s = (t - dt)[:, None, None]
575
+
576
+ assert change_prob_t.ndim == 3, change_prob_t.shape
577
+
578
+ if attn_mask is None:
579
+ attn_mask = torch.ones_like(token_array).to(self.device)
580
+
581
+ if p_x0 is None:
582
+ log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
583
+ p_x0 = log_p.exp()
584
+
585
+ assert change_prob_t.ndim == p_x0.ndim
586
+
587
+ # changed for noise removal
588
+ p_x0 = p_x0.clone()
589
+ p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask
590
+ p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK
591
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
592
+
593
+ x_changed = _sample_categorical(q_xs)
594
+
595
+ copy_flag = (token_array != self.mask_index)
596
+
597
+ int_copy_flag = copy_flag.to(token_array.dtype)
598
+ x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
599
+
600
+ # compute the log-probability under pretrained model at each step
601
+ with torch.no_grad():
602
+ # pretrained should output log-probs over vocab at each position given the *parent* (masked) input
603
+ log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
604
+
605
+ # log-prob of the *sampled token* at each position
606
+ log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
607
+
608
+ # sum only over the sites actually sampled this step (i.e., where parent was mask)
609
+
610
+ assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
611
+ changed_mask = (~copy_flag)
612
+ # mask of tokens that were unmasked in this step
613
+ unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype)
614
+
615
+ log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
616
+
617
+ # compute the per-sequence log-probability under the pretrained model
618
+ log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
619
+ log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
620
+
621
+ # returns:
622
+ # log_p (B, L, D) log probabilties of each token under the policy model
623
+ # x_next (B, L) next sequences
624
+ # log_policy_step (B, ) log probability of all unmasked tokens under policy
625
+ # log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
626
+ return log_p, x_next, log_policy_step, log_pretrained_step
627
+
628
+ # first step in expansion
629
+ def batch_mcts_reverse_step(self, token_array, t, dt, batch_size, pretrained, p_x0=None, attn_mask=None):
630
+ torch.cuda.empty_cache()
631
+ self.backbone.eval()
632
+ self.noise.eval()
633
+
634
+ assert self.config.noise.type == 'loglinear'
635
+ sigma_t, _ = self.noise(t)
636
+
637
+ if t.ndim > 1:
638
+ t = t.squeeze(-1)
639
+ assert t.ndim == 1
640
+
641
+ change_prob_t = t[:, None, None]
642
+ change_prob_s = (t - dt)[:, None, None]
643
+
644
+ assert change_prob_t.ndim == 3, change_prob_t.shape
645
+
646
+ if token_array.dim() == 1:
647
+ token_array = token_array.unsqueeze(0)
648
+
649
+ # expand to match (num_children, L)
650
+ if attn_mask is None:
651
+ attn_mask = torch.ones_like(token_array).to(self.device)
652
+
653
+ token_array = token_array.to(self.device)
654
+ sigma_t = sigma_t.to(self.device)
655
+
656
+ if p_x0 is None:
657
+ log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
658
+ p_x0 = log_p.exp()
659
+
660
+ assert change_prob_t.ndim == p_x0.ndim
661
+
662
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
663
+
664
+ # zero-masking probability
665
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
666
+
667
+ # repeat the parent token along the first dimension which will be unmasked into distinct sequences
668
+ token_array_expanded = token_array.repeat(batch_size, 1)
669
+
670
+ if self.config.mcts.sampling == 0:
671
+ x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
672
+ else:
673
+ x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
674
+
675
+ copy_flag = (token_array_expanded != self.mask_index)
676
+
677
+ int_copy_flag = copy_flag.to(token_array.dtype)
678
+ x_children = int_copy_flag * token_array_expanded + (1 - int_copy_flag) * x_changed
679
+
680
+
681
+ # compute the log-probability under pretrained model at each step
682
+ with torch.no_grad():
683
+ # pretrained should output log-probs over vocab at each position given the *parent* (masked) input
684
+ log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
685
+
686
+ # expand to match the shape of x_children
687
+ log_pre = log_pre.repeat(batch_size, 1, 1)
688
+
689
+ # log-prob of the *sampled token* at each position
690
+ log_pre_token = log_pre.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
691
+
692
+ # sum only over the sites actually sampled this step (i.e., where parent was mask)
693
+
694
+ assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
695
+ changed_mask = (~copy_flag)
696
+ # mask of tokens that were unmasked in this step
697
+ unmasked_this_step = (changed_mask & (x_children != self.mask_index)).to(log_pre_token.dtype)
698
+
699
+ log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
700
+
701
+ # compute the per-child log-probability under the pretrained model
702
+ log_p = log_p.repeat(batch_size, 1, 1)
703
+ log_policy_token = log_p.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # (B, L) probability of each chosen token
704
+ #print(log_policy_token)
705
+ log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
706
+
707
+ # returns:
708
+ # log_p (B, L, D) log probabilties of each token under the policy model
709
+ # x_children (B, L) child sequences
710
+ # log_policy_step (B, ) log probability of all unmasked tokens under policy
711
+ # log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
712
+ return log_p, x_children, log_policy_step, log_pretrained_step
713
+
714
+
715
+ def compute_invalid_loss(self, logits, k=None, temp=None):
716
+ """
717
+ Penalizes logits that produce invalid sequences using the `is_peptide` function,
718
+ scaling penalties inversely with token probabilities.
719
+
720
+ Args:
721
+ logits: Tensor of shape [batch_size, seq_len, vocab_size].
722
+ k: Number of samples for Gumbel-Rao.
723
+ temp: Temperature for softmax.
724
+
725
+ Returns:
726
+ loss: A scalar tensor representing the total loss for invalid sequences.
727
+ """
728
+
729
+ #samples = self.gumbel_rao(logits, k=k, temp=temp) # (batch_size, seq_len, vocab_size)
730
+
731
+ # Convert logits to sequences using the tokenizer
732
+ batch_token_ids = logits.argmax(dim=-1).to(self.device) # (batch_size, seq_len)
733
+ sampled_sequences = self.tokenizer.batch_decode(batch_token_ids)
734
+
735
+ # Check validity of each sampled sequence (not differentiable)
736
+ penalties = torch.tensor(
737
+ [1 if not self.analyzer.is_peptide(seq) else 0 for seq in sampled_sequences],
738
+ dtype=torch.float32,
739
+ device=self.device
740
+ )
741
+ #print(penalties)
742
+
743
+ # Compute probabilities for each token (batch_size, seq_length)
744
+ sampled_probs = torch.softmax(logits, dim=-1).gather(dim=-1, index=batch_token_ids.unsqueeze(-1)).squeeze(-1).to(self.device)
745
+
746
+ # scale penalties by softmax probability of sampled tokens
747
+ scaled_penalty = penalties[:, None] * sampled_probs # (batch_size, seq_length)
748
+
749
+ return scaled_penalty.to(self.device)
750
+
751
+ ### DIFFUSION LOSS ###
752
+
753
+ def sample_t(self, n, device):
754
+ """
755
+ Sample random time steps for batch training
756
+ """
757
+ # sample values uniformly at random from [0, 1)
758
+ eps_t = torch.rand(n, device=device)
759
+ # antithetic sampling: reduce variance by pairing each sample with complementary sample
760
+ if self.config.training.antithetic_sampling:
761
+ # compute interval between sampled time steps
762
+ offset = torch.arange(n, device=device) / n
763
+ # ensure that each eps value is evenly spaced between [0, 1)
764
+ eps_t = ((eps_t / n) + offset) % 1
765
+
766
+ # ensures values are not exactly 0 or 1
767
+ t = (1 - self.config.training.sampling_eps) * eps_t + self.config.training.sampling_eps
768
+
769
+ return t
770
+
771
+ """def mask_samples(self, x0, mask_prob):
772
+
773
+ # generate array of values in range [0, 1] uniformly at random
774
+ # will be used to determine which tokens are masked
775
+ mask_indices = torch.rand(* x0.shape, device=x0.device) # (batch_size, L)
776
+
777
+ # select tokens to mask if the random value in mask_indices is less than mask_prob
778
+ # this will mask approximately the fraction of tokens indicated by mask_prob
779
+ zt = torch.where(mask_indices < mask_prob, self.mask_index, x0)
780
+
781
+ return zt"""
782
+
783
+ def q_xt(self, x, mask_prob):
784
+ """Computes the noisy sample xt.
785
+
786
+ Args:
787
+ x: int torch.Tensor with shape (batch_size,
788
+ diffusion_model_input_length), input.
789
+ move_chance: float torch.Tensor with shape (batch_size, 1).
790
+ """
791
+
792
+ actual_seq_length = (x != 0).sum(dim=-1, keepdim=True)
793
+ #print(actual_seq_length)
794
+
795
+ max_mask_length = (actual_seq_length * 0.75).long()
796
+
797
+ mask_indices = torch.rand(*x.shape, device=x.device) < mask_prob
798
+
799
+ restricted_move_indices = torch.zeros_like(mask_indices, dtype=torch.bool)
800
+
801
+ for i in range(x.shape[0]):
802
+ true_positions = torch.where(mask_indices[i])[0]
803
+ if len(true_positions) > max_mask_length[i]:
804
+ selected_positions = true_positions[:max_mask_length[i].item()]
805
+ restricted_move_indices[i, selected_positions] = True
806
+ else:
807
+ restricted_move_indices[i] = mask_indices[i]
808
+
809
+ xt = torch.where(restricted_move_indices, self.tokenizer.mask_token_id, x)
810
+
811
+ return xt
812
+
813
+
814
+ def sample_prior(self, *batch_dims):
815
+ """
816
+ Returns array of fully masked sequences with same shape as input
817
+ """
818
+ return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64)
819
+
820
+
821
+ ### COMPUTING LOSS ###
822
+
823
+ def compute_diffusion_loss(self, model_output, xt, x0, t):
824
+ """
825
+ Computes diffusion loss term in ELBO
826
+ (evaluates how accurately the model predicts the token probabilities at each time step)
827
+
828
+ Inputs:
829
+ - model_output: [sequence length, vocab size, vocab size] array of logits for each token at each sequence position
830
+ - zt: corrupted version of original input x0 at timestep t
831
+ - x0: original input sequence
832
+ - t: timestep
833
+ """
834
+ # compute interval between each timestep
835
+ dt = 1 / self.T
836
+
837
+ # compute vectorized alpha scaling terms for the logits at timestep s and t
838
+ alpha_t = 1 - t + torch.zeros_like(x0)
839
+ # s = t - dt
840
+ alpha_s = 1 - (t - dt) + torch.zeros_like(x0)
841
+
842
+ # gather vector of log-probabilities for each token in x0
843
+ # log<x_theta, x>
844
+ log_x_theta_at_x0 = torch.gather(model_output, -1, x0[:, :, None]) # shape (B, L, vocab_size)
845
+ # gather log-probabillities for assigning a masked token at each position in the sequence at time t
846
+ # log<x_theta, m>
847
+ log_x_theta_at_m = model_output[:, :, self.mask_index]
848
+ # obtain non-log probability of assigning a masked token
849
+ # <xt, m>
850
+ x_theta_at_m = log_x_theta_at_m.exp()
851
+
852
+ # first term of diffusion loss
853
+ term_1_coef = dt / t
854
+ term_1_log_numerator = torch.log((alpha_t * x_theta_at_m) / t + 1)
855
+ term_1_log_denom = log_x_theta_at_x0
856
+
857
+ # second term of diffusion loss
858
+ term_2_coef = 1 - (dt / t)
859
+ term_2_log_numerator = term_1_log_numerator
860
+ term_2_log_denom = torch.log((alpha_s * x_theta_at_m) / (t - dt) + 1)
861
+
862
+ L_vb_masked = (term_1_coef * (term_1_log_numerator - term_1_log_denom) +
863
+ term_2_coef * (term_2_log_numerator - term_2_log_denom))
864
+
865
+ # multiply by <zt, m> term
866
+ L_vb = L_vb_masked * (xt == self.mask_index)
867
+
868
+ # scale by T and return
869
+ return self.T * L_vb
870
+
871
+ def _forward_pass_diffusion(self, x0, attn_mask, bond_mask=None, mask=None):
872
+ """
873
+ Training reverse diffusion model x_theta to reconstruct samples x0
874
+
875
+ bond_mask: (batch, seq_length)
876
+ """
877
+ # randomly sample time steps to start the denoising process for each x0 in batch
878
+ t = self.sample_t(x0.shape[0], self.device)
879
+
880
+ # if we are training the intermediate transition blocks
881
+ if self.T > 0:
882
+ # scale by total timesteps T and cast to integer
883
+ t = (t * self.T).to(torch.int)
884
+ # scale down by T to get a multiple of 1/T
885
+ t = t / self.T
886
+ # add 1/T to ensure no 0 values
887
+ t += (1 / self.T)
888
+
889
+ # get noise and rate of noise at timestep t
890
+ # sigma = -log(1-t); dsigma = 1 / (1-t)
891
+ sigma, dsigma = self.noise(t)
892
+ time_conditioning = sigma[:, None]
893
+
894
+ # Get masking probabilities for all tokens for each batch
895
+ # log-linear: 1 - alpha = t
896
+ base_mask_prob = 1 - torch.exp(-sigma[:, None]) # (batch_size, L)
897
+
898
+ if self.config.noise.state_dependent and (bond_mask is not None):
899
+ # log-polynomial masking schedule: alpha = 1 - t^w
900
+ # bond_sigma = -log(1-t^w) for w = 3 (default)
901
+ # bond_dsigma = -wt^(w-1) / (1-t^w)
902
+ bond_sigma, bond_dsigma = self.bond_noise(t) # scalar
903
+ # expand dimensions for broadcasting to (B, L)
904
+ bond_sigma = bond_sigma[:, None]
905
+ bond_dsigma = bond_dsigma[:, None]
906
+ sigma = sigma[:, None]
907
+ dsigma = dsigma[:, None]
908
+
909
+ # compute masking probability for peptide bonds 1 - bond_alpha = t^w
910
+ bond_mask_prob = 1 - torch.exp(-bond_sigma).to(self.device)
911
+ # piece together (B, L) tensor with modified masking prob at peptide-bond locations
912
+ mask_prob = torch.where(bond_mask == 1, bond_mask_prob, base_mask_prob).to(self.device)
913
+ #print(mask_prob)
914
+ dsigma = torch.where(bond_mask == 1, bond_dsigma, dsigma).to(self.device)
915
+ sigma = torch.where(bond_mask == 1, bond_sigma, sigma).to(self.device)
916
+ else:
917
+ mask_prob = base_mask_prob.to(self.device)
918
+
919
+ # get masked samples at different timesteps
920
+ if mask is None:
921
+ zt = self.q_xt(x0, mask_prob).to(self.device)
922
+ else:
923
+ zt = x0.where(mask==1, torch.full_like(x0, self.mask_index)).to(self.device)
924
+
925
+ model_output = self.forward(zt, attn_mask=attn_mask.to(self.device), sigma=time_conditioning).to(self.device)
926
+
927
+ # debugging
928
+ assert not torch.isnan(model_output).any()
929
+ assert model_output.is_cuda
930
+ utils.print_nans(model_output, 'model_output')
931
+
932
+ # compute invalid loss
933
+ invalid_loss = self.compute_invalid_loss(logits=model_output).to(self.device) # (B, L)
934
+ #print(invalid_loss)
935
+
936
+ if self.T > 0:
937
+ # compute diffusion loss
938
+ diffusion_loss = self.compute_diffusion_loss(model_output, zt, x0, t)
939
+ return diffusion_loss
940
+
941
+ # compute loss for the final that converts from z0 to x0
942
+ # -log(p_theta)
943
+ # get (batch_size, L) array of log-probabilities
944
+ log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1).to(self.device) # (B, L)
945
+
946
+ if self.config.noise.state_dependent and (bond_mask is not None):
947
+ return (-log_p_theta * (dsigma / torch.expm1(sigma)) + invalid_loss).to(self.device)
948
+ else:
949
+ return ((-log_p_theta * (dsigma / torch.expm1(sigma))[:, None]) + invalid_loss).to(self.device)
950
+
951
+ def _loss(self, x0, attn_mask, bond_mask=None, mask=None):
952
+ loss = self._forward_pass_diffusion(x0, attn_mask, bond_mask, mask)
953
+
954
+ # negative log loss
955
+ nlls = loss * attn_mask
956
+
957
+ # count number of tokens
958
+ num_tokens = attn_mask.sum()
959
+
960
+ # compute batch loss
961
+ batch_nll = nlls.sum()
962
+ # compute per token loss
963
+ token_nll = batch_nll / num_tokens
964
+ # return losses
965
+ return Loss(loss = token_nll.to(self.device), nlls = nlls.to(self.device), attn_mask = attn_mask.to(self.device))
966
+
967
+ def _compute_loss(self, batch, prefix, bond_mask=None):
968
+
969
+ attn_mask = batch['attention_mask'].to(self.device)
970
+
971
+ if 'mask' in batch:
972
+ mask = batch['mask'].to(self.device)
973
+ else:
974
+ mask = None
975
+
976
+ if 'bond_mask' in batch:
977
+ bond_mask = batch['bond_mask'].to(self.device)
978
+ else:
979
+ bond_mask = None
980
+
981
+ losses = self._loss(batch['input_ids'].to(self.device), attn_mask, bond_mask, mask)
982
+ loss = losses.loss
983
+
984
+ if prefix == 'train':
985
+ self.train_metrics.update(
986
+ losses.nlls.to(self.device),
987
+ losses.attn_mask.to(self.device)
988
+ )
989
+ metrics = self.train_metrics
990
+ elif prefix == 'val':
991
+ self.valid_metrics.update(
992
+ losses.nlls.to(self.device),
993
+ losses.attn_mask.to(self.device)
994
+ )
995
+ metrics = self.valid_metrics
996
+ elif prefix == 'test':
997
+ self.test_metrics.update(losses.nlls, losses.attn_mask)
998
+ metrics = self.test_metrics
999
+ else:
1000
+ raise ValueError(f'Invalid prefix: {prefix}')
1001
+
1002
+ self.log_dict(metrics,
1003
+ on_step=False,
1004
+ on_epoch=True,
1005
+ sync_dist=True)
1006
+
1007
+ return loss
1008
+
1009
+
1010
+ ### SAMPLING ###
1011
+
1012
+ def generate_from_masked(self, num_samples=None, seq_length=None, sample_steps=128, eps=1e-5):
1013
+ # get number of timesteps
1014
+ if sample_steps is None:
1015
+ sample_steps = self.config.sampling.steps
1016
+
1017
+ if seq_length is None:
1018
+ seq_length = self.config.sampling.seq_length
1019
+
1020
+ # sample fully masked sequences
1021
+ z = self.sample_prior(num_samples, seq_length).to(self.device)
1022
+
1023
+ # create vector of sample_steps timesteps
1024
+ timesteps = torch.linspace(1, eps, sample_steps + 1, device=self.device)
1025
+
1026
+ # compute interval between timesteps
1027
+ dt = (1 - eps) / sample_steps
1028
+
1029
+ for i in range(sample_steps):
1030
+ t = timesteps[i] * torch.ones(z.shape[0], 1, device=self.device)
1031
+
1032
+ z = self.single_reverse_step(z, t, dt)
1033
+
1034
+ return z
1035
+
1036
+
1037
+ ### SAMPLING STEP ###
1038
+ """
1039
+ def single_reverse_step(self, zt, t, dt, attn_mask=None):
1040
+ # get sigma values that determine masking prob
1041
+ sigma_t, _ = self.noise(t)
1042
+ sigma_s, _ = self.noise(t - dt)
1043
+
1044
+ # reshape sigmas
1045
+ if sigma_t.ndim > 1:
1046
+ sigma_t = sigma_t.squeeze(-1)
1047
+ if sigma_s.ndim > 1:
1048
+ sigma_s = sigma_s.squeeze(-1)
1049
+ assert sigma_t.ndim == 1, sigma_t.shape
1050
+ assert sigma_s.ndim == 1, sigma_s.shape
1051
+
1052
+ # compute masking probabilities for each timestep
1053
+ change_prob_t = 1 - torch.exp(-sigma_t)
1054
+ change_prob_s = 1 - torch.exp(-sigma_s)
1055
+
1056
+ # expand dimensions
1057
+ change_prob_t = change_prob_t[:, None, None]
1058
+ change_prob_s = change_prob_s[:, None, None]
1059
+
1060
+ # get prodiction model that outputs token probabilities
1061
+ log_p_x0 = self.forward(zt, attn_mask=attn_mask, sigma=sigma_t)
1062
+
1063
+ # check dimensions match
1064
+ assert change_prob_t.ndim == log_p_x0.ndim
1065
+
1066
+ # compute reverse diffusion probability of being unmasked at timestep s
1067
+ # (sigma_s - sigma_t)*x_theta
1068
+ q_zs = log_p_x0.exp() * (change_prob_t - change_prob_s)
1069
+
1070
+ # compute reverse diffusion probability of remaining masked at timestep s
1071
+ # (1 - sigma_s)*m
1072
+ q_zs[:, :, self.mask_index] = change_prob_s[:, :, 0]
1073
+
1074
+ # sample sequence at timestep s from categorical distribution of q_zs
1075
+ z_changed = _sample_categorical(q_zs)
1076
+
1077
+ copy_flag = (zt != self.mask_index).to(zt.dtype)
1078
+ return (copy_flag * zt) + ((1 - copy_flag) * z_changed)"""
1079
+
1080
+ def cached_reverse_step(self, x, t, dt, p_x0=None, attn_mask=None):
1081
+ assert self.config.noise.type == 'loglinear'
1082
+ sigma_t, _ = self.noise(t)
1083
+
1084
+ if t.ndim > 1:
1085
+ t = t.squeeze(-1)
1086
+ assert t.ndim == 1
1087
+
1088
+ change_prob_t = t[:, None, None]
1089
+ change_prob_s = (t - dt)[:, None, None]
1090
+
1091
+ assert change_prob_t.ndim == 3, change_prob_t.shape
1092
+
1093
+ if p_x0 is None:
1094
+ p_x0 = self.forward(x, attn_mask=attn_mask, sigma=sigma_t).exp()
1095
+
1096
+ assert change_prob_t.ndim == p_x0.ndim
1097
+
1098
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
1099
+
1100
+ # zero-masking probability
1101
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
1102
+
1103
+ x_changed = _sample_categorical(q_xs)
1104
+
1105
+ copy_flag = (x != self.mask_index).to(x.dtype)
1106
+
1107
+ return p_x0, copy_flag * x + (1 - copy_flag) * x_changed
1108
+
1109
+ # first step in expansion
1110
+ def batch_cached_reverse_step(self, token_array, t, dt, batch_size, p_x0=None, attn_mask=None):
1111
+ """
1112
+ Generates batch_size different samples from the same starting point for the
1113
+ first expansion step of MCTS
1114
+ """
1115
+
1116
+ assert self.config.noise.type == 'loglinear'
1117
+ sigma_t, _ = self.noise(t)
1118
+
1119
+ if t.ndim > 1:
1120
+ t = t.squeeze(-1)
1121
+ assert t.ndim == 1
1122
+
1123
+ change_prob_t = t[:, None, None]
1124
+ change_prob_s = (t - dt)[:, None, None]
1125
+
1126
+ assert change_prob_t.ndim == 3, change_prob_t.shape
1127
+
1128
+ if token_array.dim() == 1:
1129
+ token_array = token_array.unsqueeze(0)
1130
+ #token_array = token_array.repeat(batch_size, 1)
1131
+
1132
+ attn_mask = torch.ones_like(token_array).to(self.device)
1133
+
1134
+ if p_x0 is None:
1135
+ p_x0 = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t).exp()
1136
+
1137
+ assert change_prob_t.ndim == p_x0.ndim
1138
+
1139
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
1140
+
1141
+ # zero-masking probability
1142
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
1143
+
1144
+ # repeat the parent token along the first dimension which will be unmasked into distinct sequences
1145
+ token_array = token_array.repeat(batch_size, 1)
1146
+
1147
+ if self.config.mcts.sampling == 0:
1148
+ x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
1149
+ else:
1150
+ x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
1151
+
1152
+ copy_flag = (token_array != self.mask_index).to(token_array.dtype)
1153
+
1154
+ return p_x0, copy_flag * token_array + (1 - copy_flag) * x_changed
1155
+
1156
+ def _process_sigma(self, sigma):
1157
+ if sigma.ndim > 1:
1158
+ sigma = sigma.squeeze(-1)
1159
+ if not self.time_conditioning:
1160
+ sigma = torch.zeros_like(sigma)
1161
+ assert sigma.ndim == 1, sigma.shape
1162
+ return sigma
1163
+
1164
+ def forward(self, zt, attn_mask, sigma):
1165
+ """
1166
+ Predicts the token log-probabilities from zt at time t with noise schedule sigma
1167
+ """
1168
+ sigma = self._process_sigma(sigma)
1169
+
1170
+ with torch.cuda.amp.autocast(dtype=torch.float32):
1171
+ logits = self.backbone(zt, attn_mask).to(self.device)
1172
+
1173
+ return self.subs_parameterization(logits, zt)
1174
+
1175
+ def subs_parameterization(self, logits, zt):
1176
+ """
1177
+ Updates reverse diffusion logits based on SUBS parameterization:
1178
+ - zero masking probabilities: -infinity probability of being masked during reverse diffusion
1179
+ - carry-over unmasking: unmasked input tokens remain unchanged during reverse diffusion
1180
+
1181
+ Args:
1182
+ logits: vector of token probabilities for unmasking masked tokens
1183
+ zt: partially unmasked sequence at current timestep
1184
+ """
1185
+ logits[:, :, self.mask_index] += self.neg_infinity # [sequence index, current token, next token]
1186
+
1187
+
1188
+ logits = (logits - torch.logsumexp(logits, dim=-1, keepdim=True)).to(self.device)
1189
+
1190
+
1191
+ unmasked_indices = (zt != self.mask_index).to(self.device) # shape: [200, seq_length]
1192
+ batch_idx, seq_idx = torch.where(unmasked_indices) # Get explicit indices
1193
+ batch_idx = batch_idx.to(self.device)
1194
+ seq_idx = seq_idx.to(self.device)
1195
+ tokens = zt[batch_idx, seq_idx].to(self.device) # Get the tokens at those positions
1196
+
1197
+ #assert logits.is_contiguous(), "logits tensor is not contiguous"
1198
+ #assert unmasked_indices.shape == zt.shape, "same shape"
1199
+ #assert not torch.isnan(logits).any(), "NaN values found in logits"
1200
+ #assert tokens.max() < logits.shape[-1], "token indices out of bounds"
1201
+ #assert batch_idx.max() < logits.shape[0], "batch index out of bounds"
1202
+ #assert seq_idx.max() < logits.shape[1], "seq index out of bounds"
1203
+ #assert batch_idx.device == seq_idx.device == logits.device == tokens.device, "device inconsistent"
1204
+
1205
+ logits[unmasked_indices] = self.neg_infinity # Set everything to -inf first
1206
+ logits[unmasked_indices, zt[unmasked_indices]] = 0 # Set only the specific token positions to 0
1207
+ # return logits with SUBS parameterization
1208
+ return logits.to(self.device)
1209
+
1210
+ """SAMPLING"""
1211
+ @torch.no_grad()
1212
+ def _sample(self, num_steps=None, eps=1e-5, x_input=None):
1213
+ """
1214
+ Generate samples
1215
+ """
1216
+ batch_size_per_gpu = self.config.eval.perplexity_batch_size
1217
+
1218
+ if num_steps is None:
1219
+ num_steps = self.config.sampling.steps
1220
+
1221
+ if x_input is not None:
1222
+ x = x_input['input_ids'].to(self.device)
1223
+ attn_mask = x_input['attention_mask'].to(self.device)
1224
+ else:
1225
+ x = self.sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device)
1226
+ attn_mask = torch.ones_like(x).to(self.device)
1227
+
1228
+
1229
+ timesteps = torch.linspace(1, eps, num_steps+1, device=self.device)
1230
+ dt = (1 - eps) / num_steps
1231
+ p_x0_cache = None
1232
+ generation_history = [] # used to track which tokens are unmasked
1233
+
1234
+ for i in range(num_steps):
1235
+ t = timesteps[i] * torch.ones(x.shape[0], 1, device = self.device)
1236
+ if self.sampler == 'ddpm':
1237
+ x = self.single_reverse_step(x, t, dt).to(self.device)
1238
+
1239
+ elif self.sampler == 'ddpm_cache':
1240
+ p_x0_cache, x_next = self.cached_reverse_step(x, t, dt, p_x0=p_x0_cache, attn_mask=attn_mask)
1241
+ if (not torch.allclose(x_next, x) or self.time_conditioning):
1242
+ # Disable caching
1243
+ p_x0_cache = None
1244
+ x = x_next.to(self.device)
1245
+ #print(self.tokenizer.decode(x.squeeze()))
1246
+ else:
1247
+ x = self._analytic_update(x, t, dt, attn_mask).to(self.device)
1248
+
1249
+ if self.config.sampling.noise_removal:
1250
+ t = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device)
1251
+ if self.sampler == 'analytic':
1252
+ x = self._denoiser_update(x, t).to(self.device)
1253
+ else:
1254
+ time_conditioning = self.noise(t)[0].to(self.device)
1255
+ x = self.forward(x, attn_mask=attn_mask, sigma=time_conditioning).argmax(dim=-1).to(self.device)
1256
+ #print(self.tokenizer.decode(x.squeeze()))
1257
+ return x.to(self.device)
1258
+
1259
+
1260
+ def restore_model_and_sample(self, num_steps, eps=1e-5):
1261
+ """Generate samples from the model."""
1262
+ self.backbone.eval()
1263
+ self.noise.eval()
1264
+ samples = self._sample(num_steps=num_steps, eps=eps)
1265
+ self.backbone.train()
1266
+ self.noise.train()
1267
+ return samples
1268
+
1269
+ def get_score(self, zt, sigma, attn_mask=None):
1270
+
1271
+ # score(x, t) = p_t(y) / p_t(x)
1272
+ # => log score(x, t) = log p_t(y) - log p_t(x)
1273
+
1274
+ # case 1: x = masked
1275
+ # (i) y = unmasked
1276
+ # log score(x, t) = log p_\theta(x)|_y + log k
1277
+ # where k = exp(- sigma) / (1 - exp(- sigma))
1278
+ # (ii) y = masked
1279
+ # log score(x, t) = 0
1280
+
1281
+ # case 2: x = unmasked
1282
+ # (i) y != masked, y != x
1283
+ # log score(x_i, t) = - inf
1284
+ # (ii) y = x
1285
+ # log score(x_i, t) = 0
1286
+ # (iii) y = masked token
1287
+ # log score(x_i, t) = - log k
1288
+ # where k = exp(- sigma) / (1 - exp(- sigma))
1289
+
1290
+ model_output = self.forward(zt, attn_mask=attn_mask, sigma=sigma)
1291
+
1292
+ log_k = -torch.log(torch.expm1(sigma)).squeeze(-1)
1293
+ assert log_k.ndim == 1
1294
+
1295
+ masked_score = model_output + log_k[:, None, None]
1296
+ masked_score[:, :, self.mask_index] = 0
1297
+
1298
+ unmasked_score = self.neg_infinity * torch.ones_like(model_output)
1299
+ unmasked_score = torch.scatter(
1300
+ unmasked_score, -1,
1301
+ zt[..., None],
1302
+ torch.zeros_like(unmasked_score[..., :1]))
1303
+
1304
+ unmasked_score[:, :, self.mask_index] = - (log_k[:, None] * torch.ones_like(zt))
1305
+
1306
+ masked_indices = (zt == self.mask_index).to(model_output.dtype)[:, :, None]
1307
+
1308
+ model_output = (masked_score * masked_indices + unmasked_score * (1 - masked_indices))
1309
+
1310
+ return model_output.exp()
1311
+
1312
+ def _staggered_score(self, score, dsigma):
1313
+ score = score.clone()
1314
+ extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
1315
+ score *= dsigma.exp()[:, None]
1316
+ score[..., self.mask_index] += extra_const
1317
+ return score
1318
+
1319
+ def _analytic_update(self, x, t, step_size, attn_mask=None):
1320
+ curr_sigma, _ = self.noise(t)
1321
+ next_sigma, _ = self.noise(t - step_size)
1322
+ dsigma = curr_sigma - next_sigma
1323
+ score = self.get_score(x, attn_mask, curr_sigma)
1324
+ stag_score = self._staggered_score(score, dsigma)
1325
+ probs = stag_score * self._transp_transition(x, dsigma)
1326
+ return _sample_categorical(probs)
1327
+
1328
+ def _denoiser_update(self, x, t):
1329
+ sigma, _ = self.noise(t)
1330
+ score = self.get_score(x, sigma)
1331
+ stag_score = self._staggered_score(score, sigma)
1332
+ probs = stag_score * self._transp_transition(x, sigma)
1333
+ probs[..., self.mask_index] = 0
1334
+ samples = _sample_categorical(probs)
1335
+ return samples
1336
+
1337
+ def _transp_transition(self, i, sigma):
1338
+ sigma = unsqueeze(sigma, reference=i[..., None])
1339
+ edge = torch.exp(-sigma) * F.one_hot(
1340
+ i, num_classes=self.vocab_size)
1341
+ edge += torch.where(i == self.mask_index,
1342
+ 1 - torch.exp(-sigma).squeeze(-1),
1343
+ 0)[..., None]
1344
+ return edge
1345
+
1346
+
1347
+ """TRAINING from https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py"""
1348
+
1349
+ def on_train_epoch_start(self):
1350
+ torch.cuda.empty_cache()
1351
+ self.backbone.train()
1352
+ self.noise.train()
1353
+
1354
+
1355
+ def training_step(self, batch, batch_idx):
1356
+ # Initialize throughput calculation
1357
+ start_time = time.time()
1358
+
1359
+ if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
1360
+ loss = self._compute_loss(batch, prefix='train', bond_mask=batch['bond_mask'])
1361
+ else:
1362
+ loss = self._compute_loss(batch, prefix='train')
1363
+
1364
+ self.log(name='trainer/loss',
1365
+ value=loss.item(),
1366
+ on_step=True,
1367
+ on_epoch=False,
1368
+ sync_dist=True)
1369
+
1370
+ # Calculate throughput
1371
+ elapsed_time = time.time() - start_time
1372
+ total_tokens = batch['input_ids'].numel()
1373
+ throughput = total_tokens / elapsed_time
1374
+
1375
+ self.log(name='trainer/throughput',
1376
+ value=throughput,
1377
+ on_step=True,
1378
+ on_epoch=False,
1379
+ sync_dist=True)
1380
+
1381
+ return loss
1382
+
1383
+
1384
+ def on_load_checkpoint(self, checkpoint):
1385
+ self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
1386
+ self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed']
1387
+
1388
+ ### VALIDATION ###
1389
+ def on_validation_epoch_start(self):
1390
+ gc.collect()
1391
+ torch.cuda.empty_cache()
1392
+ self.backbone.eval()
1393
+ self.noise.eval()
1394
+ assert self.valid_metrics.nll.mean_value == 0
1395
+ assert self.valid_metrics.nll.weight == 0
1396
+
1397
+ def validation_step(self, batch, batch_idx):
1398
+ if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
1399
+ loss = self._compute_loss(batch, prefix='val', bond_mask=batch['bond_mask'])
1400
+ else:
1401
+ loss = self._compute_loss(batch, prefix='val')
1402
+
1403
+ self.log(name='trainer/val_loss',
1404
+ value=loss.item(),
1405
+ on_step=True,
1406
+ on_epoch=False,
1407
+ prog_bar=True,
1408
+ sync_dist=True)
1409
+ return loss
1410
+
1411
+ def on_validation_epoch_end(self):
1412
+ gc.collect()
1413
+ torch.cuda.empty_cache()
1414
+
1415
+ ### OPTIMIZATION ###
1416
+
1417
+ def optimizer_step(self, *args, **kwargs):
1418
+ super().optimizer_step(*args, **kwargs)
1419
+
1420
+ gc.collect()
1421
+ torch.cuda.empty_cache()
1422
+
1423
+ def configure_optimizers(self):
1424
+ optimizer = torch.optim.AdamW(
1425
+ itertools.chain(self.backbone.parameters(),self.noise.parameters()),
1426
+ lr=self.config.optim.lr,
1427
+ betas=(self.config.optim.beta1, self.config.optim.beta2),
1428
+ eps=self.config.optim.eps,
1429
+ weight_decay=self.config.optim.weight_decay
1430
+ )
1431
+
1432
+ self.total_steps = self.config.trainer.max_steps
1433
+ scheduler = CosineWarmup(optimizer,
1434
+ warmup_steps=self.config.lr_scheduler.num_warmup_steps,
1435
+ total_steps=self.total_steps)
1436
+
1437
+ scheduler_dict = {
1438
+ 'scheduler': scheduler,
1439
+ 'interval': 'step',
1440
+ 'frequency': 1,
1441
+ 'monitor': 'val/loss',
1442
+ 'name': 'trainer/lr'
1443
+ }
1444
+
1445
+ return [optimizer], [scheduler_dict]
1446
+
1447
+ @torch.no_grad()
1448
+ def compute_masked_perplexity(self, generated_ids, input_ids):
1449
+ """
1450
+ Computes masked perplexity between array of generated token ids and masked ids that are converted to logits
1451
+ """
1452
+
1453
+ total_nll = 0
1454
+ total_tokens = 0
1455
+
1456
+ input_ids = torch.tensor(input_ids).to(self.device)
1457
+ #print(input_ids)
1458
+
1459
+ for sequence in generated_ids:
1460
+ # tokenize the sequence
1461
+
1462
+ gt_ids = torch.tensor(sequence).to(self.device)
1463
+ #print(gt_ids)
1464
+
1465
+ sys.stdout.flush()
1466
+
1467
+ # forward pass thorugh backbone peptideclm model
1468
+ attn_mask = torch.ones_like(input_ids).to(self.device)
1469
+
1470
+ # compute logits using backbone
1471
+
1472
+ if self.config.mode in ['train', 'ppl_eval']:
1473
+ outputs = self.backbone.forward(input_ids=input_ids, attn_mask=attn_mask)
1474
+ elif self.config.mode == 'sample_eval':
1475
+ outputs = self.backbone.forward(input_ids=input_ids)
1476
+
1477
+
1478
+ # get logits for each position in sequence across all tokens in vocab
1479
+ #logits = outputs[-1] # (batch_size, seq_length, vocab_size)
1480
+
1481
+ logits = outputs.view(-1, outputs.size(-1))
1482
+ gt_ids = gt_ids.view(-1)
1483
+
1484
+ #print(logits.shape)
1485
+ #print(gt_ids.shape)
1486
+
1487
+ # compute loss
1488
+ # shift_logits = logits[:, :-1, :].contiguous() # remove eos
1489
+ # shift_labels = input_ids[:, 1:].contiguous()
1490
+ # print(masked)
1491
+
1492
+ loss = F.cross_entropy(logits,
1493
+ gt_ids.where(input_ids==self.mask_index, torch.full_like(gt_ids, -100)).view(-1),
1494
+ reduction='sum')
1495
+
1496
+ total_nll += loss.item()
1497
+ # count all non-padding tokens
1498
+ total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() # count in bos and eos
1499
+
1500
+ # compute pseudo-perplexity
1501
+ # print(total_nll, ",;,", total_tokens)
1502
+ pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens))
1503
+ self.gen_ppl_metric.update(pseudo_perplexity)
1504
+
1505
+ return pseudo_perplexity.item()
1506
+
1507
+
1508
+ def unsqueeze(x, reference):
1509
+ return x.view(* x.shape, * ((1,) * (len(reference.shape) - len(x.shape))))
1510
+
1511
+ class CosineWarmup(_LRScheduler):
1512
+ def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
1513
+ self.warmup_steps = warmup_steps
1514
+ self.total_steps = total_steps
1515
+ self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
1516
+ super(CosineWarmup, self).__init__(optimizer, last_epoch)
1517
+
1518
+ def get_lr(self):
1519
+ if self.last_epoch < self.warmup_steps:
1520
+ return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
1521
+
1522
+ progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
1523
+ cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
1524
+ decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
1525
+
1526
+ return [decayed_lr * base_lr for base_lr in self.base_lrs]
tr2d2-pep/finetune.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # direct reward backpropagation
2
+ from diffusion import Diffusion
3
+ from hydra import initialize, compose
4
+ from hydra.core.global_hydra import GlobalHydra
5
+ import numpy as np
6
+ from scipy.stats import pearsonr
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import argparse
10
+ import wandb
11
+ import os
12
+ import datetime
13
+ from finetune_peptides import finetune
14
+ from peptide_mcts import MCTS
15
+ from utils.utils import str2bool, set_seed
16
+ from scoring.scoring_functions import ScoringFunctions
17
+
18
+
19
+ argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
20
+ argparser.add_argument('--base_path', type=str, default='')
21
+ argparser.add_argument('--learning_rate', type=float, default=1e-4)
22
+ argparser.add_argument('--num_epochs', type=int, default=100)
23
+ argparser.add_argument('--num_accum_steps', type=int, default=4)
24
+ argparser.add_argument('--truncate_steps', type=int, default=50)
25
+ argparser.add_argument("--truncate_kl", type=str2bool, default=False)
26
+ argparser.add_argument('--gumbel_temp', type=float, default=1.0)
27
+ argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
28
+ argparser.add_argument('--batch_size', type=int, default=32)
29
+ argparser.add_argument('--name', type=str, default='debug')
30
+ argparser.add_argument('--total_num_steps', type=int, default=128)
31
+ argparser.add_argument('--copy_flag_temp', type=float, default=None)
32
+ argparser.add_argument('--save_every_n_epochs', type=int, default=10)
33
+ argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
34
+ argparser.add_argument("--seed", type=int, default=0)
35
+ # new
36
+ argparser.add_argument('--run_name', type=str, default='peptides')
37
+ argparser.add_argument("--device", default="cuda:0", type=str)
38
+ argparser.add_argument("--save_path_dir", default="/scratch/pranamlab/sophtang/home/tr2d2/peptides/checkpoints/", type=str)
39
+ # mcts
40
+ argparser.add_argument('--num_sequences', type=int, default=10)
41
+ argparser.add_argument('--num_children', type=int, default=50)
42
+ argparser.add_argument('--num_iter', type=int, default=30) # iterations of mcts
43
+ argparser.add_argument('--seq_length', type=int, default=200)
44
+ argparser.add_argument('--time_conditioning', action='store_true', default=False)
45
+ argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
46
+ argparser.add_argument('--buffer_size', type=int, default=100)
47
+ argparser.add_argument('--wdce_num_replicates', type=int, default=16)
48
+ argparser.add_argument('--noise_removal', action='store_true', default=False)
49
+ argparser.add_argument('--grad_clip', action='store_true', default=False)
50
+ argparser.add_argument('--resample_every_n_step', type=int, default=10)
51
+ argparser.add_argument('--exploration', type=float, default=0.1)
52
+ argparser.add_argument('--reset_every_n_step', type=int, default=100)
53
+ argparser.add_argument('--alpha', type=float, default=0.01)
54
+ argparser.add_argument('--scalarization', type=str, default='sum')
55
+ argparser.add_argument('--no_mcts', action='store_true', default=False)
56
+ argparser.add_argument("--centering", action='store_true', default=False)
57
+
58
+ # objectives
59
+ argparser.add_argument('--num_obj', type=int, default=5)
60
+ argparser.add_argument('--prot_seq', type=str, default=None)
61
+ argparser.add_argument('--prot_name', type=str, default=None)
62
+
63
+ args = argparser.parse_args()
64
+ print(args)
65
+
66
+ # pretrained model path
67
+ ckpt_path = f'{args.base_path}/TR2-D2/tr2d2-pep/pretrained/peptune-pretrained.ckpt'
68
+
69
+ # reinitialize Hydra
70
+ GlobalHydra.instance().clear()
71
+
72
+ # Initialize Hydra and compose the configuration
73
+ initialize(config_path="configs", job_name="load_model")
74
+ cfg = compose(config_name="peptune_config.yaml")
75
+ curr_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
76
+
77
+ # proteins
78
+ amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
79
+ tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
80
+ gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
81
+ glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
82
+ glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
83
+ ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
84
+ cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL'
85
+ ligase = 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS'
86
+ skp2 = 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL'
87
+
88
+ if args.prot_seq is not None:
89
+ prot = args.prot_seq
90
+ prot_name = args.prot_name
91
+ filename = args.prot_name
92
+ else:
93
+ prot = tfr
94
+ prot_name = "tfr"
95
+ filename = "tfr"
96
+
97
+ if args.no_mcts:
98
+ args.run_name = f'{prot_name}_resample{args.resample_every_n_step}_no-mcts'
99
+ else:
100
+ args.run_name = f'{prot_name}_resample{args.resample_every_n_step}_buffer{args.buffer_size}_numiter{args.num_iter}_children{args.num_children}_{curr_time}'
101
+
102
+ args.save_path = os.path.join(args.save_path_dir, args.run_name)
103
+ os.makedirs(args.save_path, exist_ok=True)
104
+ # wandb init
105
+ wandb.init(project='tree-multi', name=args.run_name, config=args, dir=args.save_path)
106
+
107
+ log_path = os.path.join(args.save_path, 'log.txt')
108
+
109
+ set_seed(args.seed, use_cuda=True)
110
+
111
+ # Initialize the model
112
+ policy_model = Diffusion.load_from_checkpoint(ckpt_path,
113
+ config=cfg,
114
+ mode="train",
115
+ device=args.device,
116
+ map_location=args.device)
117
+ pretrained = Diffusion.load_from_checkpoint(ckpt_path,
118
+ config=cfg,
119
+ mode="eval",
120
+ device=args.device,
121
+ map_location=args.device)
122
+
123
+ # define mcts
124
+ score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
125
+
126
+ mcts = MCTS(args, cfg, policy_model, pretrained, score_func_names, prot_seqs=[prot])
127
+
128
+ if args.no_mcts:
129
+ reward_model = ScoringFunctions(score_func_names, prot_seqs=[prot], device=args.device)
130
+ finetune(args, cfg, policy_model, reward_model=reward_model, mcts=None, pretrained=pretrained, filename=filename, prot_name=prot_name)
131
+ else:
132
+ mcts = MCTS(args, cfg, policy_model, pretrained, score_func_names, prot_seqs=[prot])
133
+ finetune(args, cfg, policy_model, reward_model=mcts.rewardFunc, mcts=mcts, pretrained=None, filename=filename, prot_name=prot_name)
tr2d2-pep/finetune.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home
4
+ ENV_PATH=/path/to/your/env
5
+ SCRIPT_LOC=$HOME_LOC/TR2-D2/tr2d2-pep
6
+ LOG_LOC=$HOME_LOC/TR2-D2/tr2d2-pep/logs
7
+ DATE=$(date +%m_%d)
8
+ SPECIAL_PREFIX='tr2d2-finetune-tfr'
9
+ # set 3 have skip connection
10
+ PYTHON_EXECUTABLE=$ENV_PATH/bin/python
11
+
12
+ # ===================================================================
13
+
14
+ source "$(conda info --base)/etc/profile.d/conda.sh"
15
+ conda activate $ENV_PATH
16
+
17
+ # ===================================================================
18
+
19
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/finetune.py \
20
+ --base_path $HOME_LOC \
21
+ --device "cuda:6" \
22
+ --noise_removal \
23
+ --wdce_num_replicates 16 \
24
+ --buffer_size 20 \
25
+ --seq_length 200 \
26
+ --num_children 50 \
27
+ --total_num_steps 128 \
28
+ --num_iter 10 \
29
+ --resample_every_n_step 10 \
30
+ --num_epochs 1000 \
31
+ --exploration 0.1 \
32
+ --save_every_n_epoch 50 \
33
+ --reset_every_n_step 1 \
34
+ --alpha 0.1 \
35
+ --grad_clip > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}.log" 2>&1
36
+
37
+ conda deactivate
tr2d2-pep/finetune_peptides.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # direct reward backpropagation
2
+ import numpy as np
3
+ import torch
4
+ import wandb
5
+ import os
6
+ from finetune_utils import loss_wdce
7
+ from tqdm import tqdm
8
+ import pandas as pd
9
+ from plotting import plot_data_with_distribution_seaborn, plot_data
10
+
11
+ def finetune(args, cfg, policy_model, reward_model, mcts=None, pretrained=None, filename=None, prot_name=None, eps=1e-5):
12
+ """
13
+ Finetuning with WDCE loss
14
+ """
15
+ base_path = args.base_path
16
+ dt = (1 - eps) / args.total_num_steps
17
+
18
+ if args.no_mcts:
19
+ assert pretrained is not None, "pretrained model is required for no mcts"
20
+ else:
21
+ assert mcts is not None, "mcts is required for mcts"
22
+
23
+ # set model to train mode
24
+ policy_model.train()
25
+ torch.set_grad_enabled(True)
26
+ optim = torch.optim.AdamW(policy_model.parameters(), lr=args.learning_rate)
27
+
28
+ # record metrics
29
+ batch_losses = []
30
+ #batch_rewards = []
31
+
32
+ # initialize the final seqs and log_rnd of the trajectories that generated those seqs
33
+ x_saved, log_rnd_saved, final_rewards_saved = None, None, None
34
+
35
+ valid_fraction_log = []
36
+ affinity_log = []
37
+ sol_log = []
38
+ hemo_log = []
39
+ nf_log = []
40
+ permeability_log = []
41
+
42
+ ### End of Fine-Tuning Loop ###
43
+ pbar = tqdm(range(args.num_epochs))
44
+
45
+ for epoch in pbar:
46
+ # store metrics
47
+ rewards = []
48
+ losses = []
49
+
50
+ policy_model.train()
51
+
52
+ with torch.no_grad():
53
+ if x_saved is None or epoch % args.resample_every_n_step == 0:
54
+ # compute final sequences and trajectory log_rnd
55
+ if args.no_mcts:
56
+ x_final, log_rnd, final_rewards = policy_model.sample_finetuned_with_rnd(args, reward_model, pretrained)
57
+ else:
58
+ # decides whether to reset tree
59
+ if (epoch) % args.reset_every_n_step == 0:
60
+ x_final, log_rnd, final_rewards, _, _ = mcts.forward(resetTree=True)
61
+ else:
62
+ x_final, log_rnd, final_rewards, _, _ = mcts.forward(resetTree=False)
63
+
64
+ # save for next iteration
65
+ x_saved, log_rnd_saved, final_rewards_saved = x_final, log_rnd, final_rewards
66
+ else:
67
+ x_final, log_rnd, final_rewards = x_saved, log_rnd_saved, final_rewards_saved
68
+
69
+ # compute wdce loss
70
+ loss = loss_wdce(policy_model, log_rnd, x_final, num_replicates=args.wdce_num_replicates, centering=args.centering)
71
+
72
+ # gradient descent
73
+ loss.backward()
74
+
75
+ # optimizer
76
+ if args.grad_clip:
77
+ torch.nn.utils.clip_grad_norm_(policy_model.parameters(), args.gradnorm_clip)
78
+
79
+ optim.step()
80
+ optim.zero_grad()
81
+
82
+ pbar.set_postfix(loss=loss.item())
83
+
84
+ # sample a eval batch with updated policy to evaluate rewards
85
+ x_eval, affinity, sol, hemo, nf, permeability, valid_fraction = policy_model.sample_finetuned(args, reward_model, batch_size=50, dataframe=False)
86
+
87
+ # append to log
88
+ affinity_log.append(affinity)
89
+ sol_log.append(sol)
90
+ hemo_log.append(hemo)
91
+ nf_log.append(nf)
92
+ permeability_log.append(permeability)
93
+ valid_fraction_log.append(valid_fraction)
94
+
95
+ batch_losses.append(loss.cpu().detach().numpy())
96
+
97
+ losses.append(loss.cpu().detach().numpy())
98
+ losses = np.array(losses)
99
+
100
+ if args.no_mcts:
101
+ mean_reward_search = final_rewards.mean().item()
102
+ min_reward_search = final_rewards.min().item()
103
+ max_reward_search = final_rewards.max().item()
104
+ median_reward_search = final_rewards.median().item()
105
+ else:
106
+ mean_reward_search = np.mean(final_rewards)
107
+ min_reward_search = np.min(final_rewards)
108
+ max_reward_search = np.max(final_rewards)
109
+ median_reward_search = np.median(final_rewards)
110
+
111
+ print("epoch %d"%epoch, "affinity %f"%np.mean(affinity), "sol %f"%np.mean(sol), "hemo %f"%np.mean(hemo), "nf %f"%np.mean(nf), "permeability %f"%np.mean(permeability), "mean loss %f"%np.mean(losses))
112
+
113
+ wandb.log({"epoch": epoch, "affinity": np.mean(affinity), "sol": np.mean(sol), "hemo": np.mean(hemo), "nf": np.mean(nf), "permeability": np.mean(permeability),
114
+ "mean_loss": np.mean(losses),
115
+ "mean_reward_search": mean_reward_search, "min_reward_search": min_reward_search,
116
+ "max_reward_search": max_reward_search, "median_reward_search": median_reward_search})
117
+
118
+ if (epoch+1) % args.save_every_n_epochs == 0:
119
+ model_path = os.path.join(args.save_path, f'model_{epoch}.ckpt')
120
+ torch.save(policy_model.state_dict(), model_path)
121
+ print(f"model saved at epoch {epoch}")
122
+
123
+ ### End of Fine-Tuning Loop ###
124
+
125
+ wandb.finish()
126
+
127
+ # save logs and plot
128
+ plot_path = f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}'
129
+ os.makedirs(plot_path, exist_ok=True)
130
+ output_log_path = f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/log_{filename}.csv'
131
+ save_logs_to_file(valid_fraction_log, affinity_log,
132
+ sol_log, hemo_log, nf_log,
133
+ permeability_log, output_log_path)
134
+
135
+ plot_data(valid_fraction_log,
136
+ save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/valid_{filename}.png')
137
+
138
+ plot_data_with_distribution_seaborn(log1=affinity_log,
139
+ save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/binding_{filename}.png',
140
+ label1=f"Average Binding Affinity to {prot_name}",
141
+ title=f"Average Binding Affinity to {prot_name} Over Iterations")
142
+
143
+ plot_data_with_distribution_seaborn(log1=sol_log,
144
+ save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/sol_{filename}.png',
145
+ label1="Average Solubility Score",
146
+ title="Average Solubility Score Over Iterations")
147
+ plot_data_with_distribution_seaborn(log1=hemo_log,
148
+ save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/hemo_{filename}.png',
149
+ label1="Average Hemolysis Score",
150
+ title="Average Hemolysis Score Over Iterations")
151
+ plot_data_with_distribution_seaborn(log1=nf_log,
152
+ save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/nf_{filename}.png',
153
+ label1="Average Nonfouling Score",
154
+ title="Average Nonfouling Score Over Iterations")
155
+ plot_data_with_distribution_seaborn(log1=permeability_log,
156
+ save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/perm_{filename}.png',
157
+ label1="Average Permeability Score",
158
+ title="Average Permeability Score Over Iterations")
159
+
160
+ x_eval, affinity, sol, hemo, nf, permeability, valid_fraction, df = policy_model.sample_finetuned(args, reward_model, batch_size=200, dataframe=True)
161
+ df.to_csv(f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/{prot_name}_generation_results.csv', index=False)
162
+
163
+ return batch_losses
164
+
165
+ def save_logs_to_file(valid_fraction_log, affinity_log,
166
+ sol_log, hemo_log, nf_log,
167
+ permeability_log, output_path):
168
+ """
169
+ Saves the logs (valid_fraction_log, affinity1_log, and permeability_log) to a CSV file.
170
+
171
+ Parameters:
172
+ valid_fraction_log (list): Log of valid fractions over iterations.
173
+ affinity1_log (list): Log of binding affinity over iterations.
174
+ permeability_log (list): Log of membrane permeability over iterations.
175
+ output_path (str): Path to save the log CSV file.
176
+ """
177
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
178
+
179
+ # Combine logs into a DataFrame
180
+ log_data = {
181
+ "Iteration": list(range(1, len(valid_fraction_log) + 1)),
182
+ "Valid Fraction": valid_fraction_log,
183
+ "Binding Affinity": affinity_log,
184
+ "Solubility": sol_log,
185
+ "Hemolysis": hemo_log,
186
+ "Nonfouling": nf_log,
187
+ "Permeability": permeability_log
188
+ }
189
+
190
+ df = pd.DataFrame(log_data)
191
+
192
+ # Save to CSV
193
+ df.to_csv(output_path, index=False)
tr2d2-pep/finetune_utils.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from torch.utils.data import DataLoader, TensorDataset
4
+ from utils.utils import sample_categorical_logits
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ import torch.distributed as dist
8
+ import torch.nn.functional as F
9
+
10
+ def to_one_hot(x_idx, num_classes=4):
11
+ oh = F.one_hot(x_idx.long(), num_classes=num_classes)
12
+ return oh.float()
13
+
14
+ def rnd(model, reward_model, batch_size, scale=1, device='cuda:0'):
15
+ r"""
16
+ Run random order sampling and compute the RND $\log\frac{dP^*}{dP^u}$ along the trajectory
17
+ reward_model: r(X)
18
+
19
+ return:
20
+ - x: the final samples, [B, D]
21
+ - log_rnd: the log RND along this trajectory, [B]
22
+ """
23
+ if hasattr(model, 'module'):
24
+ model = model.module
25
+
26
+ x = torch.full((batch_size, model.length), model.vocab_size-1).to(device=device, dtype=torch.int64)
27
+ batch_arange = torch.arange(batch_size, device=device)
28
+ jump_pos = torch.rand(x.shape, device=device).argsort(dim=-1)
29
+ # jump_times, jump_pos = torch.rand(x.shape, device=device).sort(dim=-1)
30
+ # jump_times: Unif[0,1] in increasing order
31
+ # jump_pos: random permutation of range(D)
32
+ log_rnd = torch.zeros(batch_size, device=device) # [B]
33
+ for d in range(model.length-1, -1, -1):
34
+ # jump at time jump_times[:, d] at position jump_pos[:, d]
35
+ logits = model(x)[:, :, :-1] # [B, D, N-1]
36
+ update = sample_categorical_logits(
37
+ logits[batch_arange, jump_pos[:, d]]) # [B]
38
+ if torch.is_grad_enabled(): # avoid issues with in-place operations
39
+ x = x.clone()
40
+ x[batch_arange, jump_pos[:, d]] = update
41
+ log_rnd += -np.log(model.vocab_size-1) - logits[batch_arange, jump_pos[:, d], update]
42
+ log_rnd += scale * reward_model(x) # [B]
43
+ return x, log_rnd
44
+
45
+
46
+ @torch.no_grad()
47
+ def sampling(model, batch_size, rounds=1, device='cuda:0'):
48
+ """Any order autoregressive sampling"""
49
+ if hasattr(model, 'module'):
50
+ model = model.module
51
+ batch_arange = torch.arange(batch_size, device=device)
52
+ all_samples = []
53
+ for _ in tqdm(range(rounds), leave=False):
54
+ x = torch.full((batch_size, model.length), model.vocab_size-1).to(device=device, dtype=torch.int64)
55
+ jump_pos = torch.rand(x.shape, device=device).argsort(dim=-1)
56
+ # jump_times, jump_pos = torch.rand(x.shape, device=device).sort(dim=-1)
57
+ # jump_times: Unif[0,1] in increasing order
58
+ # jump_pos: random permutation of range(D)
59
+ for d in tqdm(range(model.length-1, -1, -1), leave=False):
60
+ # jump at time jump_times[:, d] at position jump_pos[:, d]
61
+ logits = model.logits(x)[:, :, :-1] # [B, D, N-1], not log-softmaxed but fine
62
+ update = sample_categorical_logits(
63
+ logits[batch_arange, jump_pos[:, d]]) # [B]
64
+ x[batch_arange, jump_pos[:, d]] = update
65
+ all_samples.append(x)
66
+ return torch.cat(all_samples) # (rounds * B, L)
67
+
68
+
69
+ def loss_ce(log_rnd):
70
+ """Cross entropy loss KL(P^*||P^u)"""
71
+ weights = log_rnd.detach().softmax(dim=-1)
72
+ return (log_rnd * weights).sum()
73
+
74
+
75
+ def loss_lv(log_rnd):
76
+ r"""Log variance loss Var_{P^\bar{u}}\log\frac{dP^*}{dP^u}"""
77
+ return log_rnd.var()
78
+
79
+
80
+ def loss_re_rf(log_rnd, const=0):
81
+ r"""Relative entropy loss KL(P^u||P^*) with REINFORCE trick"""
82
+ return (-log_rnd * (-log_rnd.detach() + const)).mean()
83
+
84
+
85
+ def loss_wdce(policy_model, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False):
86
+ r"""
87
+ Weighted denoising cross entropy loss
88
+ X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
89
+
90
+ log_rnd: [B]; x: [B, L] (no mask)
91
+ num_replicates: R, number of replicates of each row in x
92
+ weight_func: w(lambda) for each sample, 1/lambda by default
93
+ """
94
+ mask_index = policy_model.mask_index
95
+ if hasattr(policy_model, 'module'):
96
+ policy_model = policy_model.module
97
+
98
+ batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
99
+
100
+ batch_weights = log_rnd.detach_().softmax(dim=-1) # [B*R]
101
+ if centering:
102
+ batch_weights = batch_weights - batch_weights.mean(dim=-1, keepdim=True)
103
+
104
+ batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
105
+
106
+ lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
107
+ lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
108
+
109
+ masked_index = torch.rand(*batch.shape, device=batch.device) < lamda[..., None] # [B*R, D]
110
+ perturbed_batch = torch.where(masked_index, mask_index, batch)
111
+
112
+ # add time conditioning
113
+ t = lamda
114
+ sigma_t = -torch.log1p(-(1 - eps) * t)
115
+ attn_mask = torch.ones_like(perturbed_batch).to(policy_model.device)
116
+
117
+ # compute logits
118
+ logits = policy_model(perturbed_batch, attn_mask=attn_mask, sigma=sigma_t)
119
+ losses = torch.zeros(*batch.shape, device=batch.device, dtype=logits.dtype) # [B*R, D]
120
+ losses[masked_index] = torch.gather(input=logits[masked_index], dim=-1,
121
+ index=batch[masked_index][..., None]).squeeze(-1)
122
+ return - (losses.sum(dim=-1) * lamda_weights * batch_weights).mean()
123
+
124
+
125
+ def loss_dce(model, x, weight_func=lambda l: 1/l):
126
+ r"""
127
+ Denoising cross entropy loss, x [B, D] are ground truth samples
128
+ weight_func: w(lambda) for each sample, 1/lambda by default
129
+ """
130
+ lamda = torch.rand(x.shape[0], device=x.device) # [B]
131
+ lamda_weights = weight_func(lamda).clamp(max=1e5) # [B]
132
+ masked_index = torch.rand(*x.shape, device=x.device) < lamda[..., None] # [B, D]
133
+ perturbed_batch = torch.where(masked_index, model.vocab_size-1, x)
134
+ logits = model(perturbed_batch)
135
+ losses = torch.zeros(*x.shape, device=x.device, dtype=logits.dtype) # [B, D]
136
+ losses[masked_index] = torch.gather(input=logits[masked_index], dim=-1,
137
+ index=x[masked_index][..., None]).squeeze(-1)
138
+ return - (losses.sum(dim=-1) * lamda_weights).mean()
tr2d2-pep/generate_mcts.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import math
5
+ import random
6
+ import sys
7
+ from diffusion import Diffusion
8
+ import hydra
9
+ from tqdm import tqdm
10
+ import matplotlib.pyplot as plt
11
+ import os
12
+ import seaborn as sns
13
+ import pandas as pd
14
+ import numpy as np
15
+ import argparse
16
+ # direct reward backpropagation
17
+ from diffusion import Diffusion
18
+ from hydra import initialize, compose
19
+ from hydra.core.global_hydra import GlobalHydra
20
+ import numpy as np
21
+ import torch
22
+ import argparse
23
+ import os
24
+ import datetime
25
+ from utils.utils import str2bool, set_seed
26
+
27
+ # for peptides
28
+ from utils.app import PeptideAnalyzer
29
+ from peptide_mcts import MCTS
30
+
31
+ @torch.no_grad()
32
+ def generate_mcts(args, cfg, policy_model, pretrained, prot=None, prot_name=None, filename=None):
33
+
34
+ score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
35
+
36
+ mcts = MCTS(args, cfg, policy_model, pretrained, score_func_names, prot_seqs=[prot])
37
+
38
+ final_x, log_rnd, final_rewards, score_vectors, sequences = mcts.forward()
39
+
40
+ return final_x, log_rnd, final_rewards, score_vectors, sequences
41
+
42
+ def save_logs_to_file(reward_log, logrnd_log,
43
+ valid_fraction_log, affinity1_log,
44
+ sol_log, hemo_log, nf_log,
45
+ permeability_log, output_path):
46
+ """
47
+ Saves the logs (valid_fraction_log, affinity1_log, and permeability_log) to a CSV file.
48
+
49
+ Parameters:
50
+ valid_fraction_log (list): Log of valid fractions over iterations.
51
+ affinity1_log (list): Log of binding affinity over iterations.
52
+ permeability_log (list): Log of membrane permeability over iterations.
53
+ output_path (str): Path to save the log CSV file.
54
+ """
55
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
56
+
57
+ # Combine logs into a DataFrame
58
+ log_data = {
59
+ "Iteration": list(range(1, len(valid_fraction_log) + 1)),
60
+ "Reward": reward_log,
61
+ "Log RND": logrnd_log,
62
+ "Valid Fraction": valid_fraction_log,
63
+ "Binding Affinity": affinity1_log,
64
+ "Solubility": sol_log,
65
+ "Hemolysis": hemo_log,
66
+ "Nonfouling": nf_log,
67
+ "Permeability": permeability_log
68
+ }
69
+
70
+ df = pd.DataFrame(log_data)
71
+
72
+ # Save to CSV
73
+ df.to_csv(output_path, index=False)
74
+
75
+ argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
76
+ argparser.add_argument('--base_path', type=str, default='')
77
+ argparser.add_argument('--learning_rate', type=float, default=1e-4)
78
+ argparser.add_argument('--num_epochs', type=int, default=1000)
79
+ argparser.add_argument('--num_accum_steps', type=int, default=4)
80
+ argparser.add_argument('--truncate_steps', type=int, default=50)
81
+ argparser.add_argument("--truncate_kl", type=str2bool, default=False)
82
+ argparser.add_argument('--gumbel_temp', type=float, default=1.0)
83
+ argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
84
+ argparser.add_argument('--batch_size', type=int, default=32)
85
+ argparser.add_argument('--name', type=str, default='debug')
86
+ argparser.add_argument('--total_num_steps', type=int, default=128)
87
+ argparser.add_argument('--copy_flag_temp', type=float, default=None)
88
+ argparser.add_argument('--save_every_n_epochs', type=int, default=50)
89
+ argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
90
+ argparser.add_argument("--seed", type=int, default=0)
91
+ # new
92
+ argparser.add_argument('--run_name', type=str, default='drakes')
93
+ argparser.add_argument("--device", default="cuda", type=str)
94
+
95
+ # mcts
96
+ argparser.add_argument('--num_sequences', type=int, default=100)
97
+ argparser.add_argument('--num_children', type=int, default=20)
98
+ argparser.add_argument('--num_iter', type=int, default=100) # iterations of mcts
99
+ argparser.add_argument('--seq_length', type=int, default=200)
100
+ argparser.add_argument('--time_conditioning', action='store_true', default=False)
101
+ argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
102
+ argparser.add_argument('--buffer_size', type=int, default=100)
103
+ argparser.add_argument('--wdce_num_replicates', type=int, default=16)
104
+ argparser.add_argument('--noise_removal', action='store_true', default=False)
105
+ argparser.add_argument('--exploration', type=float, default=0.1)
106
+ argparser.add_argument('--reset_every_n_step', type=int, default=100)
107
+ argparser.add_argument('--alpha', type=float, default=0.01)
108
+ argparser.add_argument('--scalarization', type=str, default='sum')
109
+ argparser.add_argument('--no_mcts', action='store_true', default=False)
110
+ argparser.add_argument("--centering", action='store_true', default=False)
111
+ argparser.add_argument('--num_obj', type=int, default=5)
112
+
113
+ argparser.add_argument('--prot_seq', type=str, default=None)
114
+ argparser.add_argument('--prot_name', type=str, default=None)
115
+
116
+ args = argparser.parse_args()
117
+ print(args)
118
+
119
+ # pretrained model path
120
+ ckpt_path = f'{args.base_path}/TR2-D2/tr2d2-pep/pretrained/peptune-pretrained.ckpt'
121
+
122
+ # reinitialize Hydra
123
+ GlobalHydra.instance().clear()
124
+
125
+ # Initialize Hydra and compose the configuration
126
+ initialize(config_path="configs", job_name="load_model")
127
+ cfg = compose(config_name="config.yaml")
128
+ curr_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
129
+
130
+ set_seed(args.seed, use_cuda=True)
131
+
132
+ # proteins
133
+ amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
134
+ tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
135
+ gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
136
+ glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
137
+ glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
138
+ ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
139
+ cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL'
140
+ ligase = 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS'
141
+ skp2 = 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL'
142
+
143
+ if args.prot_seq is not None:
144
+ prot = args.prot_seq
145
+ prot_name = args.prot_name
146
+ filename = args.prot_name
147
+ else:
148
+ prot = tfr
149
+ prot_name = "tfr"
150
+ filename = "tfr"
151
+
152
+ # Initialize the model
153
+ new_model = Diffusion.load_from_checkpoint(ckpt_path, config=cfg, strict=False, map_location=args.device)
154
+ old_model = Diffusion.load_from_checkpoint(ckpt_path, config=cfg, strict=False, map_location=args.device)
155
+
156
+ with torch.no_grad():
157
+ final_x, log_rnd, final_rewards, score_vectors, sequences = generate_mcts(args, cfg, new_model, old_model,
158
+ prot=prot, prot_name=prot_name)
159
+
160
+ final_x = final_x.detach().to('cpu') # [B, L] integer tokens
161
+ log_rnd = log_rnd.detach().to('cpu').float().view(-1) # [B]
162
+ #final_rewards = final_rewards.detach().to('cpu').float().view(-1) # [B]
163
+
164
+ print("loaded models...")
165
+ analyzer = PeptideAnalyzer()
166
+
167
+ generation_results = []
168
+
169
+ for i in range(final_x.shape[0]):
170
+ sequence = sequences[i]
171
+ log_rnd_single = log_rnd[i]
172
+ final_reward = final_rewards[i]
173
+
174
+ aa_seq, seq_length = analyzer.analyze_structure(sequence)
175
+
176
+ scores = score_vectors[i]
177
+
178
+ binding1 = scores[0]
179
+ solubility = scores[1]
180
+ hemo = scores[2]
181
+ nonfouling = scores[3]
182
+ permeability = scores[4]
183
+
184
+ generation_results.append([sequence, aa_seq, final_reward, log_rnd_single, binding1, solubility, hemo, nonfouling, permeability])
185
+ print(f"length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity: {binding1} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling} | Permeability: {permeability}")
186
+
187
+ sys.stdout.flush()
188
+
189
+ df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Peptide Sequence', 'Final Reward', 'Log RND', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling', 'Permeability'])
190
+
191
+
192
+ df.to_csv(f'{args.base_path}/TR2-D2/tr2d2-pep/plots/{prot_name}-peptune-baseline/generation_results.csv', index=False)
tr2d2-pep/metrics.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from math import sqrt
3
+
4
+ def summarize_metrics(skip, csv_path: str, save_path: str | None = None) -> pd.DataFrame:
5
+ """
6
+ Compute mean and standard deviation for all columns except the first
7
+ (assumed non-numeric identifier like 'Peptide Sequence').
8
+
9
+ Returns a DataFrame with rows = column names and columns = ['mean','std','count'].
10
+ Uses sample std (ddof=1). Non-numeric cells are coerced to NaN.
11
+ """
12
+ df = pd.read_csv(csv_path)
13
+ vals = df.iloc[:, skip:].apply(pd.to_numeric, errors='coerce') # columns 2..end
14
+ stats = vals.agg(['mean', 'std', 'count']).T # shape: (num_metrics, 3)
15
+ if save_path:
16
+ stats.to_csv(save_path, index=True)
17
+ return stats
18
+
19
+ def summarize_list(xs, ddof = 1):
20
+ # Clean & coerce to float
21
+ vals = []
22
+ for x in xs:
23
+ if x is None or x == "":
24
+ continue
25
+ try:
26
+ vals.append(float(x))
27
+ except (TypeError, ValueError):
28
+ continue
29
+
30
+ n = len(vals)
31
+ if n == 0:
32
+ raise ValueError("No numeric values found.")
33
+ if n <= ddof:
34
+ raise ValueError(f"Need at least {ddof + 1} numeric values; got {n}.")
35
+
36
+ # Welford’s algorithm (one pass, stable)
37
+ mean = 0.0
38
+ M2 = 0.0
39
+ count = 0
40
+ for v in vals:
41
+ count += 1
42
+ delta = v - mean
43
+ mean += delta / count
44
+ M2 += delta * (v - mean)
45
+
46
+ var = M2 / (count - ddof)
47
+ std = sqrt(var)
48
+
49
+ result = {"mean": mean, "std": std, "count": count}
50
+
51
+ return result
52
+
53
+ def csv_column_to_list(path: str, column: str, *, dropna: bool = True):
54
+ df = pd.read_csv(path)
55
+ if column not in df.columns:
56
+ raise KeyError(f"Column '{column}' not found. Available: {list(df.columns)}")
57
+ s = df[column]
58
+ if dropna:
59
+ s = s.dropna()
60
+ return s.tolist()
61
+
62
+ def main():
63
+ path = ""
64
+ prot_name = ""
65
+ stats = summarize_metrics(skip=1, csv_path=f"{path}/{prot_name}_generation_results.csv",
66
+ save_path=f"{path}/results_summary.csv")
67
+
68
+ print(stats)
69
+
70
+ if __name__ == '__main__':
71
+ main()
tr2d2-pep/noise_schedule.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ torch._C._jit_set_profiling_mode(False)
7
+ torch._C._jit_set_profiling_executor(False)
8
+ torch._C._jit_override_can_fuse_on_cpu(True)
9
+ torch._C._jit_override_can_fuse_on_gpu(True)
10
+
11
+ def get_noise(config, dtype=torch.float32):
12
+ if config.noise.type == 'geometric':
13
+ return GeometricNoise(config.noise.sigma_min, config.noise.sigma_max)
14
+ elif config.noise.type == 'loglinear':
15
+ return LogLinearNoise()
16
+ elif config.noise.type == 'cosine':
17
+ return CosineNoise()
18
+ elif config.noise.type == 'cosinesqr':
19
+ return CosineSqrNoise()
20
+ elif config.noise.type == 'linear':
21
+ return Linear(config.noise.sigma_min, config.noise.sigma_max, dtype)
22
+ else:
23
+ raise ValueError(f'{config.noise.type} is not a valid noise')
24
+
25
+
26
+ def binary_discretization(z):
27
+ z_hard = torch.sign(z)
28
+ z_soft = z / torch.norm(z, dim=-1, keepdim=True)
29
+ return z_soft + (z_hard - z_soft).detach()
30
+
31
+
32
+ class Noise(abc.ABC, nn.Module):
33
+ """
34
+ Baseline forward method to get the total + rate of noise at a timestep
35
+ """
36
+ def forward(self, t):
37
+ # Assume time goes from 0 to 1
38
+ return self.total_noise(t), self.rate_noise(t)
39
+
40
+
41
+ class CosineNoise(Noise):
42
+ def __init__(self, eps=1e-3):
43
+ super().__init__()
44
+ self.eps = eps
45
+
46
+ def rate_noise(self, t):
47
+ cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
48
+ sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
49
+ scale = torch.pi / 2
50
+ return scale * sin / (cos + self.eps)
51
+
52
+ def total_noise(self, t):
53
+ cos = torch.cos(t * torch.pi / 2)
54
+ return - torch.log(self.eps + (1 - self.eps) * cos)
55
+
56
+
57
+ class CosineSqrNoise(Noise):
58
+ def __init__(self, eps=1e-3):
59
+ super().__init__()
60
+ self.eps = eps
61
+
62
+ def rate_noise(self, t):
63
+ cos = (1 - self.eps) * (
64
+ torch.cos(t * torch.pi / 2) ** 2)
65
+ sin = (1 - self.eps) * torch.sin(t * torch.pi)
66
+ scale = torch.pi / 2
67
+ return scale * sin / (cos + self.eps)
68
+
69
+ def total_noise(self, t):
70
+ cos = torch.cos(t * torch.pi / 2) ** 2
71
+ return - torch.log(self.eps + (1 - self.eps) * cos)
72
+
73
+
74
+ class Linear(Noise):
75
+ def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
76
+ super().__init__()
77
+ self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
78
+ self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
79
+
80
+ def rate_noise(self):
81
+ return self.sigma_max - self.sigma_min
82
+
83
+ def total_noise(self, t):
84
+ return self.sigma_min + t * (self.sigma_max - self.sigma_min)
85
+
86
+ def importance_sampling_transformation(self, t):
87
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
88
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
89
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
90
+ return (sigma_t - self.sigma_min) / (
91
+ self.sigma_max - self.sigma_min)
92
+
93
+
94
+ class GeometricNoise(Noise):
95
+ def __init__(self, sigma_min=1e-3, sigma_max=1):
96
+ super().__init__()
97
+ self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
98
+
99
+ def rate_noise(self, t):
100
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
101
+ self.sigmas[1].log() - self.sigmas[0].log())
102
+
103
+ def total_noise(self, t):
104
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
105
+
106
+
107
+ class LogLinearNoise(Noise):
108
+ """Log Linear noise schedule.
109
+
110
+ Built such that 1 - 1/e^(n(t)) interpolates between 0 and
111
+ ~1 when t varies from 0 to 1. Total noise is
112
+ -log(1 - (1 - eps) * t), so the sigma will be
113
+ (1 - eps) * t.
114
+ """
115
+ def __init__(self, eps=1e-3):
116
+ super().__init__()
117
+ self.eps = eps
118
+ self.sigma_max = self.total_noise(torch.tensor(1.0))
119
+ self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
120
+
121
+ def rate_noise(self, t):
122
+ return (1 - self.eps) / (1 - (1 - self.eps) * t)
123
+
124
+ def total_noise(self, t):
125
+ return -torch.log1p(-(1 - self.eps) * t)
126
+
127
+ def importance_sampling_transformation(self, t):
128
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
129
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
130
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
131
+ t = - torch.expm1(- sigma_t) / (1 - self.eps)
132
+ return t
133
+
134
+ class LogPolyNoise(Noise):
135
+ """
136
+ Log Polynomial noise schedule for slower masking of peptide bond tokens
137
+ """
138
+ def __init__(self, eps=1e-3):
139
+ super().__init__()
140
+ self.eps = eps
141
+ self.sigma_max = self.total_noise(torch.tensor(1.0))
142
+ self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
143
+
144
+ def rate_noise(self, t):
145
+ # derivative of -log(1-t^w)
146
+ return ((3 * (t**2)) - self.eps) / (1 - (1 - self.eps) * (t**3))
147
+
148
+ def total_noise(self, t):
149
+ # -log(1-t^w)
150
+ return -torch.log1p(-(1 - self.eps) * (t**3))
tr2d2-pep/peptide_mcts.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import random as rd
6
+ from utils.app import PeptideAnalyzer
7
+ from utils.timer import StepTimer
8
+ from scoring.scoring_functions import ScoringFunctions
9
+
10
+ import noise_schedule
11
+
12
+ ### for peptide multi-objective ###
13
+ def dominates(a, b):
14
+ a = np.asarray(a); b = np.asarray(b)
15
+ return np.all(a >= b) and np.any(a > b)
16
+
17
+ def dominated_by(a, b):
18
+ return dominates(b, a)
19
+
20
+
21
+ def updateParetoFront(paretoFront, node, scoreVector, totalSize=None, eps=1e-12):
22
+ """
23
+ Maintain a non-dominated set (Pareto front) of (node -> scoreVector).
24
+
25
+ - Accept 'node' iff it is NOT dominated by any node in the set.
26
+ - Remove any nodes that ARE dominated by 'node'.
27
+ - Skip insertion if an equal point already exists (within eps).
28
+ - If totalSize is given and the archive exceeds it, drop the item
29
+ with the smallest sum(scoreVector) as a simple tie-breaker.
30
+
31
+ Args:
32
+ paretoFront (dict): {node: scoreVector}
33
+ node: candidate node (used as dict key)
34
+ scoreVector (array-like): candidate scores (to be maximized)
35
+ totalSize (int|None): optional max size for the archive
36
+ eps (float): tolerance for equality/inequality checks
37
+
38
+ Returns:
39
+ dict: updated paretoFront
40
+ """
41
+ s = np.asarray(scoreVector, dtype=float)
42
+
43
+ def dominates(a, b):
44
+ # a >= b in all coords and > in at least one (with tolerance)
45
+ return np.all(a >= b - eps) and np.any(a > b + eps)
46
+
47
+ def equal(a, b):
48
+ return np.all(np.abs(a - b) <= eps)
49
+
50
+ # reject if candidate is dominated by any node already in the set
51
+ for v in paretoFront.values():
52
+ v = np.asarray(v, dtype=float)
53
+ if dominates(v, s):
54
+ return paretoFront # no change
55
+
56
+ # remove any nodes dominated by candidate node
57
+ survivors = {}
58
+ #has_equal = False
59
+ for k, v in paretoFront.items():
60
+ v_arr = np.asarray(v, dtype=float)
61
+ if dominates(s, v_arr):
62
+ continue # drop dominated incumbent
63
+ """if equal(s, v_arr):
64
+ has_equal = True # skip duplicate insertion later"""
65
+ survivors[k] = v_arr
66
+
67
+ # if an equal point exists, keep survivors as-is (no duplicate)
68
+ """if has_equal:
69
+ return survivors"""
70
+
71
+ # insert node
72
+ survivors[node] = s
73
+
74
+ # delete nodes if larger than total size
75
+ if totalSize is not None and totalSize > 0 and len(survivors) > totalSize:
76
+ # remove the item with the smallest sum(scoreVector)
77
+ keys = list(survivors.keys())
78
+ sums = np.array([np.sum(np.asarray(survivors[k], dtype=float)) for k in keys])
79
+ drop_idx = int(np.argmin(sums))
80
+ del survivors[keys[drop_idx]]
81
+
82
+ return survivors
83
+
84
+ ### BEGINNING OF NODE CLASS ###
85
+
86
+ class Node:
87
+ """
88
+ Node class: partially unmasked sequence
89
+ - parentNode: Node object at previous time step
90
+ - childNodes: set of M Node objects generated from sampling M distinct unmasking schemes
91
+ - totalReward: vector of cumulative rewards for all K objectives
92
+ - visits: number of times the node has been visited by an interation
93
+ - path: array of partially unmasked SMILES strings leading to the node from the completely masked root node
94
+ - timestep: the time step where the sequence was sampled
95
+ """
96
+ def __init__(self, args, tokens=None, log_rnd=None, log_policy_step=None, log_pretrained_step=None, parentNode=None, childNodes=None, totalReward=None, timestep=None):
97
+ self.args = args
98
+ self.parentNode = parentNode
99
+ # fixed child node list creation
100
+ self.childNodes = [] if childNodes is None else childNodes
101
+
102
+ self.log_rnd = log_rnd # stores the log_rnd up to that step
103
+
104
+ #self.log_p0 = 0 # stores the log probabiltiy of the unmasking step from the previous iteration
105
+ self.log_policy_step = log_policy_step # stores the log probability of the unmasking step under the current policy
106
+ self.log_pretrained_step = log_pretrained_step
107
+
108
+ # initialize total rewards to the reward of the roll out unmasked sequence
109
+ if totalReward is not None:
110
+ self.totalReward = totalReward # potential reward of the node based on generated children
111
+ else:
112
+ self.totalReward = np.zeros(self.args.num_obj)
113
+
114
+ # set initial visits to 1
115
+ self.visits = 1
116
+
117
+ # set timestep (value between 0 and num_steps)
118
+ self.timestep = timestep
119
+
120
+ # dict with 'seqs' as token array and 'attention_mask'
121
+ self.tokens = tokens
122
+
123
+ def selectNode(self):
124
+ """
125
+ Selects a node to move to among the children nodes based on select score
126
+ """
127
+ # extract the status of the current node
128
+ nodeStatus = self.getExpandStatus()
129
+
130
+ # if the node is a legal non-leaf node
131
+ if (nodeStatus == 3):
132
+ # initialize array that will store select score vectors of each child node
133
+
134
+ paretoFront = {}
135
+
136
+ for childNode in self.childNodes:
137
+ childStatus = childNode.getExpandStatus()
138
+ # only append child if it is legal leaf node (expandable) or legal non-leaf node
139
+ if childStatus == 2 or childStatus == 3:
140
+ selectScore = childNode.calcSelectScore()
141
+ paretoFront = updateParetoFront(paretoFront, childNode, selectScore)
142
+
143
+ selected = rd.choice(list(paretoFront.keys()))
144
+
145
+ # return selected child node and status
146
+ return selected, selected.getExpandStatus()
147
+
148
+ # if node is not valid non-leaf node
149
+ return self, nodeStatus
150
+
151
+ def addChildNode(self, tokens, log_rnd, log_policy_step, log_pretrained_step, totalReward):
152
+ """"
153
+ Adds a child node:
154
+ log_rnd: log_rnd of the path up to the added child node
155
+ log_policy_step: scalar value of the log-prob of sampling the step under the policy
156
+ log_pretrained_step: scalar value of the log-prob of sampling the step under the pretrained model
157
+ """
158
+ child = Node(args=self.args,
159
+ tokens=tokens,
160
+ log_rnd = log_rnd,
161
+ log_policy_step=log_policy_step,
162
+ log_pretrained_step=log_pretrained_step,
163
+ parentNode=self,
164
+ childNodes=[],
165
+ totalReward=totalReward,
166
+ timestep=self.timestep+1)
167
+
168
+ self.childNodes.append(child)
169
+ return child
170
+
171
+ def update_logrnd(self, log_policy_step, log_rnd):
172
+ self.log_policy_step = log_policy_step
173
+ self.log_rnd = log_rnd
174
+
175
+ def updateNode(self, rewards):
176
+ """
177
+ Updates the cumulative rewards vector with the reward vector at a descendent leaf node.
178
+ Increments the number of visits to the node.
179
+ """
180
+ self.visits += 1
181
+
182
+ self.totalReward += rewards # singleton tensor
183
+
184
+ def calcSelectScore(self):
185
+ """
186
+ Calculates the select score for the node from the cumulative rewards vector and number of visits.
187
+ - c: determines the degree of exploration
188
+ - minSelectScore: determines the
189
+ """
190
+ scaling = 0.1 # scaling of the second term in the select score
191
+
192
+ # K-dimensional vector of normalized rewards for each objective
193
+ normRewards = self.totalReward / self.visits
194
+
195
+ # scales the cumulative reward by the sampling probability
196
+
197
+ return normRewards + (scaling * self.log_policy_step.detach().cpu().item() * np.sqrt(self.parentNode.visits) / self.visits)
198
+
199
+ def getExpandStatus(self):
200
+ """
201
+ Returns an integer indicating whether the node is a:
202
+ 1. terminal node (sequence is fully unmasked)
203
+ 2. legal leaf node (partially unmasked sequence that can be expanded)
204
+ 3. legal non-leaf node (already expanded sequence with M child nodes)
205
+ """
206
+ if self.timestep == self.args.total_num_steps:
207
+ return 1
208
+ elif (self.timestep < self.args.total_num_steps) and (len(self.childNodes) == 0):
209
+ return 2
210
+ return 3
211
+
212
+ ### END OF NODE CLASS ###
213
+
214
+ ### BEGINNING OF MCTS CLASS ###
215
+
216
+ class MCTS:
217
+ def __init__(self, args, config, policy_model, pretrained, score_func_names=[], prot_seqs=None, rootNode=None):
218
+ self.timer = StepTimer(policy_model.device)
219
+
220
+ self.device = policy_model.device
221
+
222
+ self.args = args
223
+ self.config = config
224
+ self.noise = noise_schedule.get_noise(config)
225
+ self.time_conditioning = args.time_conditioning
226
+
227
+ self.num_obj = len(score_func_names)
228
+
229
+ self.mask_index = policy_model.mask_index
230
+ masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index
231
+ masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)}
232
+ if rootNode is None:
233
+ self.rootNode = Node(self.args, tokens = masked_tokens,
234
+ log_rnd=torch.zeros((), device=self.device),
235
+ log_policy_step=torch.zeros((), device=self.device),
236
+ log_pretrained_step=torch.zeros((), device=self.device),
237
+ totalReward=np.zeros(self.num_obj), timestep=0)
238
+ else:
239
+ self.rootNode = rootNode # stores the root node of the tree
240
+
241
+ # dictionary:
242
+ # "seq": final unmasked sequence
243
+ # "traj": list of (N_steps, L)
244
+ # "reward": reward of the trajectory
245
+ self.buffer = [] # List[Dict[str, Any]]
246
+
247
+ self.buffer_size = args.buffer_size
248
+
249
+ self.num_steps = args.total_num_steps
250
+ #self.num_sequences = args.num_sequences
251
+
252
+ # pretrained model
253
+ self.pretrained = pretrained
254
+
255
+ # the policy model that we want to finetune
256
+ self.policy_model = policy_model
257
+ #self.tokenizer = policy_model.tokenizer
258
+ self.device = policy_model.device
259
+
260
+ self.sequence_length = args.seq_length
261
+
262
+ self.num_iter = args.num_iter
263
+
264
+ self.num_children = args.num_children
265
+
266
+ # score functions
267
+
268
+ self.rewardFunc = ScoringFunctions(score_func_names, prot_seqs, device=args.device)
269
+
270
+ self.iter_num = 0
271
+
272
+ self.reward_log = [] # stores scalarized total rewards
273
+ self.logrnd_log = []
274
+ # stores each objective
275
+ self.valid_fraction_log = []
276
+ self.affinity1_log = []
277
+ self.affinity2_log = []
278
+ self.permeability_log = []
279
+ self.sol_log = []
280
+ self.hemo_log = []
281
+ self.nf_log = []
282
+
283
+ self.policy_model.eval()
284
+ self.pretrained.eval()
285
+
286
+ # for peptides
287
+ self.analyzer = PeptideAnalyzer()
288
+ self.tokenizer = policy_model.tokenizer
289
+
290
+
291
+ def reset(self, resetTree):
292
+ self.iter_num = 0
293
+ self.buffer = []
294
+ self.reward_log = []
295
+ self.logrnd_log = []
296
+
297
+ # reset logs for each objective
298
+ self.valid_fraction_log = []
299
+ self.affinity1_log = []
300
+ self.affinity2_log = []
301
+ self.permeability_log = []
302
+ self.sol_log = []
303
+ self.hemo_log = []
304
+ self.nf_log = []
305
+
306
+ # add option to continue with the same tree
307
+ if resetTree:
308
+ masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index
309
+ masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)}
310
+ self.rootNode = Node(self.args, tokens = masked_tokens,
311
+ log_rnd=torch.zeros((), device=self.device),
312
+ log_policy_step=torch.zeros((), device=self.device),
313
+ log_pretrained_step=torch.zeros((), device=self.device),
314
+ totalReward=np.zeros(self.num_obj), timestep=0)
315
+
316
+ def forward(self, resetTree=False):
317
+
318
+ self.reset(resetTree)
319
+
320
+ while (self.iter_num < self.num_iter):
321
+ self.iter_num += 1
322
+
323
+ # traverse the tree form the root node until a leaf node
324
+ with self.timer.section("select"):
325
+ leafNode, _ = self.select(self.rootNode)
326
+
327
+ # expand leaf node into num_children partially unmasked sequences at the next timestep
328
+ with self.timer.section("expand"):
329
+ self.expand(leafNode)
330
+
331
+ final_x, log_rnd, final_rewards, score_vectors, sequences = self.consolidateBuffer()
332
+ # return final_seqs (B, L), log_rnd (B, ), and final rewards (B, )
333
+
334
+ rows = self.timer.summary()
335
+ print("\n=== Timing summary (by total time) ===")
336
+ for name, cnt, total, mean, p50, p95 in rows:
337
+ print(f"{name:30s} n={cnt:5d} total={total:8.3f}s mean={mean*1e3:7.2f}ms "
338
+ f"p50={p50*1e3:7.2f}ms p95={p95*1e3:7.2f}ms")
339
+
340
+ return final_x, log_rnd, final_rewards, score_vectors, sequences
341
+
342
+ # new updateBuffer
343
+ def _debug_buffer_decision(self, sv, reason, extra=None):
344
+ if extra is None: extra = {}
345
+ print(f"[BUFFER] reason={reason} sv={np.round(sv,4)} "
346
+ f"buf_len={len(self.buffer)} extra={extra}")
347
+
348
+ def updateBuffer(self, x_final, log_rnd, score_vectors, childSequences):
349
+ B = x_final.shape[0]
350
+ traj_log_rnds, scalar_rewards = [], []
351
+
352
+ for i in range(B):
353
+ sv = np.asarray(score_vectors[i], dtype=float)
354
+
355
+ # determine how to scalarize the multi-objective rewards
356
+ if self.args.scalarization == "normalized":
357
+ pass
358
+ elif self.args.scalarization == "weighted":
359
+ pass
360
+ else:
361
+ scalar_reward = float(np.sum(sv))
362
+
363
+ traj_log_rnd = log_rnd[i] + (scalar_reward / self.args.alpha) # scale down by alpha
364
+
365
+ item = {
366
+ "x_final": x_final[i].clone(), # clone?
367
+ "log_rnd": traj_log_rnd.clone(),
368
+ "final_reward": scalar_reward,
369
+ "score_vector": sv.copy(),
370
+ "seq": childSequences[i],
371
+ }
372
+
373
+ # Drop if dominated by any existing
374
+ if any(dominated_by(sv, bi["score_vector"]) for bi in self.buffer):
375
+ # for debugging
376
+ self._debug_buffer_decision(sv, "rejected_dominated")
377
+ continue
378
+
379
+ # Remove any existing that this candidate dominates
380
+ keep = []
381
+ for bi in self.buffer:
382
+ if not dominates(sv, bi["score_vector"]):
383
+ keep.append(bi)
384
+ self.buffer = keep
385
+
386
+ # Insert with capacity rule
387
+ if len(self.buffer) < self.buffer_size:
388
+ self.buffer.append(item)
389
+ else:
390
+ # tie-breaker: replace the worst by a simple heuristic (min sum)
391
+ worst_i = int(np.argmin([np.sum(bi["score_vector"]) for bi in self.buffer]))
392
+ self.buffer[worst_i] = item
393
+
394
+ # for debugging
395
+ self._debug_buffer_decision(sv, "inserted", {"new_len": len(self.buffer)})
396
+
397
+ traj_log_rnds.append(traj_log_rnd)
398
+ scalar_rewards.append(scalar_reward)
399
+
400
+ traj_log_rnds = torch.stack(traj_log_rnds, dim=0) if traj_log_rnds else torch.empty(0)
401
+ scalar_rewards = np.asarray(scalar_rewards, dtype=float)
402
+ return traj_log_rnds, scalar_rewards
403
+
404
+ def consolidateBuffer(self):
405
+ """
406
+ returns x_final, log_rnd, and final_rewards in tensors
407
+ """
408
+ x_final = []
409
+ log_rnd = []
410
+ final_rewards = []
411
+ score_vectors = []
412
+ sequences = []
413
+ for item in self.buffer:
414
+ x_final.append(item["x_final"])
415
+ log_rnd.append(item["log_rnd"])
416
+ final_rewards.append(item["final_reward"])
417
+ score_vectors.append(item["score_vector"])
418
+ sequences.append(item["seq"])
419
+
420
+ x_final = torch.stack(x_final, dim=0) # (B, L)
421
+ log_rnd = torch.stack(log_rnd, dim=0).to(dtype=torch.float32) # (B)
422
+ final_rewards = np.stack(final_rewards, axis=0).astype(np.float32)
423
+ score_vectors = np.stack(score_vectors, axis=0).astype(np.float32)
424
+
425
+ return x_final, log_rnd, final_rewards, score_vectors, sequences
426
+
427
+
428
+ def isPathEnd(self, path, maxDepth):
429
+ """
430
+ Checks if the node is completely unmasked (ie. end of path)
431
+ or if the path is at the max depth
432
+ """
433
+ if (path[-1] != self.mask_index).all():
434
+ return True
435
+ elif len(path) >= maxDepth:
436
+ return True
437
+ return False
438
+
439
+ def select(self, currNode, eps=1e-5):
440
+ """
441
+ Traverse the tree from the root node until reaching a legal leaf node
442
+ """
443
+ updated_log_rnd = torch.zeros((), device=self.device)
444
+ while True:
445
+ currNode, nodeStatus = currNode.selectNode()
446
+
447
+ if currNode.parentNode is not None:
448
+ # compute new log_policy
449
+ child_tokens = currNode.tokens['seqs'].to(self.device)
450
+ attn_mask = currNode.tokens['attention_mask'].to(self.device)
451
+ parent = currNode.parentNode
452
+ parent_tokens = parent.tokens['seqs'].to(self.device)
453
+ t = torch.ones(1, device = self.device)
454
+ dt = (1 - eps) / self.num_steps
455
+ with torch.no_grad():
456
+ with self.timer.section("select.compute_log_policy"):
457
+ updated_log_policy_step = self.policy_model.compute_log_policy(parent_tokens,
458
+ child_tokens,
459
+ t=t, dt=dt)
460
+ updated_log_rnd += updated_log_policy_step
461
+
462
+ currNode.update_logrnd(updated_log_policy_step, updated_log_rnd) # update log_rnd
463
+
464
+ if nodeStatus != 3:
465
+ return currNode, nodeStatus
466
+
467
+ def expand(self, parentNode, eps=1e-5):
468
+ """
469
+ Sample unmasking steps from the pre-trained MDLM
470
+ adds num_children partially unmasked sequences to the children of the parentNode
471
+ """
472
+
473
+ num_children = self.num_children
474
+ # initialize child rewards that will be added to total rewards
475
+
476
+
477
+ # compute number of rollout steps
478
+ # if parentNode.timestep = self.num_steps then num_rollout_steps = 1
479
+ num_rollout_steps = self.num_steps - parentNode.timestep
480
+ # array of rollout timesteps from the timestep of parent node to 0
481
+ rollout_t = torch.linspace(1, eps, self.num_steps + 1, device=self.device)
482
+ dt = (1 - eps) / self.num_steps
483
+
484
+ # initialize x and attn_mask
485
+ x = parentNode.tokens['seqs'].to(self.device)
486
+ attn_mask = parentNode.tokens['attention_mask'].to(self.device)
487
+ parent_log_rnd = parentNode.log_rnd # stores the log_rnd up to parent node
488
+
489
+ t = rollout_t[parentNode.timestep] * torch.ones(1, 1, device = self.device)
490
+
491
+ # sample M child sequences and compute their log probabilities
492
+ with torch.no_grad():
493
+ with self.timer.section("expand.batch_mcts_reverse_step"):
494
+ _, x_children, child_log_policy_step, child_log_pretrained_step = \
495
+ self.policy_model.batch_mcts_reverse_step(token_array=x,
496
+ t=t, dt=dt,
497
+ batch_size=num_children,
498
+ pretrained=self.pretrained)
499
+
500
+ # compute weight of the step (num_children, 1)
501
+
502
+ child_log_rnd = (parent_log_rnd + (child_log_pretrained_step - child_log_policy_step)).to(self.device)
503
+
504
+ x_rollout = x_children
505
+
506
+ traj_log_rnd = child_log_rnd # initialize log_rnd for entire rolled out trajectory
507
+
508
+ # rollout under the policy and compute the log ratio at each step
509
+ with self.timer.section("expand.rollout_total"):
510
+ for i in range(1, num_rollout_steps):
511
+ t = rollout_t[parentNode.timestep + i] * torch.ones(num_children, 1, device = self.device)
512
+
513
+ with torch.no_grad():
514
+ _, x_next, log_policy_step, log_pretrained_step = \
515
+ self.policy_model.mcts_reverse_step(x_rollout,
516
+ t=t, dt=dt,
517
+ pretrained=self.pretrained)
518
+
519
+ # add the rollout step
520
+ traj_log_rnd += log_pretrained_step - log_policy_step
521
+
522
+ x_rollout = x_next
523
+
524
+
525
+ # if mask token remains, fully unmask
526
+ mask_positions = (x_rollout == self.mask_index) # (B, L) bool
527
+
528
+ # does **any** mask remain in any sequence
529
+ any_mask_global = mask_positions.any().item() # true if mask remains
530
+ if any_mask_global:
531
+ with torch.no_grad():
532
+ with self.timer.section("expand.noise_removal"):
533
+ log_p, x_next, log_policy_step, log_pretrained_step = \
534
+ self.policy_model.mcts_noise_removal(x_rollout,
535
+ t=t, dt=dt,
536
+ pretrained=self.pretrained)
537
+
538
+ traj_log_rnd += log_pretrained_step - log_policy_step
539
+
540
+ x_rollout = x_next
541
+
542
+ # stores the string sequences for reward evaluation
543
+ with self.timer.section("expand.decode"):
544
+ childSequences = self.tokenizer.batch_decode(x_rollout)
545
+
546
+ ## FOR PEPTIDES ONLY ##
547
+ valid_x_children = []
548
+ valid_x_final = []
549
+ validSequences = []
550
+ valid_traj_log_rnd = []
551
+
552
+ with self.timer.section("expand.filter_is_peptide"):
553
+ for i in range(num_children):
554
+ # string sequence
555
+ childSeq = childSequences[i]
556
+
557
+ # check if the peptide is valid
558
+ if self.analyzer.is_peptide(childSeq):
559
+ valid_x_children.append(x_children[i])
560
+ valid_x_final.append(x_rollout[i])
561
+ validSequences.append(childSeq)
562
+ valid_traj_log_rnd.append(traj_log_rnd[i])
563
+ else:
564
+ childTokens = {'seqs': x_children[i].to(dtype=torch.long), 'attention_mask': attn_mask}
565
+ parentNode.addChildNode(tokens=childTokens,
566
+ log_rnd=child_log_rnd[i],
567
+ log_policy_step=child_log_policy_step[i],
568
+ log_pretrained_step=child_log_pretrained_step[i],
569
+ totalReward=np.zeros(self.num_obj))
570
+
571
+ del traj_log_rnd
572
+
573
+ if (len(validSequences) != 0):
574
+ # add scores to log
575
+ with self.timer.section("expand.scoring_functions"):
576
+ score_vectors = self.rewardFunc(input_seqs=validSequences) # (num_children, num_objectives)
577
+
578
+ average_scores = score_vectors.T
579
+
580
+ self.affinity1_log.append(average_scores[0])
581
+ self.sol_log.append(average_scores[1])
582
+ self.hemo_log.append(average_scores[2])
583
+ self.nf_log.append(average_scores[3])
584
+ self.permeability_log.append(average_scores[4])
585
+
586
+ else:
587
+ # set the values added to log as 0s if there are no valid sequences
588
+ self.affinity1_log.append(np.zeros((self.num_obj, self.num_children)))
589
+ self.sol_log.append(np.zeros((self.num_obj, self.num_children)))
590
+ self.hemo_log.append(np.zeros((self.num_obj, self.num_children)))
591
+ self.nf_log.append(np.zeros((self.num_obj, self.num_children)))
592
+ self.permeability_log.append(np.zeros((self.num_obj, self.num_children)))
593
+
594
+ # convert to tensor
595
+ if len(valid_x_final) == 0:
596
+ # log and bail out gracefully for this expansion
597
+ self.valid_fraction_log.append(0.0)
598
+ return
599
+
600
+ valid_x_final = torch.stack(valid_x_final, dim=0)
601
+ valid_traj_log_rnd = torch.stack(valid_traj_log_rnd, dim=0)
602
+ # update buffer and get rewards
603
+ with self.timer.section("expand.update_buffer"):
604
+ traj_log_rnds, scalar_rewards = self.updateBuffer(valid_x_final, valid_traj_log_rnd, score_vectors, childSequences)
605
+
606
+ allChildReward = np.zeros_like(score_vectors[0])
607
+
608
+ for i in range(len(score_vectors)):
609
+ reward = score_vectors[i]
610
+
611
+ # add to all child reward vector for backprop
612
+ allChildReward += reward # (num_objectives,)
613
+
614
+ # create node for sequence and add to the children node of parent
615
+ childTokens = {'seqs': valid_x_children[i].to(dtype=torch.long), 'attention_mask': attn_mask}
616
+ parentNode.addChildNode(tokens=childTokens,
617
+ log_rnd=child_log_rnd[i],
618
+ log_policy_step=child_log_policy_step[i],
619
+ log_pretrained_step=child_log_pretrained_step[i],
620
+ totalReward=reward)
621
+
622
+ ### END OF FOR PEPTIDES ONLY ###
623
+
624
+ valid_fraction = len(validSequences) / num_children
625
+ self.valid_fraction_log.append(valid_fraction)
626
+
627
+ # debugging
628
+ print(f"[EXPAND] iter={self.iter_num} parent_t={parentNode.timestep} "
629
+ f"num_children={num_children} valid={len(validSequences)} any_mask={any_mask_global}")
630
+ if score_vectors is not None:
631
+ print(f"[SCORES] min={np.min(score_vectors,0)} max={np.max(score_vectors,0)} "
632
+ f"nan_any={np.isnan(score_vectors).any()}")
633
+ # end debugging
634
+
635
+ self.reward_log.append(scalar_rewards)
636
+ self.logrnd_log.append(traj_log_rnds.detach().cpu().numpy())
637
+
638
+ allChildReward = allChildReward / len(validSequences) # normalize by number of valid children
639
+ # backpropogate all child rewards
640
+ with self.timer.section("expand.backprop"):
641
+ self.backprop(parentNode, allChildReward)
642
+
643
+
644
+ def backprop(self, node, allChildReward):
645
+ # backpropogate rewards through the path leading to the leaf node from the root
646
+ while node:
647
+ node.updateNode(allChildReward)
648
+ node = node.parentNode
tr2d2-pep/plotting.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import pandas as pd
6
+ import numpy as np
7
+ import numpy as np
8
+ import torch
9
+ import os
10
+
11
+
12
+ def plot_data_with_distribution_seaborn(log1, log2=None,
13
+ save_path=None,
14
+ label1=None,
15
+ label2=None,
16
+ title=None):
17
+ """
18
+ Plots one or two datasets with the average values and distributions over iterations using Seaborn.
19
+
20
+ Parameters:
21
+ log1 (list of lists): The first list of scores (each element is a list of scores for an iteration).
22
+ log2 (list of lists, optional): The second list of scores (each element is a list of scores for an iteration). Defaults to None.
23
+ save_path (str): Path to save the plot. Defaults to None.
24
+ label1 (str): Label for the first dataset. Defaults to "Fraction of Valid Peptide SMILES".
25
+ label2 (str, optional): Label for the second dataset. Defaults to None.
26
+ title (str): Title of the plot. Defaults to "Fraction of Valid Peptides Over Iterations".
27
+ """
28
+ # Prepare data for log1
29
+ data1 = pd.DataFrame({
30
+ "Iteration": np.repeat(range(1, len(log1) + 1), [len(scores) for scores in log1]),
31
+ label1: [score for scores in log1 for score in scores],
32
+ "Dataset": label1,
33
+ "Style": "Log1"
34
+ })
35
+
36
+ # Prepare data for log2 if provided
37
+ if log2 is not None:
38
+ data2 = pd.DataFrame({
39
+ "Iteration": np.repeat(range(1, len(log2) + 1), [len(scores) for scores in log2]),
40
+ label2: [score for scores in log2 for score in scores],
41
+ "Dataset": label2,
42
+ "Style": "Log2"
43
+ })
44
+ data = pd.concat([data1, data2], ignore_index=True)
45
+ else:
46
+ data = data1
47
+
48
+ palette = {
49
+ label1: "#8181ED", # Default color for log1
50
+ label2: "#D577FF" # Default color for log2 (if provided)
51
+ }
52
+
53
+ # Set Seaborn theme
54
+ sns.set_theme()
55
+ sns.set_context("paper")
56
+
57
+ # Create the plot
58
+ sns.relplot(
59
+ data=data,
60
+ kind="line",
61
+ x="Iteration",
62
+ y=label1,
63
+ hue="Dataset",
64
+ style="Style",
65
+ markers=True,
66
+ dashes=True,
67
+ ci="sd", # Show standard deviation
68
+ height=5,
69
+ aspect=1.5,
70
+ palette=palette
71
+ )
72
+
73
+ # Titles and labels
74
+ plt.title(title)
75
+ plt.xlabel("Iteration")
76
+ plt.ylabel(label1)
77
+
78
+ if save_path:
79
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
80
+ print(f"Plot saved to {save_path}")
81
+ plt.show()
82
+
83
+ def plot_data(log1, log2=None,
84
+ save_path=None,
85
+ label1="Log 1",
86
+ label2=None,
87
+ title="Fraction of Valid Peptides Over Iterations",
88
+ palette=None):
89
+ """
90
+ Plots one or two datasets with their mean values over iterations.
91
+
92
+ Parameters:
93
+ log1 (list): The first list of mean values for each iteration.
94
+ log2 (list, optional): The second list of mean values for each iteration. Defaults to None.
95
+ save_path (str): Path to save the plot. Defaults to None.
96
+ label1 (str): Label for the first dataset. Defaults to "Log 1".
97
+ label2 (str, optional): Label for the second dataset. Defaults to None.
98
+ title (str): Title of the plot. Defaults to "Mean Values Over Iterations".
99
+ palette (dict, optional): A dictionary defining custom colors for datasets. Defaults to None.
100
+ """
101
+ # Prepare data for log1
102
+ data1 = pd.DataFrame({
103
+ "Iteration": range(1, len(log1) + 1),
104
+ "Fraction of Valid Peptides": log1,
105
+ "Dataset": label1
106
+ })
107
+
108
+ # Prepare data for log2 if provided
109
+ if log2 is not None:
110
+ data2 = pd.DataFrame({
111
+ "Iteration": range(1, len(log2) + 1),
112
+ "Fraction of Valid Peptides": log2,
113
+ "Dataset": label2
114
+ })
115
+ data = pd.concat([data1, data2], ignore_index=True)
116
+ else:
117
+ data = data1
118
+
119
+ palette = {
120
+ label1: "#8181ED", # Default color for log1
121
+ label2: "#D577FF" # Default color for log2 (if provided)
122
+ }
123
+
124
+ # Set Seaborn theme
125
+ sns.set_theme()
126
+ sns.set_context("paper")
127
+
128
+ # Create the plot
129
+ sns.lineplot(
130
+ data=data,
131
+ x="Iteration",
132
+ y="Fraction of Valid Peptides",
133
+ hue="Dataset",
134
+ style="Dataset",
135
+ markers=True,
136
+ dashes=False,
137
+ palette=palette
138
+ )
139
+
140
+ # Titles and labels
141
+ plt.title(title)
142
+ plt.xlabel("Iteration")
143
+ plt.ylabel("Fraction of Valid Peptides")
144
+
145
+ if save_path:
146
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
147
+ print(f"Plot saved to {save_path}")
148
+ plt.show()
tr2d2-pep/roformer.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import RoFormerConfig, RoFormerForMaskedLM
2
+ import torch.nn as nn
3
+ from torch.nn.parallel import DistributedDataParallel as DDP
4
+ import torch
5
+
6
+ class Roformer(nn.Module):
7
+ def __init__(self, config, tokenizer, device=None):
8
+ super(Roformer, self).__init__()
9
+
10
+ self.tokenizer = tokenizer
11
+ self.vocab_size = self.tokenizer.vocab_size
12
+
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
14
+
15
+
16
+ roformer_config = RoFormerConfig(
17
+ vocab_size=self.tokenizer.vocab_size,
18
+ embedding_size=config.roformer.hidden_size,
19
+ hidden_size=config.roformer.hidden_size,
20
+ num_hidden_layers=config.roformer.n_layers,
21
+ num_attention_heads=config.roformer.n_heads,
22
+ intermediate_size=config.roformer.hidden_size * 4,
23
+ max_position_embeddings=config.roformer.max_position_embeddings,
24
+ hidden_dropout_prob=0.1,
25
+ attention_probs_dropout_prob=0.1,
26
+ pad_token_id=0,
27
+ rotary_value=False
28
+ )
29
+
30
+ self.model = RoFormerForMaskedLM(roformer_config).to(self.device)
31
+
32
+ def freeze_model(self):
33
+ for param in self.model.parameters():
34
+ param.requires_grad = False
35
+
36
+ def unfreeze_all_layers(self):
37
+ for param in self.model.parameters():
38
+ param.requires_grad = True
39
+
40
+ def unfreeze_n_layers(self, n):
41
+ num_layers = 8
42
+
43
+ for i, layer in enumerate(self.model.roformer.encoder.layer):
44
+ # finetune final n layers
45
+ if i >= num_layers - n:
46
+ # unfreeze query weights
47
+ for module in layer.attention.self.query.modules():
48
+ for param in module.parameters():
49
+ param.requires_grad = True
50
+ # unfreeze key weights
51
+ for module in layer.attention.self.key.modules():
52
+ for param in module.parameters():
53
+ param.requires_grad = True
54
+
55
+ def forward(self, input_ids, attn_mask):
56
+
57
+ input_ids = input_ids.to(self.device)
58
+ attn_mask = attn_mask.to(self.device)
59
+
60
+ # get logits embeddings
61
+ logits = self.model(input_ids=input_ids, attention_mask=attn_mask)
62
+ # return logits
63
+ #print(logits.logits)
64
+ return logits.logits
65
+
66
+ def save_model(self, save_dir):
67
+ self.model.save_pretrained(save_dir)
68
+ self.tokenizer.save_pretrained(save_dir)
69
+
70
+ @classmethod
71
+ def load_model(cls, save_dir, config, tokenizer):
72
+ roformer = cls(config, tokenizer)
73
+ roformer.model = RoFormerForMaskedLM.from_pretrained(save_dir)
74
+ return roformer
tr2d2-pep/run_mcts.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home
4
+ ENV_PATH=/path/to/your/env
5
+ SCRIPT_LOC=$HOME_LOC/tr2d2/peptides
6
+ LOG_LOC=$HOME_LOC/tr2d2/peptides/logs
7
+ DATE=$(date +%m_%d)
8
+ SPECIAL_PREFIX='tfr-peptune-baseline'
9
+ # set 3 have skip connection
10
+ PYTHON_EXECUTABLE=$ENV_PATH/bin/python
11
+
12
+ # ===================================================================
13
+
14
+ source "$(conda info --base)/etc/profile.d/conda.sh"
15
+ conda activate $ENV_PATH
16
+
17
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/generate_mcts.py \
18
+ --base_path $HOME_LOC \
19
+ --device "cuda:0" \
20
+ --noise_removal \
21
+ --run_name "tfr-peptune-baseline" \
22
+ --num_children 50 \
23
+ --num_iter 100 \
24
+ --buffer_size 100 \
25
+ --seq_length 200 \
26
+ --total_num_steps 128 \
27
+ --exploration 0.1 > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}.log" 2>&1
28
+
29
+ conda deactivate
tr2d2-pep/scoring/functions/binding.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os, torch
3
+ import numpy as np
4
+ import torch
5
+ import pandas as pd
6
+ import torch.nn as nn
7
+ import esm
8
+ from transformers import AutoModelForMaskedLM
9
+
10
+ class ImprovedBindingPredictor(nn.Module):
11
+ def __init__(self,
12
+ esm_dim=1280,
13
+ smiles_dim=768,
14
+ hidden_dim=512,
15
+ n_heads=8,
16
+ n_layers=3,
17
+ dropout=0.1):
18
+ super().__init__()
19
+
20
+ # Define binding thresholds
21
+ self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
22
+ self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
23
+
24
+ # Project to same dimension
25
+ self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
26
+ self.protein_projection = nn.Linear(esm_dim, hidden_dim)
27
+ self.protein_norm = nn.LayerNorm(hidden_dim)
28
+ self.smiles_norm = nn.LayerNorm(hidden_dim)
29
+
30
+ # Cross attention blocks with layer norm
31
+ self.cross_attention_layers = nn.ModuleList([
32
+ nn.ModuleDict({
33
+ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
34
+ 'norm1': nn.LayerNorm(hidden_dim),
35
+ 'ffn': nn.Sequential(
36
+ nn.Linear(hidden_dim, hidden_dim * 4),
37
+ nn.ReLU(),
38
+ nn.Dropout(dropout),
39
+ nn.Linear(hidden_dim * 4, hidden_dim)
40
+ ),
41
+ 'norm2': nn.LayerNorm(hidden_dim)
42
+ }) for _ in range(n_layers)
43
+ ])
44
+
45
+ # Prediction heads
46
+ self.shared_head = nn.Sequential(
47
+ nn.Linear(hidden_dim * 2, hidden_dim),
48
+ nn.ReLU(),
49
+ nn.Dropout(dropout),
50
+ )
51
+
52
+ # Regression head
53
+ self.regression_head = nn.Linear(hidden_dim, 1)
54
+
55
+ # Classification head (3 classes: tight, medium, loose binding)
56
+ self.classification_head = nn.Linear(hidden_dim, 3)
57
+
58
+ def get_binding_class(self, affinity):
59
+ """Convert affinity values to class indices
60
+ 0: tight binding (>= 7.5)
61
+ 1: medium binding (6.0-7.5)
62
+ 2: weak binding (< 6.0)
63
+ """
64
+ if isinstance(affinity, torch.Tensor):
65
+ tight_mask = affinity >= self.tight_threshold
66
+ weak_mask = affinity < self.weak_threshold
67
+ medium_mask = ~(tight_mask | weak_mask)
68
+
69
+ classes = torch.zeros_like(affinity, dtype=torch.long)
70
+ classes[medium_mask] = 1
71
+ classes[weak_mask] = 2
72
+ return classes
73
+ else:
74
+ if affinity >= self.tight_threshold:
75
+ return 0 # tight binding
76
+ elif affinity < self.weak_threshold:
77
+ return 2 # weak binding
78
+ else:
79
+ return 1 # medium binding
80
+
81
+ def forward(self, protein_emb, smiles_emb):
82
+ protein = self.protein_norm(self.protein_projection(protein_emb))
83
+ smiles = self.smiles_norm(self.smiles_projection(smiles_emb))
84
+
85
+ #protein = protein.transpose(0, 1)
86
+ #smiles = smiles.transpose(0, 1)
87
+
88
+ # Cross attention layers
89
+ for layer in self.cross_attention_layers:
90
+ # Protein attending to SMILES
91
+ attended_protein = layer['attention'](
92
+ protein, smiles, smiles
93
+ )[0]
94
+ protein = layer['norm1'](protein + attended_protein)
95
+ protein = layer['norm2'](protein + layer['ffn'](protein))
96
+
97
+ # SMILES attending to protein
98
+ attended_smiles = layer['attention'](
99
+ smiles, protein, protein
100
+ )[0]
101
+ smiles = layer['norm1'](smiles + attended_smiles)
102
+ smiles = layer['norm2'](smiles + layer['ffn'](smiles))
103
+
104
+ # Get sequence-level representations
105
+ protein_pool = torch.mean(protein, dim=0)
106
+ smiles_pool = torch.mean(smiles, dim=0)
107
+
108
+ # Concatenate both representations
109
+ combined = torch.cat([protein_pool, smiles_pool], dim=-1)
110
+
111
+ # Shared features
112
+ shared_features = self.shared_head(combined)
113
+
114
+ regression_output = self.regression_head(shared_features)
115
+ classification_logits = self.classification_head(shared_features)
116
+
117
+ return regression_output, classification_logits
118
+
119
+ class BindingAffinity:
120
+ def __init__(self, prot_seq, tokenizer, base_path, device=None, emb_model=None):
121
+ super().__init__()
122
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
123
+
124
+ # peptide embeddings
125
+ if emb_model is not None:
126
+ self.pep_model = emb_model.to(self.device).eval()
127
+ else:
128
+ self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
129
+
130
+ self.pep_tokenizer = tokenizer
131
+
132
+ self.model = ImprovedBindingPredictor().to(self.device)
133
+ checkpoint = torch.load(f'{base_path}/TR2-D2/tr2d2-pep/scoring/functions/classifiers/binding-affinity.pt',
134
+ map_location=self.device,
135
+ weights_only=False)
136
+ self.model.load_state_dict(checkpoint['model_state_dict'])
137
+
138
+ self.model.eval()
139
+
140
+ self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model
141
+ self.esm_model = self.esm_model.to(self.device).eval()
142
+ self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer
143
+
144
+ data = [("target", prot_seq)]
145
+ # get tokenized protein
146
+ _, _, prot_tokens = self.prot_tokenizer(data)
147
+ prot_tokens = prot_tokens.to(self.device)
148
+ with torch.no_grad():
149
+ results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2
150
+ prot_emb = results["representations"][33]
151
+
152
+ self.prot_emb = prot_emb[0].to(self.device)
153
+ self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True)
154
+
155
+
156
+ def forward(self, input_seqs):
157
+ with torch.no_grad():
158
+ scores = []
159
+ for seq in input_seqs:
160
+ pep_tokens = self.pep_tokenizer(seq, return_tensors='pt', padding=True)
161
+
162
+ pep_tokens = {k: v.to(self.device) for k, v in pep_tokens.items()}
163
+
164
+ with torch.no_grad():
165
+ emb = self.pep_model(input_ids=pep_tokens['input_ids'],
166
+ attention_mask=pep_tokens['attention_mask'],
167
+ output_hidden_states=True)
168
+
169
+ #emb = self.pep_model(input_ids=pep_tokens['input_ids'], attention_mask=pep_tokens['attention_mask'])
170
+ pep_emb = emb.last_hidden_state.squeeze(0)
171
+ pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
172
+
173
+ score, logits = self.model.forward(self.prot_emb, pep_emb)
174
+ scores.append(score.item())
175
+ return scores
176
+
177
+ def __call__(self, input_seqs: list):
178
+ return self.forward(input_seqs)
tr2d2-pep/scoring/functions/binding_utils.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import numpy as np
4
+
5
+ def to_var(x):
6
+ if torch.cuda.is_available():
7
+ x = x.cuda()
8
+ return x
9
+
10
+ class MultiHeadAttentionSequence(nn.Module):
11
+
12
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
13
+
14
+ super().__init__()
15
+
16
+ self.n_head = n_head
17
+ self.d_model = d_model
18
+ self.d_k = d_k
19
+ self.d_v = d_v
20
+
21
+ self.W_Q = nn.Linear(d_model, n_head*d_k)
22
+ self.W_K = nn.Linear(d_model, n_head*d_k)
23
+ self.W_V = nn.Linear(d_model, n_head*d_v)
24
+ self.W_O = nn.Linear(n_head*d_v, d_model)
25
+
26
+ self.layer_norm = nn.LayerNorm(d_model)
27
+
28
+ self.dropout = nn.Dropout(dropout)
29
+
30
+ def forward(self, q, k, v):
31
+
32
+ batch, len_q, _ = q.size()
33
+ batch, len_k, _ = k.size()
34
+ batch, len_v, _ = v.size()
35
+
36
+ Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
37
+ K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
38
+ V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
39
+
40
+ Q = Q.transpose(1, 2)
41
+ K = K.transpose(1, 2).transpose(2, 3)
42
+ V = V.transpose(1, 2)
43
+
44
+ attention = torch.matmul(Q, K)
45
+
46
+ attention = attention / np.sqrt(self.d_k)
47
+
48
+ attention = F.softmax(attention, dim=-1)
49
+
50
+ output = torch.matmul(attention, V)
51
+
52
+ output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
53
+
54
+ output = self.W_O(output)
55
+
56
+ output = self.dropout(output)
57
+
58
+ output = self.layer_norm(output + q)
59
+
60
+ return output, attention
61
+
62
+ class MultiHeadAttentionReciprocal(nn.Module):
63
+
64
+
65
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
66
+
67
+ super().__init__()
68
+
69
+ self.n_head = n_head
70
+ self.d_model = d_model
71
+ self.d_k = d_k
72
+ self.d_v = d_v
73
+
74
+ self.W_Q = nn.Linear(d_model, n_head*d_k)
75
+ self.W_K = nn.Linear(d_model, n_head*d_k)
76
+ self.W_V = nn.Linear(d_model, n_head*d_v)
77
+ self.W_O = nn.Linear(n_head*d_v, d_model)
78
+ self.W_V_2 = nn.Linear(d_model, n_head*d_v)
79
+ self.W_O_2 = nn.Linear(n_head*d_v, d_model)
80
+
81
+ self.layer_norm = nn.LayerNorm(d_model)
82
+
83
+ self.dropout = nn.Dropout(dropout)
84
+
85
+ self.layer_norm_2 = nn.LayerNorm(d_model)
86
+
87
+ self.dropout_2 = nn.Dropout(dropout)
88
+
89
+ def forward(self, q, k, v, v_2):
90
+
91
+ batch, len_q, _ = q.size()
92
+ batch, len_k, _ = k.size()
93
+ batch, len_v, _ = v.size()
94
+ batch, len_v_2, _ = v_2.size()
95
+
96
+ Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
97
+ K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
98
+ V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
99
+ V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v])
100
+
101
+ Q = Q.transpose(1, 2)
102
+ K = K.transpose(1, 2).transpose(2, 3)
103
+ V = V.transpose(1, 2)
104
+ V_2 = V_2.transpose(1,2)
105
+
106
+ attention = torch.matmul(Q, K)
107
+
108
+
109
+ attention = attention /np.sqrt(self.d_k)
110
+
111
+ attention_2 = attention.transpose(-2, -1)
112
+
113
+
114
+
115
+ attention = F.softmax(attention, dim=-1)
116
+
117
+ attention_2 = F.softmax(attention_2, dim=-1)
118
+
119
+
120
+ output = torch.matmul(attention, V)
121
+
122
+ output_2 = torch.matmul(attention_2, V_2)
123
+
124
+ output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
125
+
126
+ output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head])
127
+
128
+ output = self.W_O(output)
129
+
130
+ output_2 = self.W_O_2(output_2)
131
+
132
+ output = self.dropout(output)
133
+
134
+ output = self.layer_norm(output + q)
135
+
136
+ output_2 = self.dropout(output_2)
137
+
138
+ output_2 = self.layer_norm(output_2 + k)
139
+
140
+
141
+ return output, output_2, attention, attention_2
142
+
143
+
144
+ class FFN(nn.Module):
145
+
146
+ def __init__(self, d_in, d_hid, dropout=0.1):
147
+ super().__init__()
148
+
149
+ self.layer_1 = nn.Conv1d(d_in, d_hid,1)
150
+ self.layer_2 = nn.Conv1d(d_hid, d_in,1)
151
+ self.relu = nn.ReLU()
152
+ self.layer_norm = nn.LayerNorm(d_in)
153
+
154
+ self.dropout = nn.Dropout(dropout)
155
+
156
+ def forward(self, x):
157
+
158
+ residual = x
159
+ output = self.layer_1(x.transpose(1, 2))
160
+
161
+ output = self.relu(output)
162
+
163
+ output = self.layer_2(output)
164
+
165
+ output = self.dropout(output)
166
+
167
+ output = self.layer_norm(output.transpose(1, 2)+residual)
168
+
169
+ return output
170
+
171
+ class ConvLayer(nn.Module):
172
+ def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
173
+ super(ConvLayer, self).__init__()
174
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
175
+ self.relu = nn.ReLU()
176
+
177
+ def forward(self, x):
178
+ out = self.conv(x)
179
+ out = self.relu(out)
180
+ return out
181
+
182
+
183
+ class DilatedCNN(nn.Module):
184
+ def __init__(self, d_model, d_hidden):
185
+ super(DilatedCNN, self).__init__()
186
+ self.first_ = nn.ModuleList()
187
+ self.second_ = nn.ModuleList()
188
+ self.third_ = nn.ModuleList()
189
+
190
+ dilation_tuple = (1, 2, 3)
191
+ dim_in_tuple = (d_model, d_hidden, d_hidden)
192
+ dim_out_tuple = (d_hidden, d_hidden, d_hidden)
193
+
194
+ for i, dilation_rate in enumerate(dilation_tuple):
195
+ self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate,
196
+ dilation=dilation_rate))
197
+
198
+ for i, dilation_rate in enumerate(dilation_tuple):
199
+ self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate,
200
+ dilation=dilation_rate))
201
+
202
+ for i, dilation_rate in enumerate(dilation_tuple):
203
+ self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate,
204
+ dilation=dilation_rate))
205
+
206
+ def forward(self, protein_seq_enc):
207
+ # pdb.set_trace()
208
+ protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L
209
+
210
+ first_embedding = protein_seq_enc
211
+ second_embedding = protein_seq_enc
212
+ third_embedding = protein_seq_enc
213
+
214
+ for i in range(len(self.first_)):
215
+ first_embedding = self.first_[i](first_embedding)
216
+
217
+ for i in range(len(self.second_)):
218
+ second_embedding = self.second_[i](second_embedding)
219
+
220
+ for i in range(len(self.third_)):
221
+ third_embedding = self.third_[i](third_embedding)
222
+
223
+ # pdb.set_trace()
224
+
225
+ protein_seq_enc = first_embedding + second_embedding + third_embedding
226
+
227
+ return protein_seq_enc.transpose(1, 2)
228
+
229
+
230
+ class ReciprocalLayerwithCNN(nn.Module):
231
+
232
+ def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v):
233
+ super().__init__()
234
+
235
+ self.cnn = DilatedCNN(d_model, d_hidden)
236
+
237
+ self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
238
+
239
+ self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
240
+
241
+ self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, d_k, d_v)
242
+
243
+ self.ffn_seq = FFN(d_hidden, d_inner)
244
+
245
+ self.ffn_protein = FFN(d_hidden, d_inner)
246
+
247
+ def forward(self, sequence_enc, protein_seq_enc):
248
+ # pdb.set_trace() # protein_seq_enc.shape = B * L * d_model
249
+ protein_seq_enc = self.cnn(protein_seq_enc)
250
+ prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
251
+
252
+ seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
253
+
254
+ prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
255
+
256
+ prot_enc = self.ffn_protein(prot_enc)
257
+
258
+ seq_enc = self.ffn_seq(seq_enc)
259
+
260
+ return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
261
+
262
+
263
+ class ReciprocalLayer(nn.Module):
264
+
265
+ def __init__(self, d_model, d_inner, n_head, d_k, d_v):
266
+
267
+ super().__init__()
268
+
269
+ self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
270
+
271
+ self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
272
+
273
+ self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, d_k, d_v)
274
+
275
+ self.ffn_seq = FFN(d_model, d_inner)
276
+
277
+ self.ffn_protein = FFN(d_model, d_inner)
278
+
279
+ def forward(self, sequence_enc, protein_seq_enc):
280
+ prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
281
+
282
+ seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
283
+
284
+
285
+ prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
286
+ prot_enc = self.ffn_protein(prot_enc)
287
+
288
+ seq_enc = self.ffn_seq(seq_enc)
289
+
290
+ return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
tr2d2-pep/scoring/functions/classifiers/hemolysis-xgboost.json ADDED
The diff for this file is too large to render. See raw diff
 
tr2d2-pep/scoring/functions/classifiers/nonfouling-xgboost.json ADDED
The diff for this file is too large to render. See raw diff
 
tr2d2-pep/scoring/functions/classifiers/permeability-xgboost.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e5d8c84bdad75f7091b5b3963133d4b0ebd180ae45654618ca6c090eee0bc06
3
+ size 45249160
tr2d2-pep/scoring/functions/classifiers/solubility-xgboost.json ADDED
The diff for this file is too large to render. See raw diff
 
tr2d2-pep/scoring/functions/hemolysis.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xgboost as xgb
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoModelForMaskedLM
5
+ import warnings
6
+ import numpy as np
7
+ from rdkit import rdBase
8
+
9
+ rdBase.DisableLog('rdApp.error')
10
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
11
+ warnings.filterwarnings("ignore", category=UserWarning)
12
+ warnings.filterwarnings("ignore", category=FutureWarning)
13
+
14
+ class Hemolysis:
15
+
16
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
17
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
18
+ self.predictor = xgb.Booster(model_file=f'{base_path}/TR2-D2/tr2d2-pep/scoring/functions/classifiers/hemolysis-xgboost.json')
19
+ self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
20
+ self.tokenizer = tokenizer
21
+
22
+ def generate_embeddings(self, sequences):
23
+ embeddings = []
24
+ for sequence in sequences:
25
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
26
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
27
+ with torch.no_grad():
28
+ output = self.emb_model(**tokenized)
29
+ # Mean pooling across sequence length
30
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
31
+ embeddings.append(embedding)
32
+ return np.array(embeddings)
33
+
34
+ def get_scores(self, input_seqs: list):
35
+ scores = np.ones(len(input_seqs))
36
+ features = self.generate_embeddings(input_seqs)
37
+
38
+ if len(features) == 0:
39
+ return scores
40
+
41
+ features = np.nan_to_num(features, nan=0.)
42
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
43
+
44
+ features = xgb.DMatrix(features)
45
+
46
+ probs = self.predictor.predict(features)
47
+ # return the probability of it being not hemolytic
48
+ return scores - probs
49
+
50
+ def __call__(self, input_seqs: list):
51
+ scores = self.get_scores(input_seqs)
52
+ return scores
53
+
54
+ def unittest():
55
+ hemo = Hemolysis()
56
+ seq = ["[te]NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
57
+ print(hemo.tokenizer.vocab_size)
58
+ scores = hemo(input_seqs=seq)
59
+ print(scores)
60
+
61
+
62
+ if __name__ == '__main__':
63
+ unittest()
tr2d2-pep/scoring/functions/nonfouling.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import xgboost as xgb
4
+ import torch
5
+ import numpy as np
6
+ from transformers import AutoModelForMaskedLM
7
+ import warnings
8
+ import numpy as np
9
+ from rdkit import Chem, rdBase, DataStructs
10
+
11
+
12
+ rdBase.DisableLog('rdApp.error')
13
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
14
+ warnings.filterwarnings("ignore", category=UserWarning)
15
+ warnings.filterwarnings("ignore", category=FutureWarning)
16
+
17
+ class Nonfouling:
18
+
19
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
20
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
21
+ self.predictor = xgb.Booster(model_file=f'{base_path}/TR2-D2/tr2d2-pep/scoring/functions/classifiers/nonfouling-xgboost.json')
22
+ self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
23
+ self.tokenizer = tokenizer
24
+
25
+ def generate_embeddings(self, sequences):
26
+ embeddings = []
27
+ for sequence in sequences:
28
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
29
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
30
+ with torch.no_grad():
31
+ output = self.emb_model(**tokenized)
32
+ # Mean pooling across sequence length
33
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
34
+ embeddings.append(embedding)
35
+ return np.array(embeddings)
36
+
37
+ def get_scores(self, input_seqs: list):
38
+ scores = np.zeros(len(input_seqs))
39
+ features = self.generate_embeddings(input_seqs)
40
+
41
+ if len(features) == 0:
42
+ return scores
43
+
44
+ features = np.nan_to_num(features, nan=0.)
45
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
46
+
47
+ features = xgb.DMatrix(features)
48
+
49
+ scores = self.predictor.predict(features)
50
+ # return the probability of it being not hemolytic
51
+ return scores
52
+
53
+ def __call__(self, input_seqs: list):
54
+ scores = self.get_scores(input_seqs)
55
+ return scores
56
+
57
+ def unittest():
58
+ nf = Nonfouling()
59
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
60
+
61
+ scores = nf(input_seqs=seq)
62
+ print(scores)
63
+
64
+
65
+ if __name__ == '__main__':
66
+ unittest()
tr2d2-pep/scoring/functions/permeability.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import xgboost as xgb
4
+ import torch
5
+ import numpy as np
6
+ from transformers import AutoModelForMaskedLM
7
+ import warnings
8
+ import numpy as np
9
+ from rdkit.Chem import Descriptors, rdMolDescriptors
10
+ from rdkit import Chem, rdBase, DataStructs
11
+ from rdkit.Chem import AllChem
12
+ from typing import List
13
+ from transformers import AutoModelForMaskedLM
14
+
15
+
16
+ rdBase.DisableLog('rdApp.error')
17
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
18
+ warnings.filterwarnings("ignore", category=UserWarning)
19
+ warnings.filterwarnings("ignore", category=FutureWarning)
20
+
21
+ def fingerprints_from_smiles(smiles: List, size=2048):
22
+ """ Create ECFP fingerprints of smiles, with validity check """
23
+ fps = []
24
+ valid_mask = []
25
+ for i, smile in enumerate(smiles):
26
+ mol = Chem.MolFromSmiles(smile)
27
+ valid_mask.append(int(mol is not None))
28
+ fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
29
+ fps.append(fp)
30
+
31
+ fps = np.concatenate(fps, axis=0)
32
+ return fps, valid_mask
33
+
34
+
35
+ def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
36
+ """ Create ECFP fingerprint of a molecule """
37
+ if hashed:
38
+ fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
39
+ else:
40
+ fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
41
+ fp_np = np.zeros((1,))
42
+ DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
43
+ return fp_np.reshape(1, -1)
44
+
45
+ def getMolDescriptors(mol, missingVal=0):
46
+ """ calculate the full list of descriptors for a molecule """
47
+
48
+ values, names = [], []
49
+ for nm, fn in Descriptors._descList:
50
+ try:
51
+ val = fn(mol)
52
+ except:
53
+ val = missingVal
54
+ values.append(val)
55
+ names.append(nm)
56
+
57
+ custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
58
+ 'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
59
+ 'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
60
+
61
+ for nm, fn in custom_descriptors.items():
62
+ try:
63
+ val = fn(mol)
64
+ except:
65
+ val = missingVal
66
+ values.append(val)
67
+ names.append(nm)
68
+ return values, names
69
+
70
+ def get_pep_dps_from_smi(smi):
71
+ try:
72
+ mol = Chem.MolFromSmiles(smi)
73
+ except:
74
+ print(f"convert smi {smi} to molecule failed!")
75
+ mol = None
76
+
77
+ dps, _ = getMolDescriptors(mol)
78
+ return np.array(dps)
79
+
80
+
81
+ def get_pep_dps(smi_list):
82
+ if len(smi_list) == 0:
83
+ return np.zeros((0, 213))
84
+ return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
85
+
86
+ def check_smi_validity(smiles: list):
87
+ valid_smi, valid_idx = [], []
88
+ for idx, smi in enumerate(smiles):
89
+ try:
90
+ mol = Chem.MolFromSmiles(smi) if smi else None
91
+ if mol:
92
+ valid_smi.append(smi)
93
+ valid_idx.append(idx)
94
+ except Exception as e:
95
+ # logger.debug(f'Error: {e} in smiles {smi}')
96
+ pass
97
+ return valid_smi, valid_idx
98
+
99
+ class Permeability:
100
+
101
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
102
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
103
+ self.predictor = xgb.Booster(model_file=f'{base_path}/TR2-D2/tr2d2-pep/scoring/functions/classifiers/permeability-xgboost.json')
104
+ if emb_model is not None:
105
+ self.emb_model = emb_model.to(self.device).eval()
106
+ else:
107
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
108
+
109
+ self.tokenizer = tokenizer
110
+
111
+ def generate_embeddings(self, sequences):
112
+ embeddings = []
113
+ for sequence in sequences:
114
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
115
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
116
+ with torch.no_grad():
117
+ output = self.emb_model(**tokenized)
118
+ # Mean pooling across sequence length
119
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
120
+ embeddings.append(embedding)
121
+ return np.array(embeddings)
122
+
123
+ def get_features(self, input_seqs: list, dps=False, fps=False):
124
+ #valid_smiles, valid_idxes = check_smi_validity(input_seqs)
125
+
126
+
127
+ if fps:
128
+ fingerprints = fingerprints_from_smiles(input_seqs)[0]
129
+ else:
130
+ fingerprints = torch.empty((len(input_seqs), 0))
131
+
132
+ if dps:
133
+ descriptors = get_pep_dps(input_seqs)
134
+ else:
135
+ descriptors = torch.empty((len(input_seqs), 0))
136
+
137
+ embeddings = self.generate_embeddings(input_seqs)
138
+ # logger.debug(f'X_fps.shape: {X_fps.shape}, X_dps.shape: {X_dps.shape}')
139
+
140
+ features = np.concatenate([fingerprints, descriptors, embeddings], axis=1)
141
+
142
+ return features
143
+
144
+ def get_scores(self, input_seqs: list):
145
+ scores = -10 * np.ones(len(input_seqs))
146
+ features = self.get_features(input_seqs)
147
+
148
+ if len(features) == 0:
149
+ return scores
150
+
151
+ features = np.nan_to_num(features, nan=0.)
152
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
153
+
154
+ features = xgb.DMatrix(features)
155
+
156
+ scores = self.predictor.predict(features)
157
+ return scores
158
+
159
+ def __call__(self, input_seqs: list):
160
+ scores = self.get_scores(input_seqs)
161
+ return scores
162
+
163
+ def unittest():
164
+ permeability = Permeability()
165
+ seq = ['N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1cNc2c1cc(O)cc2)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H]([C@@H](O)C(C)C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O']
166
+ scores = permeability(input_seqs=seq)
167
+ print(scores)
168
+
169
+
170
+ if __name__ == '__main__':
171
+ unittest()
tr2d2-pep/scoring/functions/scoring_utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import numpy as np
3
+ from loguru import logger
4
+ from sklearn.ensemble import RandomForestRegressor
5
+ from rdkit.Chem import Descriptors, rdMolDescriptors
6
+ import joblib
7
+ from rdkit import Chem, rdBase, DataStructs
8
+ from rdkit.Chem import AllChem
9
+ from typing import List
10
+
11
+
12
+ def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
13
+ """
14
+ Create ECFP fingerprint of a molecule
15
+ """
16
+ if hashed:
17
+ fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
18
+ else:
19
+ fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
20
+ fp_np = np.zeros((1,))
21
+ DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
22
+ return fp_np.reshape(1, -1)
23
+
24
+
25
+ def fingerprints_from_smiles(smiles: List, size=2048):
26
+ """ Create ECFP fingerprints of smiles, with validity check """
27
+ fps = []
28
+ valid_mask = []
29
+ for i, smile in enumerate(smiles):
30
+ mol = Chem.MolFromSmiles(smile)
31
+ valid_mask.append(int(mol is not None))
32
+ fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
33
+ fps.append(fp)
34
+
35
+ fps = np.concatenate(fps, axis=0) if len(fps) > 0 else np.zeros((0, size))
36
+ return fps, valid_mask
37
+
38
+
39
+ def getMolDescriptors(mol, missingVal=0):
40
+ """ calculate the full list of descriptors for a molecule """
41
+
42
+ values, names = [], []
43
+ for nm, fn in Descriptors._descList:
44
+ try:
45
+ val = fn(mol)
46
+ except:
47
+ val = missingVal
48
+ values.append(val)
49
+ names.append(nm)
50
+
51
+ custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
52
+ 'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
53
+ 'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
54
+
55
+ for nm, fn in custom_descriptors.items():
56
+ try:
57
+ val = fn(mol)
58
+ except:
59
+ val = missingVal
60
+ values.append(val)
61
+ names.append(nm)
62
+ return values, names
63
+
64
+
65
+ def get_pep_dps_from_smi(smi):
66
+ try:
67
+ mol = Chem.MolFromSmiles(smi)
68
+ except:
69
+ print(f"convert smi {smi} to molecule failed!")
70
+ mol = None
71
+
72
+ dps, _ = getMolDescriptors(mol)
73
+ return np.array(dps)
74
+
75
+
76
+ def get_pep_dps(smi_list):
77
+ if len(smi_list) == 0:
78
+ return np.zeros((0, 211))
79
+ return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
80
+
81
+
82
+
83
+ def check_smi_validity(smiles: list):
84
+ valid_smi, valid_idx = [], []
85
+ for idx, smi in enumerate(smiles):
86
+ try:
87
+ mol = Chem.MolFromSmiles(smi) if smi else None
88
+ if mol:
89
+ valid_smi.append(smi)
90
+ valid_idx.append(idx)
91
+ except Exception as e:
92
+ # logger.debug(f'Error: {e} in smiles {smi}')
93
+ pass
94
+ return valid_smi, valid_idx
tr2d2-pep/scoring/functions/solubility.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xgboost as xgb
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoModelForMaskedLM
5
+ import warnings
6
+ import numpy as np
7
+ from rdkit import rdBase
8
+
9
+ rdBase.DisableLog('rdApp.error')
10
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
11
+ warnings.filterwarnings("ignore", category=UserWarning)
12
+ warnings.filterwarnings("ignore", category=FutureWarning)
13
+
14
+ class Solubility:
15
+ def __init__(self, tokenizer, base_path, device=None, emb_model=None):
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
17
+ self.predictor = xgb.Booster(model_file=f'{base_path}/TR2-D2/tr2d2-pep/scoring/functions/classifiers/solubility-xgboost.json')
18
+ if emb_model is not None:
19
+ self.emb_model = emb_model.to(self.device).eval()
20
+ else:
21
+ self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
22
+
23
+ self.tokenizer = tokenizer
24
+
25
+ def generate_embeddings(self, sequences):
26
+ embeddings = []
27
+ for sequence in sequences:
28
+ tokenized = self.tokenizer(sequence, return_tensors='pt')
29
+ tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
30
+ with torch.no_grad():
31
+ output = self.emb_model(**tokenized)
32
+ # Mean pooling across sequence length
33
+ embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
34
+ embeddings.append(embedding)
35
+ return np.array(embeddings)
36
+
37
+ def get_scores(self, input_seqs: list):
38
+ scores = np.zeros(len(input_seqs))
39
+ features = self.generate_embeddings(input_seqs)
40
+
41
+ if len(features) == 0:
42
+ return scores
43
+
44
+ features = np.nan_to_num(features, nan=0.)
45
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
46
+
47
+ features = xgb.DMatrix(features)
48
+
49
+ scores = self.predictor.predict(features)
50
+ return scores
51
+
52
+ def __call__(self, input_seqs: list):
53
+ scores = self.get_scores(input_seqs)
54
+ return scores
55
+
56
+ def unittest():
57
+ solubility = Solubility()
58
+ seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
59
+ scores = solubility(input_seqs=seq)
60
+ print(scores)
61
+
62
+ if __name__ == '__main__':
63
+ unittest()
tr2d2-pep/scoring/scoring_functions.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
2
+ from transformers import AutoModelForMaskedLM
3
+ import numpy as np
4
+ from scoring.functions.binding import BindingAffinity
5
+ from scoring.functions.permeability import Permeability
6
+ from scoring.functions.solubility import Solubility
7
+ from scoring.functions.hemolysis import Hemolysis
8
+ from scoring.functions.nonfouling import Nonfouling
9
+
10
+ base_path = '/path/to/your/home'
11
+
12
+ class ScoringFunctions:
13
+ def __init__(self, score_func_names=None, prot_seqs=None, device=None):
14
+ """
15
+ Class for generating score vectors given generated sequence
16
+
17
+ Args:
18
+ score_func_names: list of scoring function names to be evaluated
19
+ score_weights: weights to scale scores (default: 1)
20
+ target_protein: sequence of target protein binder
21
+ """
22
+ emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
23
+ tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/TR2-D2/tr2d2-pep/tokenizer/new_vocab.txt',
24
+ f'{base_path}/TR2-D2/tr2d2-pep/tokenizer/new_splits.txt')
25
+ prot_seqs = prot_seqs if prot_seqs is not None else []
26
+
27
+ if score_func_names is None:
28
+ # just do unmasking based on validity of peptide bonds
29
+ self.score_func_names = []
30
+ else:
31
+ self.score_func_names = score_func_names
32
+
33
+ # self.weights = np.array([1] * len(self.score_func_names) if score_weights is None else score_weights)
34
+
35
+ # binding affinities
36
+ self.target_protein = prot_seqs
37
+ print(len(prot_seqs))
38
+
39
+ if ('binding_affinity1' in score_func_names) and (len(prot_seqs) == 1):
40
+ binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
41
+ binding_affinity2 = None
42
+ elif ('binding_affinity1' in score_func_names) and ('binding_affinity2' in score_func_names) and (len(prot_seqs) == 2):
43
+ binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
44
+ binding_affinity2 = BindingAffinity(prot_seqs[1], tokenizer=tokenizer, base_path=base_path, device=device)
45
+ else:
46
+ print("here")
47
+ binding_affinity1 = None
48
+ binding_affinity2 = None
49
+
50
+ permeability = Permeability(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
51
+ sol = Solubility(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
52
+ nonfouling = Nonfouling(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
53
+ hemo = Hemolysis(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
54
+
55
+ self.all_funcs = {'binding_affinity1': binding_affinity1,
56
+ 'binding_affinity2': binding_affinity2,
57
+ 'permeability': permeability,
58
+ 'nonfouling': nonfouling,
59
+ 'solubility': sol,
60
+ 'hemolysis': hemo
61
+ }
62
+
63
+ def forward(self, input_seqs):
64
+ scores = []
65
+
66
+ for i, score_func in enumerate(self.score_func_names):
67
+ score = self.all_funcs[score_func](input_seqs = input_seqs)
68
+
69
+ scores.append(score)
70
+
71
+ # convert to numpy arrays with shape (num_sequences, num_functions)
72
+ scores = np.float32(scores).T
73
+
74
+ return scores
75
+
76
+ def __call__(self, input_seqs: list):
77
+ return self.forward(input_seqs)
tr2d2-pep/scoring/tokenizer/my_tokenizers.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os
3
+ import re
4
+ from typing import List, Optional
5
+ from transformers import PreTrainedTokenizer
6
+ from SmilesPE.tokenizer import SPE_Tokenizer
7
+ import torch
8
+
9
+ def load_vocab(vocab_file):
10
+ """Loads a vocabulary file into a dictionary."""
11
+ vocab = collections.OrderedDict()
12
+ with open(vocab_file, "r", encoding="utf-8") as reader:
13
+ tokens = reader.readlines()
14
+ for index, token in enumerate(tokens):
15
+ token = token.rstrip("\n")
16
+ vocab[token] = index
17
+ return vocab
18
+
19
+ class Atomwise_Tokenizer(object):
20
+ """Run atom-level SMILES tokenization"""
21
+
22
+ def __init__(self):
23
+ """ Constructs a atom-level Tokenizer.
24
+ """
25
+ # self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
26
+ self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
27
+
28
+ self.regex = re.compile(self.regex_pattern)
29
+
30
+ def tokenize(self, text):
31
+ """ Basic Tokenization of a SMILES.
32
+ """
33
+ tokens = [token for token in self.regex.findall(text)]
34
+ return tokens
35
+
36
+ class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
37
+ r"""
38
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
39
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
40
+ should refer to the superclass for more information regarding methods.
41
+ Args:
42
+ vocab_file (:obj:`string`):
43
+ File containing the vocabulary.
44
+ spe_file (:obj:`string`):
45
+ File containing the trained SMILES Pair Encoding vocabulary.
46
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
47
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
48
+ token instead.
49
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
50
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
51
+ for sequence classification or for a text and a question for question answering.
52
+ It is also used as the last token of a sequence built with special tokens.
53
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
54
+ The token used for padding, for example when batching sequences of different lengths.
55
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
56
+ The classifier token which is used when doing sequence classification (classification of the whole
57
+ sequence instead of per-token classification). It is the first token of the sequence when built with
58
+ special tokens.
59
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
60
+ The token used for masking values. This is the token used when training this model with masked language
61
+ modeling. This is the token which the model will try to predict.
62
+ """
63
+
64
+ def __init__(self, vocab_file, spe_file,
65
+ unk_token="[UNK]",
66
+ sep_token="[SEP]",
67
+ pad_token="[PAD]",
68
+ cls_token="[CLS]",
69
+ mask_token="[MASK]",
70
+ **kwargs):
71
+ if not os.path.isfile(vocab_file):
72
+ raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
73
+ if not os.path.isfile(spe_file):
74
+ raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
75
+
76
+ self.vocab = load_vocab(vocab_file)
77
+ self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
78
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
79
+ self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
80
+
81
+ super().__init__(
82
+ unk_token=unk_token,
83
+ sep_token=sep_token,
84
+ pad_token=pad_token,
85
+ cls_token=cls_token,
86
+ mask_token=mask_token,
87
+ **kwargs)
88
+
89
+ @property
90
+ def vocab_size(self):
91
+ return len(self.vocab)
92
+
93
+ def get_vocab(self):
94
+ return dict(self.vocab, **self.added_tokens_encoder)
95
+
96
+ def _tokenize(self, text):
97
+ return self.spe_tokenizer.tokenize(text).split(' ')
98
+
99
+ def _convert_token_to_id(self, token):
100
+ """ Converts a token (str) in an id using the vocab. """
101
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
102
+
103
+ # changed encode and decode functions
104
+ def encode(self, token_array):
105
+ token_ids = []
106
+ token_ids.append(2)
107
+ for token in token_array:
108
+ id = self._convert_token_to_id(token)
109
+ token_ids.append(id)
110
+ token_ids.append(3)
111
+ token_ids = torch.tensor([token_ids])
112
+ attn_mask = torch.ones_like(token_ids)
113
+ return {'input_ids': token_ids, 'attention_mask': attn_mask}
114
+
115
+ def decode(self, token_ids, skip_special_tokens=True):
116
+ token_ids = token_ids.squeeze(0).cpu().tolist()
117
+ token_array = []
118
+ for idx in token_ids:
119
+ if idx == 3: # Stop decoding when token ID 3 is encountered
120
+ break
121
+ if skip_special_tokens and idx in self.all_special_ids:
122
+ continue
123
+ token = self._convert_id_to_token(idx)
124
+ token_array.append(token)
125
+ sequence = "".join(token_array)
126
+ return sequence
127
+
128
+ def batch_decode(self, batch_token_ids, skip_special_tokens=True):
129
+ sequences = []
130
+ for token_ids in batch_token_ids:
131
+ sequences.append(self.decode(token_ids))
132
+ return sequences
133
+
134
+ def get_token_split(self, token_ids):
135
+ if isinstance(token_ids, torch.Tensor):
136
+ token_ids = token_ids.cpu().tolist()
137
+
138
+ token_array = []
139
+ for seq_ids in token_ids:
140
+ seq_array = []
141
+ for id in seq_ids:
142
+ token = self._convert_id_to_token(id)
143
+ seq_array.append(token)
144
+ token_array.append(seq_array)
145
+
146
+ return token_array
147
+
148
+ def _convert_id_to_token(self, index):
149
+ """Converts an index (integer) in a token (str) using the vocab."""
150
+ return self.ids_to_tokens.get(index, self.unk_token)
151
+
152
+ def convert_tokens_to_string(self, tokens):
153
+ """ Converts a sequence of tokens (string) in a single string. """
154
+ out_string = " ".join(tokens).replace(" ##", "").strip()
155
+ return out_string
156
+
157
+ def build_inputs_with_special_tokens(
158
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
159
+ ) -> List[int]:
160
+ """
161
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
162
+ by concatenating and adding special tokens.
163
+ A BERT sequence has the following format:
164
+ - single sequence: ``[CLS] X [SEP]``
165
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
166
+ Args:
167
+ token_ids_0 (:obj:`List[int]`):
168
+ List of IDs to which the special tokens will be added
169
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
170
+ Optional second list of IDs for sequence pairs.
171
+ Returns:
172
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
173
+ """
174
+ if token_ids_1 is None:
175
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
176
+ cls = [self.cls_token_id]
177
+ sep = [self.sep_token_id]
178
+ return cls + token_ids_0 + sep + token_ids_1 + sep
179
+
180
+ def get_special_tokens_mask(
181
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
182
+ ) -> List[int]:
183
+ """
184
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
185
+ special tokens using the tokenizer ``prepare_for_model`` method.
186
+ Args:
187
+ token_ids_0 (:obj:`List[int]`):
188
+ List of ids.
189
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
190
+ Optional second list of IDs for sequence pairs.
191
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
192
+ Set to True if the token list is already formatted with special tokens for the model
193
+ Returns:
194
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
195
+ """
196
+
197
+ if already_has_special_tokens:
198
+ if token_ids_1 is not None:
199
+ raise ValueError(
200
+ "You should not supply a second sequence if the provided sequence of "
201
+ "ids is already formated with special tokens for the model."
202
+ )
203
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
204
+
205
+ if token_ids_1 is not None:
206
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
207
+ return [1] + ([0] * len(token_ids_0)) + [1]
208
+
209
+ def create_token_type_ids_from_sequences(
210
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
211
+ ) -> List[int]:
212
+ """
213
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
214
+ A BERT sequence pair mask has the following format:
215
+ ::
216
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
217
+ | first sequence | second sequence |
218
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
219
+ Args:
220
+ token_ids_0 (:obj:`List[int]`):
221
+ List of ids.
222
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
223
+ Optional second list of IDs for sequence pairs.
224
+ Returns:
225
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
226
+ sequence(s).
227
+ """
228
+ sep = [self.sep_token_id]
229
+ cls = [self.cls_token_id]
230
+ if token_ids_1 is None:
231
+ return len(cls + token_ids_0 + sep) * [0]
232
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
233
+
234
+ def save_vocabulary(self, vocab_path):
235
+ """
236
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
237
+ Args:
238
+ vocab_path (:obj:`str`):
239
+ The directory in which to save the vocabulary.
240
+ Returns:
241
+ :obj:`Tuple(str)`: Paths to the files saved.
242
+ """
243
+ index = 0
244
+ vocab_file = vocab_path
245
+ with open(vocab_file, "w", encoding="utf-8") as writer:
246
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
247
+ if index != token_index:
248
+ index = token_index
249
+ writer.write(token + "\n")
250
+ index += 1
251
+ return (vocab_file,)
252
+
253
+ class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
254
+ r"""
255
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
256
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
257
+ should refer to the superclass for more information regarding methods.
258
+ Args:
259
+ vocab_file (:obj:`string`):
260
+ File containing the vocabulary.
261
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
262
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
263
+ token instead.
264
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
265
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
266
+ for sequence classification or for a text and a question for question answering.
267
+ It is also used as the last token of a sequence built with special tokens.
268
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
269
+ The token used for padding, for example when batching sequences of different lengths.
270
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
271
+ The classifier token which is used when doing sequence classification (classification of the whole
272
+ sequence instead of per-token classification). It is the first token of the sequence when built with
273
+ special tokens.
274
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
275
+ The token used for masking values. This is the token used when training this model with masked language
276
+ modeling. This is the token which the model will try to predict.
277
+ """
278
+
279
+ def __init__(
280
+ self,
281
+ vocab_file,
282
+ unk_token="[UNK]",
283
+ sep_token="[SEP]",
284
+ pad_token="[PAD]",
285
+ cls_token="[CLS]",
286
+ mask_token="[MASK]",
287
+ **kwargs
288
+ ):
289
+ super().__init__(
290
+ unk_token=unk_token,
291
+ sep_token=sep_token,
292
+ pad_token=pad_token,
293
+ cls_token=cls_token,
294
+ mask_token=mask_token,
295
+ **kwargs,
296
+ )
297
+
298
+ if not os.path.isfile(vocab_file):
299
+ raise ValueError(
300
+ "Can't find a vocabulary file at path '{}'.".format(vocab_file)
301
+ )
302
+ self.vocab = load_vocab(vocab_file)
303
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
304
+ self.tokenizer = Atomwise_Tokenizer()
305
+
306
+ @property
307
+ def vocab_size(self):
308
+ return len(self.vocab)
309
+
310
+ def get_vocab(self):
311
+ return dict(self.vocab, **self.added_tokens_encoder)
312
+
313
+
314
+ def _tokenize(self, text):
315
+ return self.tokenizer.tokenize(text)
316
+
317
+ def _convert_token_to_id(self, token):
318
+ """ Converts a token (str) in an id using the vocab. """
319
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
320
+
321
+ def _convert_id_to_token(self, index):
322
+ """Converts an index (integer) in a token (str) using the vocab."""
323
+ return self.ids_to_tokens.get(index, self.unk_token)
324
+
325
+ def convert_tokens_to_string(self, tokens):
326
+ """ Converts a sequence of tokens (string) in a single string. """
327
+ out_string = " ".join(tokens).replace(" ##", "").strip()
328
+ return out_string
329
+
330
+ def build_inputs_with_special_tokens(
331
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
332
+ ) -> List[int]:
333
+ """
334
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
335
+ by concatenating and adding special tokens.
336
+ A BERT sequence has the following format:
337
+ - single sequence: ``[CLS] X [SEP]``
338
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
339
+ Args:
340
+ token_ids_0 (:obj:`List[int]`):
341
+ List of IDs to which the special tokens will be added
342
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
343
+ Optional second list of IDs for sequence pairs.
344
+ Returns:
345
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
346
+ """
347
+ if token_ids_1 is None:
348
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
349
+ cls = [self.cls_token_id]
350
+ sep = [self.sep_token_id]
351
+ return cls + token_ids_0 + sep + token_ids_1 + sep
352
+
353
+ def get_special_tokens_mask(
354
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
355
+ ) -> List[int]:
356
+ """
357
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
358
+ special tokens using the tokenizer ``prepare_for_model`` method.
359
+ Args:
360
+ token_ids_0 (:obj:`List[int]`):
361
+ List of ids.
362
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
363
+ Optional second list of IDs for sequence pairs.
364
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
365
+ Set to True if the token list is already formatted with special tokens for the model
366
+ Returns:
367
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
368
+ """
369
+
370
+ if already_has_special_tokens:
371
+ if token_ids_1 is not None:
372
+ raise ValueError(
373
+ "You should not supply a second sequence if the provided sequence of "
374
+ "ids is already formated with special tokens for the model."
375
+ )
376
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
377
+
378
+ if token_ids_1 is not None:
379
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
380
+ return [1] + ([0] * len(token_ids_0)) + [1]
381
+
382
+ def create_token_type_ids_from_sequences(
383
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
384
+ ) -> List[int]:
385
+ """
386
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
387
+ A BERT sequence pair mask has the following format:
388
+ ::
389
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
390
+ | first sequence | second sequence |
391
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
392
+ Args:
393
+ token_ids_0 (:obj:`List[int]`):
394
+ List of ids.
395
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
396
+ Optional second list of IDs for sequence pairs.
397
+ Returns:
398
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
399
+ sequence(s).
400
+ """
401
+ sep = [self.sep_token_id]
402
+ cls = [self.cls_token_id]
403
+ if token_ids_1 is None:
404
+ return len(cls + token_ids_0 + sep) * [0]
405
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
406
+
407
+ def save_vocabulary(self, vocab_path):
408
+ """
409
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
410
+ Args:
411
+ vocab_path (:obj:`str`):
412
+ The directory in which to save the vocabulary.
413
+ Returns:
414
+ :obj:`Tuple(str)`: Paths to the files saved.
415
+ """
416
+ index = 0
417
+ vocab_file = vocab_path
418
+ with open(vocab_file, "w", encoding="utf-8") as writer:
419
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
420
+ if index != token_index:
421
+ index = token_index
422
+ writer.write(token + "\n")
423
+ index += 1
424
+ return (vocab_file,)
tr2d2-pep/scoring/tokenizer/new_splits.txt ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ c 1
2
+ c 2
3
+ c 3
4
+ c 4
5
+ c 5
6
+ c 6
7
+ c 7
8
+ c 8
9
+ c 9
10
+ ( c1
11
+ ( c2
12
+ c1 )
13
+ c2 )
14
+ n 1
15
+ n 2
16
+ n 3
17
+ n 4
18
+ n 5
19
+ n 6
20
+ n 7
21
+ n 8
22
+ n 9
23
+ ( n1
24
+ ( n2
25
+ n1 )
26
+ n2 )
27
+ O 1
28
+ O 2
29
+ O 3
30
+ O 4
31
+ O 5
32
+ O 6
33
+ O 7
34
+ O 8
35
+ O 9
36
+ ( O1
37
+ ( O2
38
+ O2 )
39
+ O2 )
40
+ = O
41
+ = C
42
+ = c
43
+ = N
44
+ = n
45
+ =C C
46
+ =C N
47
+ =C c
48
+ =c c
49
+ =N C
50
+ =N c
51
+ =n C
52
+ =n c
53
+ # N
54
+ # C
55
+ #N C
56
+ #C C
57
+ #C N
58
+ #N N
59
+ ( C
60
+ C )
61
+ ( O
62
+ O )
63
+ ( N
64
+ N )
65
+ Br c
66
+ ( =O
67
+ (=O )
68
+ C (=O)
69
+ C =O
70
+ C =N
71
+ C #N
72
+ C #C
73
+ C C
74
+ CC C
75
+ CC N
76
+ CC O
77
+ CC S
78
+ CC c
79
+ CC n
80
+ C N
81
+ CN C
82
+ CN c
83
+ C O
84
+ CO C
85
+ CO N
86
+ CO c
87
+ C S
88
+ CS C
89
+ CS S
90
+ CS c
91
+ C c
92
+ Cl c
93
+ C n
94
+ F c
95
+ N C
96
+ NC C
97
+ NC c
98
+ N N
99
+ N O
100
+ N c
101
+ N n
102
+ O C
103
+ OC C
104
+ OC O
105
+ OC c
106
+ O N
107
+ O O
108
+ O c
109
+ S C
110
+ SC C
111
+ SC c
112
+ S S
113
+ S c
114
+ c c
115
+ cc c
116
+ cc n
117
+ cc o
118
+ cc s
119
+ cc cc
120
+ c n
121
+ cn c
122
+ cn n
123
+ c o
124
+ co c
125
+ c s
126
+ cs c
127
+ cs n
128
+ n c
129
+ nc c
130
+ nc n
131
+ nc o
132
+ nc s
133
+ n n
134
+ nn c
135
+ nn n
136
+ n o
137
+ no c
138
+ no n
139
+ n s
140
+ ns c
141
+ ns n
142
+ o c
143
+ oc c
144
+ o n
145
+ s c
146
+ sc c
147
+ sc n
148
+ s n
149
+ N P
150
+ P N
151
+ C P
152
+ P C
153
+ N S
154
+ S N
155
+ C S
156
+ S C
157
+ S P
158
+ P S
159
+ C I
tr2d2-pep/scoring/tokenizer/new_vocab.txt ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [UNK]
3
+ [CLS]
4
+ [SEP]
5
+ [MASK]
6
+ #
7
+ %
8
+ (
9
+ )
10
+ +
11
+ -
12
+ /
13
+ 0
14
+ 1
15
+ 2
16
+ 3
17
+ 4
18
+ 5
19
+ 6
20
+ 7
21
+ 8
22
+ 9
23
+ =
24
+ @
25
+ A
26
+ B
27
+ Br
28
+ Brc
29
+ C
30
+ CC
31
+ CCC
32
+ CCN
33
+ CCO
34
+ CCS
35
+ CCc
36
+ CCn
37
+ CN
38
+ CNC
39
+ CNc
40
+ CO
41
+ COC
42
+ CON
43
+ COc
44
+ CS
45
+ CSC
46
+ CSS
47
+ CSc
48
+ Cc
49
+ Cl
50
+ Clc
51
+ Cn
52
+ F
53
+ Fc
54
+ H
55
+ I
56
+ K
57
+ L
58
+ M
59
+ N
60
+ NC
61
+ NCC
62
+ NCc
63
+ NN
64
+ NO
65
+ Nc
66
+ Nn
67
+ O
68
+ OC
69
+ OCC
70
+ OCO
71
+ OCc
72
+ ON
73
+ OO
74
+ Oc
75
+ P
76
+ R
77
+ S
78
+ SC
79
+ SCC
80
+ SCc
81
+ SS
82
+ Sc
83
+ T
84
+ X
85
+ Z
86
+ [
87
+ \\
88
+ (/
89
+ ]
90
+ a
91
+ b
92
+ c
93
+ cc
94
+ ccc
95
+ cccc
96
+ ccn
97
+ cco
98
+ ccs
99
+ cn
100
+ cnc
101
+ cnn
102
+ co
103
+ coc
104
+ cs
105
+ csc
106
+ csn
107
+ e
108
+ g
109
+ i
110
+ l
111
+ n
112
+ nc
113
+ ncc
114
+ ncn
115
+ nco
116
+ ncs
117
+ nn
118
+ nnc
119
+ nnn
120
+ no
121
+ noc
122
+ non
123
+ ns
124
+ nsc
125
+ nsn
126
+ o
127
+ oc
128
+ occ
129
+ on
130
+ p
131
+ r
132
+ s
133
+ sc
134
+ scc
135
+ scn
136
+ sn
137
+ t
138
+ c1
139
+ c2
140
+ c3
141
+ c4
142
+ c5
143
+ c6
144
+ c7
145
+ c8
146
+ c9
147
+ n1
148
+ n2
149
+ n3
150
+ n4
151
+ n5
152
+ n6
153
+ n7
154
+ n8
155
+ n9
156
+ O1
157
+ O2
158
+ O3
159
+ O4
160
+ O5
161
+ O6
162
+ O7
163
+ O8
164
+ O9
165
+ (c1
166
+ (c2
167
+ c1)
168
+ c2)
169
+ (n1
170
+ (n2
171
+ n1)
172
+ n2)
173
+ (O1
174
+ (O2
175
+ O2)
176
+ =O
177
+ =C
178
+ =c
179
+ =N
180
+ =n
181
+ =CC
182
+ =CN
183
+ =Cc
184
+ =cc
185
+ =NC
186
+ =Nc
187
+ =nC
188
+ =nc
189
+ #C
190
+ #CC
191
+ #CN
192
+ #N
193
+ #NC
194
+ #NN
195
+ (C
196
+ C)
197
+ (O
198
+ O)
199
+ (N
200
+ N)
201
+ NP
202
+ PN
203
+ CP
204
+ PC
205
+ NS
206
+ SN
207
+ SP
208
+ PS
209
+ C(=O)
210
+ (/Br)
211
+ (/C#N)
212
+ (/C)
213
+ (/C=N)
214
+ (/C=O)
215
+ (/CBr)
216
+ (/CC)
217
+ (/CCC)
218
+ (/CCF)
219
+ (/CCN)
220
+ (/CCO)
221
+ (/CCl)
222
+ (/CI)
223
+ (/CN)
224
+ (/CO)
225
+ (/CS)
226
+ (/Cl)
227
+ (/F)
228
+ (/I)
229
+ (/N)
230
+ (/NC)
231
+ (/NCC)
232
+ (/NO)
233
+ (/O)
234
+ (/OC)
235
+ (/OCC)
236
+ (/S)
237
+ (/SC)
238
+ (=C)
239
+ (=C/C)
240
+ (=C/F)
241
+ (=C/I)
242
+ (=C/N)
243
+ (=C/O)
244
+ (=CBr)
245
+ (=CC)
246
+ (=CCF)
247
+ (=CCN)
248
+ (=CCO)
249
+ (=CCl)
250
+ (=CF)
251
+ (=CI)
252
+ (=CN)
253
+ (=CO)
254
+ (=C\\C)
255
+ (=C\\F)
256
+ (=C\\I)
257
+ (=C\\N)
258
+ (=C\\O)
259
+ (=N)
260
+ (=N/C)
261
+ (=N/N)
262
+ (=N/O)
263
+ (=NBr)
264
+ (=NC)
265
+ (=NCC)
266
+ (=NCl)
267
+ (=NN)
268
+ (=NO)
269
+ (=NOC)
270
+ (=N\\C)
271
+ (=N\\N)
272
+ (=N\\O)
273
+ (=O)
274
+ (=S)
275
+ (B)
276
+ (Br)
277
+ (C#C)
278
+ (C#CC)
279
+ (C#CI)
280
+ (C#CO)
281
+ (C#N)
282
+ (C#SN)
283
+ (C)
284
+ (C=C)
285
+ (C=CF)
286
+ (C=CI)
287
+ (C=N)
288
+ (C=NN)
289
+ (C=NO)
290
+ (C=O)
291
+ (C=S)
292
+ (CBr)
293
+ (CC#C)
294
+ (CC#N)
295
+ (CC)
296
+ (CC=C)
297
+ (CC=O)
298
+ (CCBr)
299
+ (CCC)
300
+ (CCCC)
301
+ (CCCF)
302
+ (CCCI)
303
+ (CCCN)
304
+ (CCCO)
305
+ (CCCS)
306
+ (CCCl)
307
+ (CCF)
308
+ (CCI)
309
+ (CCN)
310
+ (CCNC)
311
+ (CCNN)
312
+ (CCNO)
313
+ (CCO)
314
+ (CCOC)
315
+ (CCON)
316
+ (CCS)
317
+ (CCSC)
318
+ (CCl)
319
+ (CF)
320
+ (CI)
321
+ (CN)
322
+ (CN=O)
323
+ (CNC)
324
+ (CNCC)
325
+ (CNCO)
326
+ (CNN)
327
+ (CNNC)
328
+ (CNO)
329
+ (CNOC)
330
+ (CO)
331
+ (COC)
332
+ (COCC)
333
+ (COCI)
334
+ (COCN)
335
+ (COCO)
336
+ (COF)
337
+ (CON)
338
+ (COO)
339
+ (CS)
340
+ (CSC)
341
+ (CSCC)
342
+ (CSCF)
343
+ (CSO)
344
+ (Cl)
345
+ (F)
346
+ (I)
347
+ (N)
348
+ (N=N)
349
+ (N=NO)
350
+ (N=O)
351
+ (N=S)
352
+ (NBr)
353
+ (NC#N)
354
+ (NC)
355
+ (NC=N)
356
+ (NC=O)
357
+ (NC=S)
358
+ (NCBr)
359
+ (NCC)
360
+ (NCCC)
361
+ (NCCF)
362
+ (NCCN)
363
+ (NCCO)
364
+ (NCCS)
365
+ (NCCl)
366
+ (NCNC)
367
+ (NCO)
368
+ (NCS)
369
+ (NCl)
370
+ (NN)
371
+ (NN=O)
372
+ (NNC)
373
+ (NO)
374
+ (NOC)
375
+ (O)
376
+ (OC#N)
377
+ (OC)
378
+ (OC=C)
379
+ (OC=O)
380
+ (OC=S)
381
+ (OCBr)
382
+ (OCC)
383
+ (OCCC)
384
+ (OCCF)
385
+ (OCCI)
386
+ (OCCN)
387
+ (OCCO)
388
+ (OCCS)
389
+ (OCCl)
390
+ (OCF)
391
+ (OCI)
392
+ (OCO)
393
+ (OCOC)
394
+ (OCON)
395
+ (OCSC)
396
+ (OCl)
397
+ (OI)
398
+ (ON)
399
+ (OO)
400
+ (OOC)
401
+ (OOCC)
402
+ (OOSN)
403
+ (OSC)
404
+ (P)
405
+ (S)
406
+ (SC#N)
407
+ (SC)
408
+ (SCC)
409
+ (SCCC)
410
+ (SCCF)
411
+ (SCCN)
412
+ (SCCO)
413
+ (SCCS)
414
+ (SCCl)
415
+ (SCF)
416
+ (SCN)
417
+ (SCOC)
418
+ (SCSC)
419
+ (SCl)
420
+ (SI)
421
+ (SN)
422
+ (SN=O)
423
+ (SO)
424
+ (SOC)
425
+ (SOOO)
426
+ (SS)
427
+ (SSC)
428
+ (SSCC)
429
+ ([At])
430
+ ([O-])
431
+ ([O])
432
+ ([S-])
433
+ (\\Br)
434
+ (\\C#N)
435
+ (\\C)
436
+ (\\C=N)
437
+ (\\C=O)
438
+ (\\CBr)
439
+ (\\CC)
440
+ (\\CCC)
441
+ (\\CCO)
442
+ (\\CCl)
443
+ (\\CF)
444
+ (\\CN)
445
+ (\\CNC)
446
+ (\\CO)
447
+ (\\COC)
448
+ (\\Cl)
449
+ (\\F)
450
+ (\\I)
451
+ (\\N)
452
+ (\\NC)
453
+ (\\NCC)
454
+ (\\NN)
455
+ (\\NO)
456
+ (\\NOC)
457
+ (\\O)
458
+ (\\OC)
459
+ (\\OCC)
460
+ (\\ON)
461
+ (\\S)
462
+ (\\SC)
463
+ (\\SCC)
464
+ [Ag+]
465
+ [Ag-4]
466
+ [Ag]
467
+ [Al-3]
468
+ [Al]
469
+ [As+]
470
+ [AsH3]
471
+ [AsH]
472
+ [As]
473
+ [At]
474
+ [B-]
475
+ [B@-]
476
+ [B@@-]
477
+ [BH-]
478
+ [BH2-]
479
+ [BH3-]
480
+ [B]
481
+ [Ba]
482
+ [Br+2]
483
+ [BrH]
484
+ [Br]
485
+ [C+]
486
+ [C-]
487
+ [C@@H]
488
+ [C@@]
489
+ [C@H]
490
+ [C@]
491
+ [CH-]
492
+ [CH2]
493
+ [CH3]
494
+ [CH]
495
+ [C]
496
+ [CaH2]
497
+ [Ca]
498
+ [Cl+2]
499
+ [Cl+3]
500
+ [Cl+]
501
+ [Cs]
502
+ [FH]
503
+ [F]
504
+ [H]
505
+ [He]
506
+ [I+2]
507
+ [I+3]
508
+ [I+]
509
+ [IH]
510
+ [I]
511
+ [K]
512
+ [Kr]
513
+ [Li+]
514
+ [LiH]
515
+ [MgH2]
516
+ [Mg]
517
+ [N+]
518
+ [N-]
519
+ [N@+]
520
+ [N@@+]
521
+ [N@@]
522
+ [N@]
523
+ [NH+]
524
+ [NH-]
525
+ [NH2+]
526
+ [NH3]
527
+ [NH]
528
+ [N]
529
+ [Na]
530
+ [O+]
531
+ [O-]
532
+ [OH+]
533
+ [OH2]
534
+ [OH]
535
+ [O]
536
+ [P+]
537
+ [P@+]
538
+ [P@@+]
539
+ [P@@]
540
+ [P@]
541
+ [PH2]
542
+ [PH]
543
+ [P]
544
+ [Ra]
545
+ [Rb]
546
+ [S+]
547
+ [S-]
548
+ [S@+]
549
+ [S@@+]
550
+ [S@@]
551
+ [S@]
552
+ [SH+]
553
+ [SH2]
554
+ [SH]
555
+ [S]
556
+ [Se+]
557
+ [Se-2]
558
+ [SeH2]
559
+ [SeH]
560
+ [Se]
561
+ [Si@]
562
+ [SiH2]
563
+ [SiH]
564
+ [Si]
565
+ [SrH2]
566
+ [TeH]
567
+ [Te]
568
+ [Xe]
569
+ [Zn+2]
570
+ [Zn-2]
571
+ [Zn]
572
+ [b-]
573
+ [c+]
574
+ [c-]
575
+ [cH-]
576
+ [cH]
577
+ [c]
578
+ [n+]
579
+ [n-]
580
+ [nH]
581
+ [n]
582
+ [o+]
583
+ [s+]
584
+ [se+]
585
+ [se]
586
+ [te+]
587
+ [te]
tr2d2-pep/tokenizer/my_tokenizers.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os
3
+ import re
4
+ from typing import List, Optional
5
+ from transformers import PreTrainedTokenizer
6
+ from SmilesPE.tokenizer import SPE_Tokenizer
7
+ import torch
8
+
9
+ def load_vocab(vocab_file):
10
+ """Loads a vocabulary file into a dictionary."""
11
+ vocab = collections.OrderedDict()
12
+ with open(vocab_file, "r", encoding="utf-8") as reader:
13
+ tokens = reader.readlines()
14
+ for index, token in enumerate(tokens):
15
+ token = token.rstrip("\n")
16
+ vocab[token] = index
17
+ return vocab
18
+
19
+ class Atomwise_Tokenizer(object):
20
+ """Run atom-level SMILES tokenization"""
21
+
22
+ def __init__(self):
23
+ """ Constructs a atom-level Tokenizer.
24
+ """
25
+ # self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
26
+ self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
27
+
28
+ self.regex = re.compile(self.regex_pattern)
29
+
30
+ def tokenize(self, text):
31
+ """ Basic Tokenization of a SMILES.
32
+ """
33
+ tokens = [token for token in self.regex.findall(text)]
34
+ return tokens
35
+
36
+ class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
37
+ r"""
38
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
39
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
40
+ should refer to the superclass for more information regarding methods.
41
+ Args:
42
+ vocab_file (:obj:`string`):
43
+ File containing the vocabulary.
44
+ spe_file (:obj:`string`):
45
+ File containing the trained SMILES Pair Encoding vocabulary.
46
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
47
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
48
+ token instead.
49
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
50
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
51
+ for sequence classification or for a text and a question for question answering.
52
+ It is also used as the last token of a sequence built with special tokens.
53
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
54
+ The token used for padding, for example when batching sequences of different lengths.
55
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
56
+ The classifier token which is used when doing sequence classification (classification of the whole
57
+ sequence instead of per-token classification). It is the first token of the sequence when built with
58
+ special tokens.
59
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
60
+ The token used for masking values. This is the token used when training this model with masked language
61
+ modeling. This is the token which the model will try to predict.
62
+ """
63
+
64
+ def __init__(self, vocab_file, spe_file,
65
+ unk_token="[UNK]",
66
+ sep_token="[SEP]",
67
+ pad_token="[PAD]",
68
+ cls_token="[CLS]",
69
+ mask_token="[MASK]",
70
+ **kwargs):
71
+ if not os.path.isfile(vocab_file):
72
+ raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
73
+ if not os.path.isfile(spe_file):
74
+ raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
75
+
76
+ self.vocab = load_vocab(vocab_file)
77
+ self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
78
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
79
+ self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
80
+
81
+ super().__init__(
82
+ unk_token=unk_token,
83
+ sep_token=sep_token,
84
+ pad_token=pad_token,
85
+ cls_token=cls_token,
86
+ mask_token=mask_token,
87
+ **kwargs)
88
+
89
+ @property
90
+ def vocab_size(self):
91
+ return len(self.vocab)
92
+
93
+ def get_vocab(self):
94
+ return dict(self.vocab, **self.added_tokens_encoder)
95
+
96
+ def _tokenize(self, text):
97
+ return self.spe_tokenizer.tokenize(text).split(' ')
98
+
99
+ def _convert_token_to_id(self, token):
100
+ """ Converts a token (str) in an id using the vocab. """
101
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
102
+
103
+ # changed encode and decode functions
104
+ def encode(self, token_array):
105
+ token_ids = []
106
+ token_ids.append(2)
107
+ for token in token_array:
108
+ id = self._convert_token_to_id(token)
109
+ token_ids.append(id)
110
+ token_ids.append(3)
111
+ token_ids = torch.tensor([token_ids])
112
+ attn_mask = torch.ones_like(token_ids)
113
+ return {'input_ids': token_ids, 'attention_mask': attn_mask}
114
+
115
+ def decode(self, token_ids, skip_special_tokens=True):
116
+ token_ids = token_ids.squeeze(0).cpu().tolist()
117
+ token_array = []
118
+ for idx in token_ids:
119
+ if idx == 3: # Stop decoding when token ID 3 is encountered
120
+ break
121
+ if skip_special_tokens and idx in self.all_special_ids:
122
+ continue
123
+ token = self._convert_id_to_token(idx)
124
+ token_array.append(token)
125
+ sequence = "".join(token_array)
126
+ return sequence
127
+
128
+ def batch_decode(self, batch_token_ids, skip_special_tokens=True):
129
+ sequences = []
130
+ for token_ids in batch_token_ids:
131
+ sequences.append(self.decode(token_ids))
132
+ return sequences
133
+
134
+ def get_token_split(self, token_ids):
135
+ if isinstance(token_ids, torch.Tensor):
136
+ token_ids = token_ids.cpu().tolist()
137
+
138
+ token_array = []
139
+ for seq_ids in token_ids:
140
+ seq_array = []
141
+ for id in seq_ids:
142
+ token = self._convert_id_to_token(id)
143
+ seq_array.append(token)
144
+ token_array.append(seq_array)
145
+
146
+ return token_array
147
+
148
+ def _convert_id_to_token(self, index):
149
+ """Converts an index (integer) in a token (str) using the vocab."""
150
+ return self.ids_to_tokens.get(index, self.unk_token)
151
+
152
+ def convert_tokens_to_string(self, tokens):
153
+ """ Converts a sequence of tokens (string) in a single string. """
154
+ out_string = " ".join(tokens).replace(" ##", "").strip()
155
+ return out_string
156
+
157
+ def build_inputs_with_special_tokens(
158
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
159
+ ) -> List[int]:
160
+ """
161
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
162
+ by concatenating and adding special tokens.
163
+ A BERT sequence has the following format:
164
+ - single sequence: ``[CLS] X [SEP]``
165
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
166
+ Args:
167
+ token_ids_0 (:obj:`List[int]`):
168
+ List of IDs to which the special tokens will be added
169
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
170
+ Optional second list of IDs for sequence pairs.
171
+ Returns:
172
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
173
+ """
174
+ if token_ids_1 is None:
175
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
176
+ cls = [self.cls_token_id]
177
+ sep = [self.sep_token_id]
178
+ return cls + token_ids_0 + sep + token_ids_1 + sep
179
+
180
+ def get_special_tokens_mask(
181
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
182
+ ) -> List[int]:
183
+ """
184
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
185
+ special tokens using the tokenizer ``prepare_for_model`` method.
186
+ Args:
187
+ token_ids_0 (:obj:`List[int]`):
188
+ List of ids.
189
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
190
+ Optional second list of IDs for sequence pairs.
191
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
192
+ Set to True if the token list is already formatted with special tokens for the model
193
+ Returns:
194
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
195
+ """
196
+
197
+ if already_has_special_tokens:
198
+ if token_ids_1 is not None:
199
+ raise ValueError(
200
+ "You should not supply a second sequence if the provided sequence of "
201
+ "ids is already formated with special tokens for the model."
202
+ )
203
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
204
+
205
+ if token_ids_1 is not None:
206
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
207
+ return [1] + ([0] * len(token_ids_0)) + [1]
208
+
209
+ def create_token_type_ids_from_sequences(
210
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
211
+ ) -> List[int]:
212
+ """
213
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
214
+ A BERT sequence pair mask has the following format:
215
+ ::
216
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
217
+ | first sequence | second sequence |
218
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
219
+ Args:
220
+ token_ids_0 (:obj:`List[int]`):
221
+ List of ids.
222
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
223
+ Optional second list of IDs for sequence pairs.
224
+ Returns:
225
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
226
+ sequence(s).
227
+ """
228
+ sep = [self.sep_token_id]
229
+ cls = [self.cls_token_id]
230
+ if token_ids_1 is None:
231
+ return len(cls + token_ids_0 + sep) * [0]
232
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
233
+
234
+ def save_vocabulary(self, vocab_path):
235
+ """
236
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
237
+ Args:
238
+ vocab_path (:obj:`str`):
239
+ The directory in which to save the vocabulary.
240
+ Returns:
241
+ :obj:`Tuple(str)`: Paths to the files saved.
242
+ """
243
+ index = 0
244
+ vocab_file = vocab_path
245
+ with open(vocab_file, "w", encoding="utf-8") as writer:
246
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
247
+ if index != token_index:
248
+ index = token_index
249
+ writer.write(token + "\n")
250
+ index += 1
251
+ return (vocab_file,)
252
+
253
+ class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
254
+ r"""
255
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
256
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
257
+ should refer to the superclass for more information regarding methods.
258
+ Args:
259
+ vocab_file (:obj:`string`):
260
+ File containing the vocabulary.
261
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
262
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
263
+ token instead.
264
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
265
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
266
+ for sequence classification or for a text and a question for question answering.
267
+ It is also used as the last token of a sequence built with special tokens.
268
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
269
+ The token used for padding, for example when batching sequences of different lengths.
270
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
271
+ The classifier token which is used when doing sequence classification (classification of the whole
272
+ sequence instead of per-token classification). It is the first token of the sequence when built with
273
+ special tokens.
274
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
275
+ The token used for masking values. This is the token used when training this model with masked language
276
+ modeling. This is the token which the model will try to predict.
277
+ """
278
+
279
+ def __init__(
280
+ self,
281
+ vocab_file,
282
+ unk_token="[UNK]",
283
+ sep_token="[SEP]",
284
+ pad_token="[PAD]",
285
+ cls_token="[CLS]",
286
+ mask_token="[MASK]",
287
+ **kwargs
288
+ ):
289
+ super().__init__(
290
+ unk_token=unk_token,
291
+ sep_token=sep_token,
292
+ pad_token=pad_token,
293
+ cls_token=cls_token,
294
+ mask_token=mask_token,
295
+ **kwargs,
296
+ )
297
+
298
+ if not os.path.isfile(vocab_file):
299
+ raise ValueError(
300
+ "Can't find a vocabulary file at path '{}'.".format(vocab_file)
301
+ )
302
+ self.vocab = load_vocab(vocab_file)
303
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
304
+ self.tokenizer = Atomwise_Tokenizer()
305
+
306
+ @property
307
+ def vocab_size(self):
308
+ return len(self.vocab)
309
+
310
+ def get_vocab(self):
311
+ return dict(self.vocab, **self.added_tokens_encoder)
312
+
313
+
314
+ def _tokenize(self, text):
315
+ return self.tokenizer.tokenize(text)
316
+
317
+ def _convert_token_to_id(self, token):
318
+ """ Converts a token (str) in an id using the vocab. """
319
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
320
+
321
+ def _convert_id_to_token(self, index):
322
+ """Converts an index (integer) in a token (str) using the vocab."""
323
+ return self.ids_to_tokens.get(index, self.unk_token)
324
+
325
+ def convert_tokens_to_string(self, tokens):
326
+ """ Converts a sequence of tokens (string) in a single string. """
327
+ out_string = " ".join(tokens).replace(" ##", "").strip()
328
+ return out_string
329
+
330
+ def build_inputs_with_special_tokens(
331
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
332
+ ) -> List[int]:
333
+ """
334
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
335
+ by concatenating and adding special tokens.
336
+ A BERT sequence has the following format:
337
+ - single sequence: ``[CLS] X [SEP]``
338
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
339
+ Args:
340
+ token_ids_0 (:obj:`List[int]`):
341
+ List of IDs to which the special tokens will be added
342
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
343
+ Optional second list of IDs for sequence pairs.
344
+ Returns:
345
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
346
+ """
347
+ if token_ids_1 is None:
348
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
349
+ cls = [self.cls_token_id]
350
+ sep = [self.sep_token_id]
351
+ return cls + token_ids_0 + sep + token_ids_1 + sep
352
+
353
+ def get_special_tokens_mask(
354
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
355
+ ) -> List[int]:
356
+ """
357
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
358
+ special tokens using the tokenizer ``prepare_for_model`` method.
359
+ Args:
360
+ token_ids_0 (:obj:`List[int]`):
361
+ List of ids.
362
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
363
+ Optional second list of IDs for sequence pairs.
364
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
365
+ Set to True if the token list is already formatted with special tokens for the model
366
+ Returns:
367
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
368
+ """
369
+
370
+ if already_has_special_tokens:
371
+ if token_ids_1 is not None:
372
+ raise ValueError(
373
+ "You should not supply a second sequence if the provided sequence of "
374
+ "ids is already formated with special tokens for the model."
375
+ )
376
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
377
+
378
+ if token_ids_1 is not None:
379
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
380
+ return [1] + ([0] * len(token_ids_0)) + [1]
381
+
382
+ def create_token_type_ids_from_sequences(
383
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
384
+ ) -> List[int]:
385
+ """
386
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
387
+ A BERT sequence pair mask has the following format:
388
+ ::
389
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
390
+ | first sequence | second sequence |
391
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
392
+ Args:
393
+ token_ids_0 (:obj:`List[int]`):
394
+ List of ids.
395
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
396
+ Optional second list of IDs for sequence pairs.
397
+ Returns:
398
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
399
+ sequence(s).
400
+ """
401
+ sep = [self.sep_token_id]
402
+ cls = [self.cls_token_id]
403
+ if token_ids_1 is None:
404
+ return len(cls + token_ids_0 + sep) * [0]
405
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
406
+
407
+ def save_vocabulary(self, vocab_path):
408
+ """
409
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
410
+ Args:
411
+ vocab_path (:obj:`str`):
412
+ The directory in which to save the vocabulary.
413
+ Returns:
414
+ :obj:`Tuple(str)`: Paths to the files saved.
415
+ """
416
+ index = 0
417
+ vocab_file = vocab_path
418
+ with open(vocab_file, "w", encoding="utf-8") as writer:
419
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
420
+ if index != token_index:
421
+ index = token_index
422
+ writer.write(token + "\n")
423
+ index += 1
424
+ return (vocab_file,)
tr2d2-pep/tokenizer/new_splits.txt ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ c 1
2
+ c 2
3
+ c 3
4
+ c 4
5
+ c 5
6
+ c 6
7
+ c 7
8
+ c 8
9
+ c 9
10
+ ( c1
11
+ ( c2
12
+ c1 )
13
+ c2 )
14
+ n 1
15
+ n 2
16
+ n 3
17
+ n 4
18
+ n 5
19
+ n 6
20
+ n 7
21
+ n 8
22
+ n 9
23
+ ( n1
24
+ ( n2
25
+ n1 )
26
+ n2 )
27
+ O 1
28
+ O 2
29
+ O 3
30
+ O 4
31
+ O 5
32
+ O 6
33
+ O 7
34
+ O 8
35
+ O 9
36
+ ( O1
37
+ ( O2
38
+ O2 )
39
+ O2 )
40
+ = O
41
+ = C
42
+ = c
43
+ = N
44
+ = n
45
+ =C C
46
+ =C N
47
+ =C c
48
+ =c c
49
+ =N C
50
+ =N c
51
+ =n C
52
+ =n c
53
+ # N
54
+ # C
55
+ #N C
56
+ #C C
57
+ #C N
58
+ #N N
59
+ ( C
60
+ C )
61
+ ( O
62
+ O )
63
+ ( N
64
+ N )
65
+ Br c
66
+ ( =O
67
+ (=O )
68
+ C (=O)
69
+ C =O
70
+ C =N
71
+ C #N
72
+ C #C
73
+ C C
74
+ CC C
75
+ CC N
76
+ CC O
77
+ CC S
78
+ CC c
79
+ CC n
80
+ C N
81
+ CN C
82
+ CN c
83
+ C O
84
+ CO C
85
+ CO N
86
+ CO c
87
+ C S
88
+ CS C
89
+ CS S
90
+ CS c
91
+ C c
92
+ Cl c
93
+ C n
94
+ F c
95
+ N C
96
+ NC C
97
+ NC c
98
+ N N
99
+ N O
100
+ N c
101
+ N n
102
+ O C
103
+ OC C
104
+ OC O
105
+ OC c
106
+ O N
107
+ O O
108
+ O c
109
+ S C
110
+ SC C
111
+ SC c
112
+ S S
113
+ S c
114
+ c c
115
+ cc c
116
+ cc n
117
+ cc o
118
+ cc s
119
+ cc cc
120
+ c n
121
+ cn c
122
+ cn n
123
+ c o
124
+ co c
125
+ c s
126
+ cs c
127
+ cs n
128
+ n c
129
+ nc c
130
+ nc n
131
+ nc o
132
+ nc s
133
+ n n
134
+ nn c
135
+ nn n
136
+ n o
137
+ no c
138
+ no n
139
+ n s
140
+ ns c
141
+ ns n
142
+ o c
143
+ oc c
144
+ o n
145
+ s c
146
+ sc c
147
+ sc n
148
+ s n
149
+ N P
150
+ P N
151
+ C P
152
+ P C
153
+ N S
154
+ S N
155
+ C S
156
+ S C
157
+ S P
158
+ P S
159
+ C I
tr2d2-pep/tokenizer/new_vocab.txt ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [UNK]
3
+ [CLS]
4
+ [SEP]
5
+ [MASK]
6
+ #
7
+ %
8
+ (
9
+ )
10
+ +
11
+ -
12
+ /
13
+ 0
14
+ 1
15
+ 2
16
+ 3
17
+ 4
18
+ 5
19
+ 6
20
+ 7
21
+ 8
22
+ 9
23
+ =
24
+ @
25
+ A
26
+ B
27
+ Br
28
+ Brc
29
+ C
30
+ CC
31
+ CCC
32
+ CCN
33
+ CCO
34
+ CCS
35
+ CCc
36
+ CCn
37
+ CN
38
+ CNC
39
+ CNc
40
+ CO
41
+ COC
42
+ CON
43
+ COc
44
+ CS
45
+ CSC
46
+ CSS
47
+ CSc
48
+ Cc
49
+ Cl
50
+ Clc
51
+ Cn
52
+ F
53
+ Fc
54
+ H
55
+ I
56
+ K
57
+ L
58
+ M
59
+ N
60
+ NC
61
+ NCC
62
+ NCc
63
+ NN
64
+ NO
65
+ Nc
66
+ Nn
67
+ O
68
+ OC
69
+ OCC
70
+ OCO
71
+ OCc
72
+ ON
73
+ OO
74
+ Oc
75
+ P
76
+ R
77
+ S
78
+ SC
79
+ SCC
80
+ SCc
81
+ SS
82
+ Sc
83
+ T
84
+ X
85
+ Z
86
+ [
87
+ \\
88
+ (/
89
+ ]
90
+ a
91
+ b
92
+ c
93
+ cc
94
+ ccc
95
+ cccc
96
+ ccn
97
+ cco
98
+ ccs
99
+ cn
100
+ cnc
101
+ cnn
102
+ co
103
+ coc
104
+ cs
105
+ csc
106
+ csn
107
+ e
108
+ g
109
+ i
110
+ l
111
+ n
112
+ nc
113
+ ncc
114
+ ncn
115
+ nco
116
+ ncs
117
+ nn
118
+ nnc
119
+ nnn
120
+ no
121
+ noc
122
+ non
123
+ ns
124
+ nsc
125
+ nsn
126
+ o
127
+ oc
128
+ occ
129
+ on
130
+ p
131
+ r
132
+ s
133
+ sc
134
+ scc
135
+ scn
136
+ sn
137
+ t
138
+ c1
139
+ c2
140
+ c3
141
+ c4
142
+ c5
143
+ c6
144
+ c7
145
+ c8
146
+ c9
147
+ n1
148
+ n2
149
+ n3
150
+ n4
151
+ n5
152
+ n6
153
+ n7
154
+ n8
155
+ n9
156
+ O1
157
+ O2
158
+ O3
159
+ O4
160
+ O5
161
+ O6
162
+ O7
163
+ O8
164
+ O9
165
+ (c1
166
+ (c2
167
+ c1)
168
+ c2)
169
+ (n1
170
+ (n2
171
+ n1)
172
+ n2)
173
+ (O1
174
+ (O2
175
+ O2)
176
+ =O
177
+ =C
178
+ =c
179
+ =N
180
+ =n
181
+ =CC
182
+ =CN
183
+ =Cc
184
+ =cc
185
+ =NC
186
+ =Nc
187
+ =nC
188
+ =nc
189
+ #C
190
+ #CC
191
+ #CN
192
+ #N
193
+ #NC
194
+ #NN
195
+ (C
196
+ C)
197
+ (O
198
+ O)
199
+ (N
200
+ N)
201
+ NP
202
+ PN
203
+ CP
204
+ PC
205
+ NS
206
+ SN
207
+ SP
208
+ PS
209
+ C(=O)
210
+ (/Br)
211
+ (/C#N)
212
+ (/C)
213
+ (/C=N)
214
+ (/C=O)
215
+ (/CBr)
216
+ (/CC)
217
+ (/CCC)
218
+ (/CCF)
219
+ (/CCN)
220
+ (/CCO)
221
+ (/CCl)
222
+ (/CI)
223
+ (/CN)
224
+ (/CO)
225
+ (/CS)
226
+ (/Cl)
227
+ (/F)
228
+ (/I)
229
+ (/N)
230
+ (/NC)
231
+ (/NCC)
232
+ (/NO)
233
+ (/O)
234
+ (/OC)
235
+ (/OCC)
236
+ (/S)
237
+ (/SC)
238
+ (=C)
239
+ (=C/C)
240
+ (=C/F)
241
+ (=C/I)
242
+ (=C/N)
243
+ (=C/O)
244
+ (=CBr)
245
+ (=CC)
246
+ (=CCF)
247
+ (=CCN)
248
+ (=CCO)
249
+ (=CCl)
250
+ (=CF)
251
+ (=CI)
252
+ (=CN)
253
+ (=CO)
254
+ (=C\\C)
255
+ (=C\\F)
256
+ (=C\\I)
257
+ (=C\\N)
258
+ (=C\\O)
259
+ (=N)
260
+ (=N/C)
261
+ (=N/N)
262
+ (=N/O)
263
+ (=NBr)
264
+ (=NC)
265
+ (=NCC)
266
+ (=NCl)
267
+ (=NN)
268
+ (=NO)
269
+ (=NOC)
270
+ (=N\\C)
271
+ (=N\\N)
272
+ (=N\\O)
273
+ (=O)
274
+ (=S)
275
+ (B)
276
+ (Br)
277
+ (C#C)
278
+ (C#CC)
279
+ (C#CI)
280
+ (C#CO)
281
+ (C#N)
282
+ (C#SN)
283
+ (C)
284
+ (C=C)
285
+ (C=CF)
286
+ (C=CI)
287
+ (C=N)
288
+ (C=NN)
289
+ (C=NO)
290
+ (C=O)
291
+ (C=S)
292
+ (CBr)
293
+ (CC#C)
294
+ (CC#N)
295
+ (CC)
296
+ (CC=C)
297
+ (CC=O)
298
+ (CCBr)
299
+ (CCC)
300
+ (CCCC)
301
+ (CCCF)
302
+ (CCCI)
303
+ (CCCN)
304
+ (CCCO)
305
+ (CCCS)
306
+ (CCCl)
307
+ (CCF)
308
+ (CCI)
309
+ (CCN)
310
+ (CCNC)
311
+ (CCNN)
312
+ (CCNO)
313
+ (CCO)
314
+ (CCOC)
315
+ (CCON)
316
+ (CCS)
317
+ (CCSC)
318
+ (CCl)
319
+ (CF)
320
+ (CI)
321
+ (CN)
322
+ (CN=O)
323
+ (CNC)
324
+ (CNCC)
325
+ (CNCO)
326
+ (CNN)
327
+ (CNNC)
328
+ (CNO)
329
+ (CNOC)
330
+ (CO)
331
+ (COC)
332
+ (COCC)
333
+ (COCI)
334
+ (COCN)
335
+ (COCO)
336
+ (COF)
337
+ (CON)
338
+ (COO)
339
+ (CS)
340
+ (CSC)
341
+ (CSCC)
342
+ (CSCF)
343
+ (CSO)
344
+ (Cl)
345
+ (F)
346
+ (I)
347
+ (N)
348
+ (N=N)
349
+ (N=NO)
350
+ (N=O)
351
+ (N=S)
352
+ (NBr)
353
+ (NC#N)
354
+ (NC)
355
+ (NC=N)
356
+ (NC=O)
357
+ (NC=S)
358
+ (NCBr)
359
+ (NCC)
360
+ (NCCC)
361
+ (NCCF)
362
+ (NCCN)
363
+ (NCCO)
364
+ (NCCS)
365
+ (NCCl)
366
+ (NCNC)
367
+ (NCO)
368
+ (NCS)
369
+ (NCl)
370
+ (NN)
371
+ (NN=O)
372
+ (NNC)
373
+ (NO)
374
+ (NOC)
375
+ (O)
376
+ (OC#N)
377
+ (OC)
378
+ (OC=C)
379
+ (OC=O)
380
+ (OC=S)
381
+ (OCBr)
382
+ (OCC)
383
+ (OCCC)
384
+ (OCCF)
385
+ (OCCI)
386
+ (OCCN)
387
+ (OCCO)
388
+ (OCCS)
389
+ (OCCl)
390
+ (OCF)
391
+ (OCI)
392
+ (OCO)
393
+ (OCOC)
394
+ (OCON)
395
+ (OCSC)
396
+ (OCl)
397
+ (OI)
398
+ (ON)
399
+ (OO)
400
+ (OOC)
401
+ (OOCC)
402
+ (OOSN)
403
+ (OSC)
404
+ (P)
405
+ (S)
406
+ (SC#N)
407
+ (SC)
408
+ (SCC)
409
+ (SCCC)
410
+ (SCCF)
411
+ (SCCN)
412
+ (SCCO)
413
+ (SCCS)
414
+ (SCCl)
415
+ (SCF)
416
+ (SCN)
417
+ (SCOC)
418
+ (SCSC)
419
+ (SCl)
420
+ (SI)
421
+ (SN)
422
+ (SN=O)
423
+ (SO)
424
+ (SOC)
425
+ (SOOO)
426
+ (SS)
427
+ (SSC)
428
+ (SSCC)
429
+ ([At])
430
+ ([O-])
431
+ ([O])
432
+ ([S-])
433
+ (\\Br)
434
+ (\\C#N)
435
+ (\\C)
436
+ (\\C=N)
437
+ (\\C=O)
438
+ (\\CBr)
439
+ (\\CC)
440
+ (\\CCC)
441
+ (\\CCO)
442
+ (\\CCl)
443
+ (\\CF)
444
+ (\\CN)
445
+ (\\CNC)
446
+ (\\CO)
447
+ (\\COC)
448
+ (\\Cl)
449
+ (\\F)
450
+ (\\I)
451
+ (\\N)
452
+ (\\NC)
453
+ (\\NCC)
454
+ (\\NN)
455
+ (\\NO)
456
+ (\\NOC)
457
+ (\\O)
458
+ (\\OC)
459
+ (\\OCC)
460
+ (\\ON)
461
+ (\\S)
462
+ (\\SC)
463
+ (\\SCC)
464
+ [Ag+]
465
+ [Ag-4]
466
+ [Ag]
467
+ [Al-3]
468
+ [Al]
469
+ [As+]
470
+ [AsH3]
471
+ [AsH]
472
+ [As]
473
+ [At]
474
+ [B-]
475
+ [B@-]
476
+ [B@@-]
477
+ [BH-]
478
+ [BH2-]
479
+ [BH3-]
480
+ [B]
481
+ [Ba]
482
+ [Br+2]
483
+ [BrH]
484
+ [Br]
485
+ [C+]
486
+ [C-]
487
+ [C@@H]
488
+ [C@@]
489
+ [C@H]
490
+ [C@]
491
+ [CH-]
492
+ [CH2]
493
+ [CH3]
494
+ [CH]
495
+ [C]
496
+ [CaH2]
497
+ [Ca]
498
+ [Cl+2]
499
+ [Cl+3]
500
+ [Cl+]
501
+ [Cs]
502
+ [FH]
503
+ [F]
504
+ [H]
505
+ [He]
506
+ [I+2]
507
+ [I+3]
508
+ [I+]
509
+ [IH]
510
+ [I]
511
+ [K]
512
+ [Kr]
513
+ [Li+]
514
+ [LiH]
515
+ [MgH2]
516
+ [Mg]
517
+ [N+]
518
+ [N-]
519
+ [N@+]
520
+ [N@@+]
521
+ [N@@]
522
+ [N@]
523
+ [NH+]
524
+ [NH-]
525
+ [NH2+]
526
+ [NH3]
527
+ [NH]
528
+ [N]
529
+ [Na]
530
+ [O+]
531
+ [O-]
532
+ [OH+]
533
+ [OH2]
534
+ [OH]
535
+ [O]
536
+ [P+]
537
+ [P@+]
538
+ [P@@+]
539
+ [P@@]
540
+ [P@]
541
+ [PH2]
542
+ [PH]
543
+ [P]
544
+ [Ra]
545
+ [Rb]
546
+ [S+]
547
+ [S-]
548
+ [S@+]
549
+ [S@@+]
550
+ [S@@]
551
+ [S@]
552
+ [SH+]
553
+ [SH2]
554
+ [SH]
555
+ [S]
556
+ [Se+]
557
+ [Se-2]
558
+ [SeH2]
559
+ [SeH]
560
+ [Se]
561
+ [Si@]
562
+ [SiH2]
563
+ [SiH]
564
+ [Si]
565
+ [SrH2]
566
+ [TeH]
567
+ [Te]
568
+ [Xe]
569
+ [Zn+2]
570
+ [Zn-2]
571
+ [Zn]
572
+ [b-]
573
+ [c+]
574
+ [c-]
575
+ [cH-]
576
+ [cH]
577
+ [c]
578
+ [n+]
579
+ [n-]
580
+ [nH]
581
+ [n]
582
+ [o+]
583
+ [s+]
584
+ [se+]
585
+ [se]
586
+ [te+]
587
+ [te]
tr2d2-pep/utils/app.py ADDED
@@ -0,0 +1,1255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+ from io import StringIO
5
+ import rdkit
6
+ from rdkit import Chem
7
+ from rdkit.Chem import AllChem, Draw
8
+ import numpy as np
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.patches as patches
12
+ from io import BytesIO
13
+ import tempfile
14
+ from rdkit import Chem
15
+
16
+ class PeptideAnalyzer:
17
+ def __init__(self):
18
+ self.bond_patterns = [
19
+ (r'OC\(=O\)', 'ester'), # Ester bond
20
+ (r'N\(C\)C\(=O\)', 'n_methyl'), # N-methylated peptide bond
21
+ (r'N[0-9]C\(=O\)', 'proline'), # Proline peptide bond
22
+ (r'NC\(=O\)', 'peptide'), # Standard peptide bond
23
+ (r'C\(=O\)N\(C\)', 'n_methyl_reverse'), # Reverse N-methylated
24
+ (r'C\(=O\)N[12]?', 'peptide_reverse') # Reverse peptide bond
25
+ ]
26
+ # Three to one letter code mapping
27
+ self.three_to_one = {
28
+ 'Ala': 'A', 'Cys': 'C', 'Asp': 'D', 'Glu': 'E',
29
+ 'Phe': 'F', 'Gly': 'G', 'His': 'H', 'Ile': 'I',
30
+ 'Lys': 'K', 'Leu': 'L', 'Met': 'M', 'Asn': 'N',
31
+ 'Pro': 'P', 'Gln': 'Q', 'Arg': 'R', 'Ser': 'S',
32
+ 'Thr': 'T', 'Val': 'V', 'Trp': 'W', 'Tyr': 'Y'
33
+ }
34
+
35
+ def is_peptide(self, smiles):
36
+ """Check if the SMILES represents a peptide structure"""
37
+ mol = Chem.MolFromSmiles(smiles)
38
+ if mol is None:
39
+ return False
40
+
41
+ # Look for peptide bonds: NC(=O) pattern
42
+ peptide_bond_pattern = Chem.MolFromSmarts('[NH][C](=O)')
43
+ if mol.HasSubstructMatch(peptide_bond_pattern):
44
+ return True
45
+
46
+ # Look for N-methylated peptide bonds: N(C)C(=O) pattern
47
+ n_methyl_pattern = Chem.MolFromSmarts('[N;H0;$(NC)](C)[C](=O)')
48
+ if mol.HasSubstructMatch(n_methyl_pattern):
49
+ return True
50
+
51
+ return False
52
+
53
+ def is_cyclic(self, smiles):
54
+ """Improved cyclic peptide detection"""
55
+ # Check for C-terminal carboxyl
56
+ if smiles.endswith('C(=O)O'):
57
+ return False, [], []
58
+
59
+ # Find all numbers used in ring closures
60
+ ring_numbers = re.findall(r'(?:^|[^c])[0-9](?=[A-Z@\(\)])', smiles)
61
+
62
+ # Find aromatic ring numbers
63
+ aromatic_matches = re.findall(r'c[0-9](?:ccccc|c\[nH\]c)[0-9]', smiles)
64
+ aromatic_cycles = []
65
+ for match in aromatic_matches:
66
+ numbers = re.findall(r'[0-9]', match)
67
+ aromatic_cycles.extend(numbers)
68
+
69
+ # Numbers that aren't part of aromatic rings are peptide cycles
70
+ peptide_cycles = [n for n in ring_numbers if n not in aromatic_cycles]
71
+
72
+ is_cyclic = len(peptide_cycles) > 0 and not smiles.endswith('C(=O)O')
73
+ return is_cyclic, peptide_cycles, aromatic_cycles
74
+
75
+ def split_on_bonds(self, smiles):
76
+ """Split SMILES into segments with simplified Pro handling"""
77
+ positions = []
78
+ used = set()
79
+
80
+ # Find Gly pattern first
81
+ gly_pattern = r'NCC\(=O\)'
82
+ for match in re.finditer(gly_pattern, smiles):
83
+ if not any(p in range(match.start(), match.end()) for p in used):
84
+ positions.append({
85
+ 'start': match.start(),
86
+ 'end': match.end(),
87
+ 'type': 'gly',
88
+ 'pattern': match.group()
89
+ })
90
+ used.update(range(match.start(), match.end()))
91
+
92
+ for pattern, bond_type in self.bond_patterns:
93
+ for match in re.finditer(pattern, smiles):
94
+ if not any(p in range(match.start(), match.end()) for p in used):
95
+ positions.append({
96
+ 'start': match.start(),
97
+ 'end': match.end(),
98
+ 'type': bond_type,
99
+ 'pattern': match.group()
100
+ })
101
+ used.update(range(match.start(), match.end()))
102
+
103
+ # Sort by position
104
+ positions.sort(key=lambda x: x['start'])
105
+
106
+ # Create segments
107
+ segments = []
108
+
109
+ if positions:
110
+ # First segment
111
+ if positions[0]['start'] > 0:
112
+ segments.append({
113
+ 'content': smiles[0:positions[0]['start']],
114
+ 'bond_after': positions[0]['pattern']
115
+ })
116
+
117
+ # Process segments
118
+ for i in range(len(positions)-1):
119
+ current = positions[i]
120
+ next_pos = positions[i+1]
121
+
122
+ if current['type'] == 'gly':
123
+ segments.append({
124
+ 'content': 'NCC(=O)',
125
+ 'bond_before': positions[i-1]['pattern'] if i > 0 else None,
126
+ 'bond_after': next_pos['pattern']
127
+ })
128
+ else:
129
+ content = smiles[current['end']:next_pos['start']]
130
+ if content:
131
+ segments.append({
132
+ 'content': content,
133
+ 'bond_before': current['pattern'],
134
+ 'bond_after': next_pos['pattern']
135
+ })
136
+
137
+ # Last segment
138
+ if positions[-1]['end'] < len(smiles):
139
+ segments.append({
140
+ 'content': smiles[positions[-1]['end']:],
141
+ 'bond_before': positions[-1]['pattern']
142
+ })
143
+
144
+ return segments
145
+
146
+ def clean_terminal_carboxyl(self, segment):
147
+ """Remove C-terminal carboxyl only if it's the true terminus"""
148
+ content = segment['content']
149
+
150
+ # Only clean if:
151
+ # 1. Contains C(=O)O
152
+ # 2. No bond_after exists (meaning it's the last segment)
153
+ # 3. C(=O)O is at the end of the content
154
+ if 'C(=O)O' in content and not segment.get('bond_after'):
155
+ print('recognized?')
156
+ # Remove C(=O)O pattern regardless of position
157
+ cleaned = re.sub(r'\(C\(=O\)O\)', '', content)
158
+ # Remove any leftover empty parentheses
159
+ cleaned = re.sub(r'\(\)', '', cleaned)
160
+ print(cleaned)
161
+ return cleaned
162
+ return content
163
+
164
+ def identify_residue(self, segment):
165
+ """Identify residue with Pro reconstruction"""
166
+ # Only clean terminal carboxyl if this is the last segment
167
+ content = self.clean_terminal_carboxyl(segment)
168
+ mods = self.get_modifications(segment)
169
+
170
+ # UAA pattern matching section - before regular residues
171
+ # Phenylglycine and derivatives
172
+ if 'c1ccccc1' in content:
173
+ if '[C@@H](c1ccccc1)' in content or '[C@H](c1ccccc1)' in content:
174
+ return '4', mods # Base phenylglycine
175
+
176
+ # 4-substituted phenylalanines
177
+ if 'Cc1ccc' in content:
178
+ if 'OMe' in content or 'OCc1ccc' in content:
179
+ return '0A1', mods # 4-methoxy-Phenylalanine
180
+ elif 'Clc1ccc' in content:
181
+ return '200', mods # 4-chloro-Phenylalanine
182
+ elif 'Brc1ccc' in content:
183
+ return '4BF', mods # 4-Bromo-phenylalanine
184
+ elif 'C#Nc1ccc' in content:
185
+ return '4CF', mods # 4-cyano-phenylalanine
186
+ elif 'Ic1ccc' in content:
187
+ return 'PHI', mods # 4-Iodo-phenylalanine
188
+ elif 'Fc1ccc' in content:
189
+ return 'PFF', mods # 4-Fluoro-phenylalanine
190
+
191
+ # Modified tryptophans
192
+ if 'c[nH]c2' in content:
193
+ if 'Oc2cccc2' in content:
194
+ return '0AF', mods # 7-hydroxy-tryptophan
195
+ elif 'Fc2cccc2' in content:
196
+ return '4FW', mods # 4-fluoro-tryptophan
197
+ elif 'Clc2cccc2' in content:
198
+ return '6CW', mods # 6-chloro-tryptophan
199
+ elif 'Brc2cccc2' in content:
200
+ return 'BTR', mods # 6-bromo-tryptophan
201
+ elif 'COc2cccc2' in content:
202
+ return 'MOT5', mods # 5-Methoxy-tryptophan
203
+ elif 'Cc2cccc2' in content:
204
+ return 'MTR5', mods # 5-Methyl-tryptophan
205
+
206
+ # Special amino acids
207
+ if 'CC(C)(C)[C@@H]' in content or 'CC(C)(C)[C@H]' in content:
208
+ return 'BUG', mods # Tertleucine
209
+
210
+ if 'CCCNC(=N)N' in content:
211
+ return 'CIR', mods # Citrulline
212
+
213
+ if '[SeH]' in content:
214
+ return 'CSE', mods # Selenocysteine
215
+
216
+ if '[NH3]CC[C@@H]' in content or '[NH3]CC[C@H]' in content:
217
+ return 'DAB', mods # Diaminobutyric acid
218
+
219
+ if 'C1CCCCC1' in content:
220
+ if 'C1CCCCC1[C@@H]' in content or 'C1CCCCC1[C@H]' in content:
221
+ return 'CHG', mods # Cyclohexylglycine
222
+ elif 'C1CCCCC1C[C@@H]' in content or 'C1CCCCC1C[C@H]' in content:
223
+ return 'ALC', mods # 3-cyclohexyl-alanine
224
+
225
+ # Naphthalene derivatives
226
+ if 'c1cccc2c1cccc2' in content:
227
+ if 'c1cccc2c1cccc2[C@@H]' in content or 'c1cccc2c1cccc2[C@H]' in content:
228
+ return 'NAL', mods # 2-Naphthyl-alanine
229
+
230
+ # Heteroaromatic derivatives
231
+ if 'c1cncc' in content:
232
+ return 'PYR4', mods # 3-(4-Pyridyl)-alanine
233
+ if 'c1cscc' in content:
234
+ return 'THA3', mods # 3-(3-thienyl)-alanine
235
+ if 'c1nnc' in content:
236
+ return 'TRZ4', mods # 3-(1,2,4-Triazol-1-yl)-alanine
237
+
238
+ # Modified serines and threonines
239
+ if 'OP(O)(O)O' in content:
240
+ if '[C@@H](COP' in content or '[C@H](COP' in content:
241
+ return 'SEP', mods # phosphoserine
242
+ elif '[C@@H](OP' in content or '[C@H](OP' in content:
243
+ return 'TPO', mods # phosphothreonine
244
+
245
+ # Specialized ring systems
246
+ if 'c1c2ccccc2cc2c1cccc2' in content:
247
+ return 'ANTH', mods # 3-(9-anthryl)-alanine
248
+ if 'c1csc2c1cccc2' in content:
249
+ return 'BTH3', mods # 3-(3-benzothienyl)-alanine
250
+ if '[C@]12C[C@H]3C[C@@H](C2)C[C@@H](C1)C3' in content:
251
+ return 'ADAM', mods # Adamanthane
252
+
253
+ # Fluorinated derivatives
254
+ if 'FC(F)(F)' in content:
255
+ if 'CC(F)(F)F' in content:
256
+ return 'FLA', mods # Trifluoro-alanine
257
+ if 'C(F)(F)F)c1' in content:
258
+ if 'c1ccccc1C(F)(F)F' in content:
259
+ return 'TFG2', mods # 2-(Trifluoromethyl)-phenylglycine
260
+ if 'c1cccc(c1)C(F)(F)F' in content:
261
+ return 'TFG3', mods # 3-(Trifluoromethyl)-phenylglycine
262
+ if 'c1ccc(cc1)C(F)(F)F' in content:
263
+ return 'TFG4', mods # 4-(Trifluoromethyl)-phenylglycine
264
+
265
+ # Multiple halogen patterns
266
+ if 'F' in content and 'c1' in content:
267
+ if 'c1ccc(c(c1)F)F' in content:
268
+ return 'F2F', mods # 3,4-Difluoro-phenylalanine
269
+ if 'cc(F)cc(c1)F' in content:
270
+ return 'WFP', mods # 3,5-Difluoro-phenylalanine
271
+ if 'Cl' in content and 'c1' in content:
272
+ if 'c1ccc(cc1Cl)Cl' in content:
273
+ return 'CP24', mods # 2,4-dichloro-phenylalanine
274
+ if 'c1ccc(c(c1)Cl)Cl' in content:
275
+ return 'CP34', mods # 3,4-dichloro-phenylalanine
276
+
277
+ # Hydroxy and amino derivatives
278
+ if 'O' in content and 'c1' in content:
279
+ if 'c1cc(O)cc(c1)O' in content:
280
+ return '3FG', mods # (2s)-amino(3,5-dihydroxyphenyl)-ethanoic acid
281
+ if 'c1ccc(c(c1)O)O' in content:
282
+ return 'DAH', mods # 3,4-Dihydroxy-phenylalanine
283
+
284
+ # Cyclic amino acids
285
+ if 'C1CCCC1' in content:
286
+ return 'CPA3', mods # 3-Cyclopentyl-alanine
287
+ if 'C1CCCCC1' in content:
288
+ if 'CC1CCCCC1' in content:
289
+ return 'ALC', mods # 3-cyclohexyl-alanine
290
+ else:
291
+ return 'CHG', mods # Cyclohexylglycine
292
+
293
+ # Chain-length variants
294
+ if 'CCC[C@@H]' in content or 'CCC[C@H]' in content:
295
+ return 'NLE', mods # Norleucine
296
+ if 'CC[C@@H]' in content or 'CC[C@H]' in content:
297
+ if not any(x in content for x in ['CC(C)', 'COC', 'CN(']):
298
+ return 'ABA', mods # 2-Aminobutyric acid
299
+
300
+ # Modified histidines
301
+ if 'c1cnc' in content:
302
+ if '[C@@H]1CN[C@@H](N1)F' in content:
303
+ return '2HF', mods # 2-fluoro-l-histidine
304
+ if 'c1cnc([nH]1)F' in content:
305
+ return '2HF1', mods # 2-fluoro-l-histidine variant
306
+ if 'c1c[nH]c(n1)F' in content:
307
+ return '2HF2', mods # 2-fluoro-l-histidine variant
308
+
309
+ # Sulfur and selenium containing
310
+ if '[SeH]' in content:
311
+ return 'CSE', mods # Selenocysteine
312
+ if 'S' in content:
313
+ if 'CSCc1ccccc1' in content:
314
+ return 'BCS', mods # benzylcysteine
315
+ if 'CCSC' in content:
316
+ return 'ESC', mods # Ethionine
317
+ if 'CCS' in content:
318
+ return 'HCS', mods # homocysteine
319
+
320
+ # Additional modifications
321
+ if 'CN=[N]=N' in content:
322
+ return 'AZDA', mods # azido-alanine
323
+ if '[NH]=[C](=[NH2])=[NH2]' in content:
324
+ if 'CCC[NH]=' in content:
325
+ return 'AGM', mods # 5-methyl-arginine
326
+ if 'CC[NH]=' in content:
327
+ return 'GDPR', mods # 2-Amino-3-guanidinopropionic acid
328
+
329
+ if 'CCON' in content:
330
+ return 'CAN', mods # canaline
331
+ if '[C@@H]1C=C[C@@H](C=C1)' in content:
332
+ return 'ACZ', mods # cis-amiclenomycin
333
+ if 'CCC(=O)[NH3]' in content:
334
+ return 'ONL', mods # 5-oxo-l-norleucine
335
+ if 'c1ccncc1' in content:
336
+ return 'PYR4', mods # 3-(4-Pyridyl)-alanine
337
+ if 'c1ccco1' in content:
338
+ return 'FUA2', mods # (2-furyl)-alanine
339
+
340
+ if 'c1ccc' in content:
341
+ if 'c1ccc(cc1)c1ccccc1' in content:
342
+ return 'BIF', mods # 4,4-biphenylalanine
343
+ if 'c1ccc(cc1)C(=O)c1ccccc1' in content:
344
+ return 'PBF', mods # 4-benzoyl-phenylalanine
345
+ if 'c1ccc(cc1)C(C)(C)C' in content:
346
+ return 'TBP4', mods # 4-tert-butyl-phenylalanine
347
+ if 'c1ccc(cc1)[C](=[NH2])=[NH2]' in content:
348
+ return '0BN', mods # 4-carbamimidoyl-l-phenylalanine
349
+ if 'c1cccc(c1)[C](=[NH2])=[NH2]' in content:
350
+ return 'APM', mods # m-amidinophenyl-3-alanine
351
+
352
+ # Multiple hydroxy patterns
353
+ if 'O' in content:
354
+ if '[C@H]([C@H](C)O)O' in content:
355
+ return 'ILX', mods # 4,5-dihydroxy-isoleucine
356
+ if '[C@H]([C@@H](C)O)O' in content:
357
+ return 'ALO', mods # Allo-threonine
358
+ if '[C@H](COP(O)(O)O)' in content:
359
+ return 'SEP', mods # phosphoserine
360
+ if '[C@H]([C@@H](C)OP(O)(O)O)' in content:
361
+ return 'TPO', mods # phosphothreonine
362
+ if '[C@H](c1ccc(O)cc1)O' in content:
363
+ return 'OMX', mods # (betar)-beta-hydroxy-l-tyrosine
364
+ if '[C@H](c1ccc(c(Cl)c1)O)O' in content:
365
+ return 'OMY', mods # (betar)-3-chloro-beta-hydroxy-l-tyrosine
366
+
367
+ # Heterocyclic patterns
368
+ if 'n1' in content:
369
+ if 'n1cccn1' in content:
370
+ return 'PYZ1', mods # 3-(1-Pyrazolyl)-alanine
371
+ if 'n1nncn1' in content:
372
+ return 'TEZA', mods # 3-(2-Tetrazolyl)-alanine
373
+ if 'c2c(n1)cccc2' in content:
374
+ return 'QU32', mods # 3-(2-Quinolyl)-alanine
375
+ if 'c1cnc2c(c1)cccc2' in content:
376
+ return 'QU33', mods # 3-(3-quinolyl)-alanine
377
+ if 'c1ccnc2c1cccc2' in content:
378
+ return 'QU34', mods # 3-(4-quinolyl)-alanine
379
+ if 'c1ccc2c(c1)nccc2' in content:
380
+ return 'QU35', mods # 3-(5-Quinolyl)-alanine
381
+ if 'c1ccc2c(c1)cncc2' in content:
382
+ return 'QU36', mods # 3-(6-Quinolyl)-alanine
383
+ if 'c1cnc2c(n1)cccc2' in content:
384
+ return 'QX32', mods # 3-(2-quinoxalyl)-alanine
385
+
386
+ # Multiple nitrogen patterns
387
+ if 'N' in content:
388
+ if '[NH3]CC[C@@H]' in content:
389
+ return 'DAB', mods # Diaminobutyric acid
390
+ if '[NH3]C[C@@H]' in content:
391
+ return 'DPP', mods # 2,3-Diaminopropanoic acid
392
+ if '[NH3]CCCCCC[C@@H]' in content:
393
+ return 'HHK', mods # (2s)-2,8-diaminooctanoic acid
394
+ if 'CCC[NH]=[C](=[NH2])=[NH2]' in content:
395
+ return 'GBUT', mods # 2-Amino-4-guanidinobutryric acid
396
+ if '[NH]=[C](=S)=[NH2]' in content:
397
+ return 'THIC', mods # Thio-citrulline
398
+
399
+ # Chain modified amino acids
400
+ if 'CC' in content:
401
+ if 'CCCC[C@@H]' in content:
402
+ return 'AHP', mods # 2-Aminoheptanoic acid
403
+ if 'CCC([C@@H])(C)C' in content:
404
+ return 'I2M', mods # 3-methyl-l-alloisoleucine
405
+ if 'CC[C@H]([C@@H])C' in content:
406
+ return 'IIL', mods # Allo-Isoleucine
407
+ if '[C@H](CCC(C)C)' in content:
408
+ return 'HLEU', mods # Homoleucine
409
+ if '[C@@H]([C@@H](C)O)C' in content:
410
+ return 'HLU', mods # beta-hydroxyleucine
411
+
412
+ # Modified glutamate/aspartate patterns
413
+ if '[C@@H]' in content:
414
+ if '[C@@H](C[C@@H](F))' in content:
415
+ return 'FGA4', mods # 4-Fluoro-glutamic acid
416
+ if '[C@@H](C[C@@H](O))' in content:
417
+ return '3GL', mods # 4-hydroxy-glutamic-acid
418
+ if '[C@@H](C[C@H](C))' in content:
419
+ return 'LME', mods # (3r)-3-methyl-l-glutamic acid
420
+ if '[C@@H](CC[C@H](C))' in content:
421
+ return 'MEG', mods # (3s)-3-methyl-l-glutamic acid
422
+
423
+ # Sulfur and selenium modifications
424
+ if 'S' in content:
425
+ if 'SCC[C@@H]' in content:
426
+ return 'HSER', mods # homoserine
427
+ if 'SCCN' in content:
428
+ return 'SLZ', mods # thialysine
429
+ if 'SC(=O)' in content:
430
+ return 'CSA', mods # s-acetonylcysteine
431
+ if '[S@@](=O)' in content:
432
+ return 'SME', mods # Methionine sulfoxide
433
+ if 'S(=O)(=O)' in content:
434
+ return 'OMT', mods # Methionine sulfone
435
+
436
+ # Double bond containing
437
+ if 'C=' in content:
438
+ if 'C=C[C@@H]' in content:
439
+ return '2AG', mods # 2-Allyl-glycine
440
+ if 'C=C[C@@H]' in content:
441
+ return 'LVG', mods # vinylglycine
442
+ if 'C=Cc1ccccc1' in content:
443
+ return 'STYA', mods # Styrylalanine
444
+
445
+ # Special cases
446
+ if '[C@@H]1Cc2c(C1)cccc2' in content:
447
+ return 'IGL', mods # alpha-amino-2-indanacetic acid
448
+ if '[C](=[C](=O)=O)=O' in content:
449
+ return '26P', mods # 2-amino-6-oxopimelic acid
450
+ if '[C](=[C](=O)=O)=C' in content:
451
+ return '2NP', mods # l-2-amino-6-methylene-pimelic acid
452
+ if 'c2cnc[nH]2' in content:
453
+ return 'HIS', mods # histidine core
454
+ if 'c1cccc2c1cc(O)cc2' in content:
455
+ return 'NAO1', mods # 5-hydroxy-1-naphthalene
456
+ if 'c1ccc2c(c1)cc(O)cc2' in content:
457
+ return 'NAO2', mods # 6-hydroxy-2-naphthalene
458
+
459
+ # Proline (P) - flexible ring numbers
460
+ if any([
461
+ # Check for any ring number in bond patterns
462
+ (segment.get('bond_after', '').startswith(f'N{n}C(=O)') and 'CCC' in content and
463
+ any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
464
+ for n in '123456789'
465
+ ]) or any([
466
+ # Check ending patterns with any ring number
467
+ (f'CCCN{n}' in content and content.endswith('=O') and
468
+ any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
469
+ for n in '123456789'
470
+ ]) or any([
471
+ # Handle CCC[C@H]n patterns
472
+ (content == f'CCC[C@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
473
+ (content == f'CCC[C@@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
474
+ # N-terminal Pro with any ring number
475
+ (f'N{n}CCC[C@H]{n}' in content) or
476
+ (f'N{n}CCC[C@@H]{n}' in content)
477
+ for n in '123456789'
478
+ ]):
479
+ return 'Pro', mods
480
+
481
+ # Tryptophan (W) - more specific indole pattern
482
+ if re.search(r'c[0-9]c\[nH\]c[0-9]ccccc[0-9][0-9]', content) and \
483
+ 'c[nH]c' in content.replace(' ', ''):
484
+ return 'Trp', mods
485
+
486
+ # Lysine (K) - both patterns
487
+ if '[C@@H](CCCCN)' in content or '[C@H](CCCCN)' in content:
488
+ return 'Lys', mods
489
+
490
+ # Arginine (R) - both patterns
491
+ if '[C@@H](CCCNC(=N)N)' in content or '[C@H](CCCNC(=N)N)' in content:
492
+ return 'Arg', mods
493
+
494
+ if ('C[C@H](CCCC)' in content or 'C[C@@H](CCCC)' in content) and 'CC(C)' not in content:
495
+ return 'Nle', mods
496
+
497
+ # Ornithine (Orn) - 3-carbon chain with NH2
498
+ if ('C[C@H](CCCN)' in content or 'C[C@@H](CCCN)' in content) and 'CC(C)' not in content:
499
+ return 'Orn', mods
500
+
501
+ # 2-Naphthylalanine (2Nal) - distinct from Phe pattern
502
+ if ('Cc3cc2ccccc2c3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
503
+ return '2Nal', mods
504
+
505
+ # Cyclohexylalanine (Cha) - already in your code but moved here for clarity
506
+ if 'N2CCCCC2' in content or 'CCCCC2' in content:
507
+ return 'Cha', mods
508
+
509
+ # Aminobutyric acid (Abu) - 2-carbon chain
510
+ if ('C[C@H](CC)' in content or 'C[C@@H](CC)' in content) and not any(p in content for p in ['CC(C)', 'CCCC', 'CCC(C)']):
511
+ return 'Abu', mods
512
+
513
+ # Pipecolic acid (Pip) - 6-membered ring like Pro
514
+ if ('N3CCCCC3' in content or 'CCCCC3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
515
+ return 'Pip', mods
516
+
517
+ # Cyclohexylglycine (Chg) - direct cyclohexyl without CH2
518
+ if ('C[C@H](C1CCCCC1)' in content or 'C[C@@H](C1CCCCC1)' in content):
519
+ return 'Chg', mods
520
+
521
+ # 4-Fluorophenylalanine (4F-Phe)
522
+ if ('Cc2ccc(F)cc2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
523
+ return '4F-Phe', mods
524
+
525
+ # Regular residue identification
526
+ if ('NCC(=O)' in content) or (content == 'C'):
527
+ # Middle case - between bonds
528
+ if segment.get('bond_before') and segment.get('bond_after'):
529
+ if ('C(=O)N' in segment['bond_before'] or 'C(=O)N(C)' in segment['bond_before']):
530
+ return 'Gly', mods
531
+ # Terminal case - at the end
532
+ elif segment.get('bond_before') and segment.get('bond_before').startswith('C(=O)N'):
533
+ return 'Gly', mods
534
+
535
+ if 'CC(C)C[C@H]' in content or 'CC(C)C[C@@H]' in content:
536
+ return 'Leu', mods
537
+ if '[C@@H](CC(C)C)' in content or '[C@H](CC(C)C)' in content:
538
+ return 'Leu', mods
539
+
540
+ if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content:
541
+ return 'Thr', mods
542
+
543
+ if '[C@H](Cc2ccccc2)' in content or '[C@@H](Cc2ccccc2)' in content:
544
+ return 'Phe', mods
545
+
546
+ if ('[C@H](C(C)C)' in content or # With outer parentheses
547
+ '[C@@H](C(C)C)' in content or # With outer parentheses
548
+ '[C@H]C(C)C' in content or # Without outer parentheses
549
+ '[C@@H]C(C)C' in content): # Without outer parentheses
550
+ if not any(p in content for p in ['CC(C)C[C@H]', 'CC(C)C[C@@H]']): # Still check not Leu
551
+ return 'Val', mods
552
+
553
+ if '[C@H](COC(C)(C)C)' in content or '[C@@H](COC(C)(C)C)' in content:
554
+ return 'O-tBu', mods
555
+
556
+ if any([
557
+ 'CC[C@H](C)' in content,
558
+ 'CC[C@@H](C)' in content,
559
+ 'C(C)C[C@H]' in content and 'CC(C)C' not in content,
560
+ 'C(C)C[C@@H]' in content and 'CC(C)C' not in content
561
+ ]):
562
+ return 'Ile', mods
563
+
564
+ if ('[C@H](C)' in content or '[C@@H](C)' in content):
565
+ if not any(p in content for p in ['C(C)C', 'COC', 'CN(', 'C(C)O', 'CC[C@H]', 'CC[C@@H]']):
566
+ return 'Ala', mods
567
+
568
+ # Tyrosine (Tyr) - 4-hydroxybenzyl side chain
569
+ if re.search(r'Cc[0-9]ccc\(O\)cc[0-9]', content):
570
+ return 'Tyr', mods
571
+
572
+
573
+ # Serine (Ser) - Hydroxymethyl side chain
574
+ if '[C@H](CO)' in content or '[C@@H](CO)' in content:
575
+ if not ('C(C)O' in content or 'COC' in content):
576
+ return 'Ser', mods
577
+
578
+ # Threonine (Thr) - 1-hydroxyethyl side chain
579
+ if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content or '[C@@H](C)O' in content or '[C@H](C)O' in content:
580
+ return 'Thr', mods
581
+
582
+ # Cysteine (Cys) - Thiol side chain
583
+ if '[C@H](CS)' in content or '[C@@H](CS)' in content:
584
+ return 'Cys', mods
585
+
586
+ # Methionine (Met) - Methylthioethyl side chain
587
+ if ('C[C@H](CCSC)' in content or 'C[C@@H](CCSC)' in content):
588
+ return 'Met', mods
589
+
590
+ # Asparagine (Asn) - Carbamoylmethyl side chain
591
+ if ('CC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
592
+ return 'Asn', mods
593
+
594
+ # Glutamine (Gln) - Carbamoylethyl side chain
595
+ if ('CCC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
596
+ return 'Gln', mods
597
+
598
+ # Aspartic acid (Asp) - Carboxymethyl side chain
599
+ if ('CC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
600
+ return 'Asp', mods
601
+
602
+ # Glutamic acid (Glu) - Carboxyethyl side chain
603
+ if ('CCC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
604
+ return 'Glu', mods
605
+
606
+ # Arginine (Arg) - 3-guanidinopropyl side chain
607
+ if ('CCCNC(=N)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
608
+ return 'Arg', mods
609
+
610
+ # Histidine (His) - Imidazole side chain
611
+ if ('Cc2cnc[nH]2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
612
+ return 'His', mods
613
+
614
+ return None, mods
615
+
616
+ def get_modifications(self, segment):
617
+ """Get modifications based on bond types"""
618
+ mods = []
619
+ if segment.get('bond_after'):
620
+ if 'N(C)' in segment['bond_after'] or segment['bond_after'].startswith('C(=O)N(C)'):
621
+ mods.append('N-Me')
622
+ if 'OC(=O)' in segment['bond_after']:
623
+ mods.append('O-linked')
624
+ return mods
625
+
626
+ def analyze_structure(self, smiles):
627
+ """Main analysis function with debug output"""
628
+ print("\nAnalyzing structure:", smiles)
629
+
630
+ # Split into segments
631
+ segments = self.split_on_bonds(smiles)
632
+
633
+ print("\nSegment Analysis:")
634
+ sequence = []
635
+ for i, segment in enumerate(segments):
636
+ print(f"\nSegment {i}:")
637
+ print(f"Content: {segment['content']}")
638
+ print(f"Bond before: {segment.get('bond_before', 'None')}")
639
+ print(f"Bond after: {segment.get('bond_after', 'None')}")
640
+
641
+ residue, mods = self.identify_residue(segment)
642
+ if residue:
643
+ if mods:
644
+ sequence.append(f"{residue}({','.join(mods)})")
645
+ else:
646
+ sequence.append(residue)
647
+ print(f"Identified as: {residue}")
648
+ print(f"Modifications: {mods}")
649
+ else:
650
+ print(f"Warning: Could not identify residue in segment: {segment['content']}")
651
+
652
+ # Check if cyclic
653
+ is_cyclic, peptide_cycles, aromatic_cycles = self.is_cyclic(smiles)
654
+ three_letter = '-'.join(sequence)
655
+ one_letter = ''.join(self.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence)
656
+
657
+ if is_cyclic:
658
+ three_letter = f"cyclo({three_letter})"
659
+ one_letter = f"cyclo({one_letter})"
660
+
661
+ print(f"\nFinal sequence: {three_letter}")
662
+ print(f"One-letter code: {one_letter}")
663
+ print(f"Is cyclic: {is_cyclic}")
664
+ #print(f"Peptide cycles: {peptide_cycles}")
665
+ #print(f"Aromatic cycles: {aromatic_cycles}")
666
+
667
+ return three_letter, len(segments)
668
+ """return {
669
+ 'three_letter': three_letter,
670
+ #'one_letter': one_letter,
671
+ 'is_cyclic': is_cyclic
672
+ }"""
673
+
674
+ def return_sequence(self, smiles):
675
+ """Main analysis function with debug output"""
676
+ print("\nAnalyzing structure:", smiles)
677
+
678
+ # Split into segments
679
+ segments = self.split_on_bonds(smiles)
680
+
681
+ print("\nSegment Analysis:")
682
+ sequence = []
683
+ for i, segment in enumerate(segments):
684
+ print(f"\nSegment {i}:")
685
+ print(f"Content: {segment['content']}")
686
+ print(f"Bond before: {segment.get('bond_before', 'None')}")
687
+ print(f"Bond after: {segment.get('bond_after', 'None')}")
688
+
689
+ residue, mods = self.identify_residue(segment)
690
+ if residue:
691
+ if mods:
692
+ sequence.append(f"{residue}({','.join(mods)})")
693
+ else:
694
+ sequence.append(residue)
695
+ print(f"Identified as: {residue}")
696
+ print(f"Modifications: {mods}")
697
+ else:
698
+ print(f"Warning: Could not identify residue in segment: {segment['content']}")
699
+
700
+ return sequence
701
+
702
+ """
703
+ def annotate_cyclic_structure(mol, sequence):
704
+ '''Create annotated 2D structure with clear, non-overlapping residue labels'''
705
+ # Generate 2D coordinates
706
+ # Generate 2D coordinates
707
+ AllChem.Compute2DCoords(mol)
708
+
709
+ # Create drawer with larger size for annotations
710
+ drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000) # Even larger size
711
+
712
+ # Get residue list and reverse it to match structural representation
713
+ if sequence.startswith('cyclo('):
714
+ residues = sequence[6:-1].split('-')
715
+ else:
716
+ residues = sequence.split('-')
717
+ residues = list(reversed(residues)) # Reverse the sequence
718
+
719
+ # Draw molecule first to get its bounds
720
+ drawer.drawOptions().addAtomIndices = False
721
+ drawer.DrawMolecule(mol)
722
+ drawer.FinishDrawing()
723
+
724
+ # Convert to PIL Image
725
+ img = Image.open(BytesIO(drawer.GetDrawingText()))
726
+ draw = ImageDraw.Draw(img)
727
+
728
+ try:
729
+ # Try to use DejaVuSans as it's commonly available on Linux systems
730
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
731
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
732
+ except OSError:
733
+ try:
734
+ # Fallback to Arial if available (common on Windows)
735
+ font = ImageFont.truetype("arial.ttf", 60)
736
+ small_font = ImageFont.truetype("arial.ttf", 60)
737
+ except OSError:
738
+ # If no TrueType fonts are available, fall back to default
739
+ print("Warning: TrueType fonts not available, using default font")
740
+ font = ImageFont.load_default()
741
+ small_font = ImageFont.load_default()
742
+ # Get molecule bounds
743
+ conf = mol.GetConformer()
744
+ positions = []
745
+ for i in range(mol.GetNumAtoms()):
746
+ pos = conf.GetAtomPosition(i)
747
+ positions.append((pos.x, pos.y))
748
+
749
+ x_coords = [p[0] for p in positions]
750
+ y_coords = [p[1] for p in positions]
751
+ min_x, max_x = min(x_coords), max(x_coords)
752
+ min_y, max_y = min(y_coords), max(y_coords)
753
+
754
+ # Calculate scaling factors
755
+ scale = 150 # Increased scale factor
756
+ center_x = 1000 # Image center
757
+ center_y = 1000
758
+
759
+ # Add residue labels in a circular arrangement around the structure
760
+ n_residues = len(residues)
761
+ radius = 700 # Distance of labels from center
762
+
763
+ # Start from the rightmost point (3 o'clock position) and go counterclockwise
764
+ # Offset by -3 positions to align with structure
765
+ offset = 0 # Adjust this value to match the structure alignment
766
+ for i, residue in enumerate(residues):
767
+ # Calculate position in a circle around the structure
768
+ # Start from 0 (3 o'clock) and go counterclockwise
769
+ angle = -(2 * np.pi * ((i + offset) % n_residues) / n_residues)
770
+
771
+ # Calculate label position
772
+ label_x = center_x + radius * np.cos(angle)
773
+ label_y = center_y + radius * np.sin(angle)
774
+
775
+ # Draw residue label
776
+ text = f"{i+1}. {residue}"
777
+ bbox = draw.textbbox((label_x, label_y), text, font=font)
778
+ padding = 10
779
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
780
+ bbox[2]+padding, bbox[3]+padding],
781
+ fill='white', outline='white')
782
+ draw.text((label_x, label_y), text,
783
+ font=font, fill='black', anchor="mm")
784
+
785
+ # Add sequence at the top with white background
786
+ seq_text = f"Sequence: {sequence}"
787
+ bbox = draw.textbbox((center_x, 100), seq_text, font=small_font)
788
+ padding = 10
789
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
790
+ bbox[2]+padding, bbox[3]+padding],
791
+ fill='white', outline='white')
792
+ draw.text((center_x, 100), seq_text,
793
+ font=small_font, fill='black', anchor="mm")
794
+
795
+ return img
796
+
797
+ """
798
+ def annotate_cyclic_structure(mol, sequence):
799
+ """Create structure visualization with just the sequence header"""
800
+ # Generate 2D coordinates
801
+ AllChem.Compute2DCoords(mol)
802
+
803
+ # Create drawer with larger size for annotations
804
+ drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000)
805
+
806
+ # Draw molecule first
807
+ drawer.drawOptions().addAtomIndices = False
808
+ drawer.DrawMolecule(mol)
809
+ drawer.FinishDrawing()
810
+
811
+ # Convert to PIL Image
812
+ img = Image.open(BytesIO(drawer.GetDrawingText()))
813
+ draw = ImageDraw.Draw(img)
814
+ try:
815
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
816
+ except OSError:
817
+ try:
818
+ small_font = ImageFont.truetype("arial.ttf", 60)
819
+ except OSError:
820
+ print("Warning: TrueType fonts not available, using default font")
821
+ small_font = ImageFont.load_default()
822
+
823
+ # Add just the sequence header at the top
824
+ seq_text = f"Sequence: {sequence}"
825
+ bbox = draw.textbbox((1000, 100), seq_text, font=small_font)
826
+ padding = 10
827
+ draw.rectangle([bbox[0]-padding, bbox[1]-padding,
828
+ bbox[2]+padding, bbox[3]+padding],
829
+ fill='white', outline='white')
830
+ draw.text((1000, 100), seq_text,
831
+ font=small_font, fill='black', anchor="mm")
832
+
833
+ return img
834
+
835
+ def create_enhanced_linear_viz(sequence, smiles):
836
+ """Create an enhanced linear representation using PeptideAnalyzer"""
837
+ analyzer = PeptideAnalyzer() # Create analyzer instance
838
+
839
+ # Create figure with two subplots
840
+ fig = plt.figure(figsize=(15, 10))
841
+ gs = fig.add_gridspec(2, 1, height_ratios=[1, 2])
842
+ ax_struct = fig.add_subplot(gs[0])
843
+ ax_detail = fig.add_subplot(gs[1])
844
+
845
+ # Parse sequence and get residues
846
+ if sequence.startswith('cyclo('):
847
+ residues = sequence[6:-1].split('-')
848
+ else:
849
+ residues = sequence.split('-')
850
+
851
+ # Get segments using analyzer
852
+ segments = analyzer.split_on_bonds(smiles)
853
+
854
+ # Debug print
855
+ print(f"Number of residues: {len(residues)}")
856
+ print(f"Number of segments: {len(segments)}")
857
+
858
+ # Top subplot - Basic structure
859
+ ax_struct.set_xlim(0, 10)
860
+ ax_struct.set_ylim(0, 2)
861
+
862
+ num_residues = len(residues)
863
+ spacing = 9.0 / (num_residues - 1) if num_residues > 1 else 9.0
864
+
865
+ # Draw basic structure
866
+ y_pos = 1.5
867
+ for i in range(num_residues):
868
+ x_pos = 0.5 + i * spacing
869
+
870
+ # Draw amino acid box
871
+ rect = patches.Rectangle((x_pos-0.3, y_pos-0.2), 0.6, 0.4,
872
+ facecolor='lightblue', edgecolor='black')
873
+ ax_struct.add_patch(rect)
874
+
875
+ # Draw connecting bonds if not the last residue
876
+ if i < num_residues - 1:
877
+ segment = segments[i] if i < len(segments) else None
878
+ if segment:
879
+ # Determine bond type from segment info
880
+ bond_type = 'ester' if 'O-linked' in segment.get('bond_after', '') else 'peptide'
881
+ is_n_methylated = 'N-Me' in segment.get('bond_after', '')
882
+
883
+ bond_color = 'red' if bond_type == 'ester' else 'black'
884
+ linestyle = '--' if bond_type == 'ester' else '-'
885
+
886
+ # Draw bond line
887
+ ax_struct.plot([x_pos+0.3, x_pos+spacing-0.3], [y_pos, y_pos],
888
+ color=bond_color, linestyle=linestyle, linewidth=2)
889
+
890
+ # Add bond type label
891
+ mid_x = x_pos + spacing/2
892
+ bond_label = f"{bond_type}"
893
+ if is_n_methylated:
894
+ bond_label += "\n(N-Me)"
895
+ ax_struct.text(mid_x, y_pos+0.1, bond_label,
896
+ ha='center', va='bottom', fontsize=10,
897
+ color=bond_color)
898
+
899
+ # Add residue label
900
+ ax_struct.text(x_pos, y_pos-0.5, residues[i],
901
+ ha='center', va='top', fontsize=14)
902
+
903
+ # Bottom subplot - Detailed breakdown
904
+ ax_detail.set_ylim(0, len(segments)+1)
905
+ ax_detail.set_xlim(0, 1)
906
+
907
+ # Create detailed breakdown
908
+ segment_y = len(segments) # Start from top
909
+ for i, segment in enumerate(segments):
910
+ y = segment_y - i
911
+
912
+ # Check if this is a bond or residue
913
+ residue, mods = analyzer.identify_residue(segment)
914
+ if residue:
915
+ text = f"Residue {i+1}: {residue}"
916
+ if mods:
917
+ text += f" ({', '.join(mods)})"
918
+ color = 'blue'
919
+ else:
920
+ # Must be a bond
921
+ text = f"Bond {i}: "
922
+ if 'O-linked' in segment.get('bond_after', ''):
923
+ text += "ester"
924
+ elif 'N-Me' in segment.get('bond_after', ''):
925
+ text += "peptide (N-methylated)"
926
+ else:
927
+ text += "peptide"
928
+ color = 'red'
929
+
930
+ # Add segment analysis
931
+ ax_detail.text(0.05, y, text, fontsize=12, color=color)
932
+ ax_detail.text(0.5, y, f"SMILES: {segment.get('content', '')}", fontsize=10, color='gray')
933
+
934
+ # If cyclic, add connection indicator
935
+ if sequence.startswith('cyclo('):
936
+ ax_struct.annotate('', xy=(9.5, y_pos), xytext=(0.5, y_pos),
937
+ arrowprops=dict(arrowstyle='<->', color='red', lw=2))
938
+ ax_struct.text(5, y_pos+0.3, 'Cyclic Connection',
939
+ ha='center', color='red', fontsize=14)
940
+
941
+ # Add titles and adjust layout
942
+ ax_struct.set_title("Peptide Structure Overview", pad=20)
943
+ ax_detail.set_title("Segment Analysis Breakdown", pad=20)
944
+
945
+ # Remove axes
946
+ for ax in [ax_struct, ax_detail]:
947
+ ax.set_xticks([])
948
+ ax.set_yticks([])
949
+ ax.axis('off')
950
+
951
+ plt.tight_layout()
952
+ return fig
953
+
954
+ class PeptideStructureGenerator:
955
+ """A class to generate 3D structures of peptides using different embedding methods"""
956
+
957
+ @staticmethod
958
+ def prepare_molecule(smiles):
959
+ """Prepare molecule with proper hydrogen handling"""
960
+ mol = Chem.MolFromSmiles(smiles, sanitize=False)
961
+ if mol is None:
962
+ raise ValueError("Failed to create molecule from SMILES")
963
+
964
+ # Calculate valence for each atom
965
+ for atom in mol.GetAtoms():
966
+ atom.UpdatePropertyCache(strict=False)
967
+
968
+ # Sanitize with reduced requirements
969
+ Chem.SanitizeMol(mol,
970
+ sanitizeOps=Chem.SANITIZE_FINDRADICALS|
971
+ Chem.SANITIZE_KEKULIZE|
972
+ Chem.SANITIZE_SETAROMATICITY|
973
+ Chem.SANITIZE_SETCONJUGATION|
974
+ Chem.SANITIZE_SETHYBRIDIZATION|
975
+ Chem.SANITIZE_CLEANUPCHIRALITY)
976
+
977
+ mol = Chem.AddHs(mol)
978
+ return mol
979
+
980
+ @staticmethod
981
+ def get_etkdg_params(attempt=0):
982
+ """Get ETKDG parameters with optional modifications based on attempt number"""
983
+ params = AllChem.ETKDGv3()
984
+ params.randomSeed = -1
985
+ params.maxIterations = 200
986
+ params.numThreads = 4 # Reduced for web interface
987
+ params.useBasicKnowledge = True
988
+ params.enforceChirality = True
989
+ params.useExpTorsionAnglePrefs = True
990
+ params.useSmallRingTorsions = True
991
+ params.useMacrocycleTorsions = True
992
+ params.ETversion = 2
993
+ params.pruneRmsThresh = -1
994
+ params.embedRmsThresh = 0.5
995
+
996
+ if attempt > 10:
997
+ params.bondLength = 1.5 + (attempt - 10) * 0.02
998
+ params.useExpTorsionAnglePrefs = False
999
+
1000
+ return params
1001
+
1002
+ def generate_structure_etkdg(self, smiles, max_attempts=20):
1003
+ """Generate 3D structure using ETKDG without UFF optimization"""
1004
+ success = False
1005
+ mol = None
1006
+
1007
+ for attempt in range(max_attempts):
1008
+ try:
1009
+ mol = self.prepare_molecule(smiles)
1010
+ params = self.get_etkdg_params(attempt)
1011
+
1012
+ if AllChem.EmbedMolecule(mol, params) == 0:
1013
+ success = True
1014
+ break
1015
+ except Exception as e:
1016
+ continue
1017
+
1018
+ if not success:
1019
+ raise ValueError("Failed to generate structure with ETKDG")
1020
+
1021
+ return mol
1022
+
1023
+ def generate_structure_uff(self, smiles, max_attempts=20):
1024
+ """Generate 3D structure using ETKDG followed by UFF optimization"""
1025
+ best_mol = None
1026
+ lowest_energy = float('inf')
1027
+
1028
+ for attempt in range(max_attempts):
1029
+ try:
1030
+ test_mol = self.prepare_molecule(smiles)
1031
+ params = self.get_etkdg_params(attempt)
1032
+
1033
+ if AllChem.EmbedMolecule(test_mol, params) == 0:
1034
+ res = AllChem.UFFOptimizeMolecule(test_mol, maxIters=2000,
1035
+ vdwThresh=10.0, confId=0,
1036
+ ignoreInterfragInteractions=True)
1037
+
1038
+ if res == 0:
1039
+ ff = AllChem.UFFGetMoleculeForceField(test_mol)
1040
+ if ff:
1041
+ current_energy = ff.CalcEnergy()
1042
+ if current_energy < lowest_energy:
1043
+ lowest_energy = current_energy
1044
+ best_mol = Chem.Mol(test_mol)
1045
+ except Exception:
1046
+ continue
1047
+
1048
+ if best_mol is None:
1049
+ raise ValueError("Failed to generate optimized structure")
1050
+
1051
+ return best_mol
1052
+
1053
+ @staticmethod
1054
+ def mol_to_sdf_bytes(mol):
1055
+ """Convert RDKit molecule to SDF file bytes"""
1056
+ # First write to StringIO in text mode
1057
+ sio = StringIO()
1058
+ writer = Chem.SDWriter(sio)
1059
+ writer.write(mol)
1060
+ writer.close()
1061
+
1062
+ # Convert the string to bytes
1063
+ return sio.getvalue().encode('utf-8')
1064
+
1065
+ def process_input(smiles_input=None, file_obj=None, show_linear=False,
1066
+ show_segment_details=False, generate_3d=False, use_uff=False):
1067
+ """Process input and create visualizations using PeptideAnalyzer"""
1068
+ analyzer = PeptideAnalyzer()
1069
+ temp_dir = tempfile.mkdtemp() if generate_3d else None
1070
+ structure_files = []
1071
+
1072
+ # Handle direct SMILES input
1073
+ if smiles_input:
1074
+ smiles = smiles_input.strip()
1075
+
1076
+ # First check if it's a peptide using analyzer's method
1077
+ if not analyzer.is_peptide(smiles):
1078
+ return "Error: Input SMILES does not appear to be a peptide structure.", None, None
1079
+
1080
+ try:
1081
+ # Create molecule
1082
+ mol = Chem.MolFromSmiles(smiles)
1083
+ if mol is None:
1084
+ return "Error: Invalid SMILES notation.", None, None
1085
+
1086
+ # Generate 3D structures if requested
1087
+ if generate_3d:
1088
+ generator = PeptideStructureGenerator()
1089
+
1090
+ try:
1091
+ # Generate ETKDG structure
1092
+ mol_etkdg = generator.generate_structure_etkdg(smiles)
1093
+ etkdg_path = os.path.join(temp_dir, "structure_etkdg.sdf")
1094
+ writer = Chem.SDWriter(etkdg_path)
1095
+ writer.write(mol_etkdg)
1096
+ writer.close()
1097
+ structure_files.append(etkdg_path)
1098
+
1099
+ # Generate UFF structure if requested
1100
+ if use_uff:
1101
+ mol_uff = generator.generate_structure_uff(smiles)
1102
+ uff_path = os.path.join(temp_dir, "structure_uff.sdf")
1103
+ writer = Chem.SDWriter(uff_path)
1104
+ writer.write(mol_uff)
1105
+ writer.close()
1106
+ structure_files.append(uff_path)
1107
+
1108
+ except Exception as e:
1109
+ return f"Error generating 3D structures: {str(e)}", None, None, None
1110
+
1111
+ # Use analyzer to get sequence
1112
+ segments = analyzer.split_on_bonds(smiles)
1113
+
1114
+ # Process segments and build sequence
1115
+ sequence_parts = []
1116
+ output_text = ""
1117
+
1118
+ # Only include segment analysis in output if requested
1119
+ if show_segment_details:
1120
+ output_text += "Segment Analysis:\n"
1121
+ for i, segment in enumerate(segments):
1122
+ output_text += f"\nSegment {i}:\n"
1123
+ output_text += f"Content: {segment['content']}\n"
1124
+ output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
1125
+ output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
1126
+
1127
+ residue, mods = analyzer.identify_residue(segment)
1128
+ if residue:
1129
+ if mods:
1130
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1131
+ else:
1132
+ sequence_parts.append(residue)
1133
+ output_text += f"Identified as: {residue}\n"
1134
+ output_text += f"Modifications: {mods}\n"
1135
+ else:
1136
+ output_text += f"Warning: Could not identify residue in segment: {segment['content']}\n"
1137
+ output_text += "\n"
1138
+ else:
1139
+ # Just build sequence without detailed analysis in output
1140
+ for segment in segments:
1141
+ residue, mods = analyzer.identify_residue(segment)
1142
+ if residue:
1143
+ if mods:
1144
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1145
+ else:
1146
+ sequence_parts.append(residue)
1147
+
1148
+ # Check if cyclic using analyzer's method
1149
+ is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
1150
+ three_letter = '-'.join(sequence_parts)
1151
+ one_letter = ''.join(analyzer.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence_parts)
1152
+
1153
+ if is_cyclic:
1154
+ three_letter = f"cyclo({three_letter})"
1155
+ one_letter = f"cyclo({one_letter})"
1156
+
1157
+ # Create cyclic structure visualization
1158
+ img_cyclic = annotate_cyclic_structure(mol, three_letter)
1159
+
1160
+ # Create linear representation if requested
1161
+ img_linear = None
1162
+ if show_linear:
1163
+ fig_linear = create_enhanced_linear_viz(three_letter, smiles)
1164
+ buf = BytesIO()
1165
+ fig_linear.savefig(buf, format='png', bbox_inches='tight', dpi=300)
1166
+ buf.seek(0)
1167
+ img_linear = Image.open(buf)
1168
+ plt.close(fig_linear)
1169
+
1170
+ # Add summary to output
1171
+ summary = "Summary:\n"
1172
+ summary += f"Sequence: {three_letter}\n"
1173
+ summary += f"One-letter code: {one_letter}\n"
1174
+ summary += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
1175
+ #if is_cyclic:
1176
+ #summary += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
1177
+ #summary += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
1178
+
1179
+ if structure_files:
1180
+ summary += "\n3D Structures Generated:\n"
1181
+ for filepath in structure_files:
1182
+ summary += f"- {os.path.basename(filepath)}\n"
1183
+
1184
+ return summary + output_text, img_cyclic, img_linear, structure_files if structure_files else None
1185
+
1186
+ except Exception as e:
1187
+ return f"Error processing SMILES: {str(e)}", None, None, None
1188
+
1189
+ # Handle file input
1190
+ if file_obj is not None:
1191
+ try:
1192
+ # Handle file content
1193
+ if hasattr(file_obj, 'name'):
1194
+ with open(file_obj.name, 'r') as f:
1195
+ content = f.read()
1196
+ else:
1197
+ content = file_obj.decode('utf-8') if isinstance(file_obj, bytes) else str(file_obj)
1198
+
1199
+ output_text = ""
1200
+ for line in content.splitlines():
1201
+ smiles = line.strip()
1202
+ if smiles:
1203
+ # Check if it's a peptide
1204
+ if not analyzer.is_peptide(smiles):
1205
+ output_text += f"Skipping non-peptide SMILES: {smiles}\n"
1206
+ continue
1207
+
1208
+ # Process this SMILES
1209
+ segments = analyzer.split_on_bonds(smiles)
1210
+ sequence_parts = []
1211
+
1212
+ # Add segment details if requested
1213
+ if show_segment_details:
1214
+ output_text += f"\nSegment Analysis for SMILES: {smiles}\n"
1215
+ for i, segment in enumerate(segments):
1216
+ output_text += f"\nSegment {i}:\n"
1217
+ output_text += f"Content: {segment['content']}\n"
1218
+ output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
1219
+ output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
1220
+ residue, mods = analyzer.identify_residue(segment)
1221
+ if residue:
1222
+ if mods:
1223
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1224
+ else:
1225
+ sequence_parts.append(residue)
1226
+ output_text += f"Identified as: {residue}\n"
1227
+ output_text += f"Modifications: {mods}\n"
1228
+ else:
1229
+ for segment in segments:
1230
+ residue, mods = analyzer.identify_residue(segment)
1231
+ if residue:
1232
+ if mods:
1233
+ sequence_parts.append(f"{residue}({','.join(mods)})")
1234
+ else:
1235
+ sequence_parts.append(residue)
1236
+
1237
+ # Get cyclicity and create sequence
1238
+ is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
1239
+ sequence = f"cyclo({'-'.join(sequence_parts)})" if is_cyclic else '-'.join(sequence_parts)
1240
+
1241
+ output_text += f"\nSummary for SMILES: {smiles}\n"
1242
+ output_text += f"Sequence: {sequence}\n"
1243
+ output_text += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
1244
+ if is_cyclic:
1245
+ output_text += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
1246
+ #output_text += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
1247
+ output_text += "-" * 50 + "\n"
1248
+
1249
+ return output_text, None, None
1250
+
1251
+ except Exception as e:
1252
+ return f"Error processing file: {str(e)}", None, None
1253
+
1254
+ return "No input provided.", None, None
1255
+
tr2d2-pep/utils/timer.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time, torch
2
+ from collections import defaultdict
3
+ from contextlib import contextmanager
4
+
5
+ class StepTimer:
6
+ def __init__(self, device=None):
7
+ self.times = defaultdict(list)
8
+ self.device = device
9
+ self._use_cuda_sync = (
10
+ isinstance(device, torch.device) and device.type == "cuda"
11
+ ) or (isinstance(device, str) and "cuda" in device)
12
+
13
+ @contextmanager
14
+ def section(self, name):
15
+ if self._use_cuda_sync:
16
+ torch.cuda.synchronize()
17
+ t0 = time.perf_counter()
18
+ try:
19
+ yield
20
+ finally:
21
+ if self._use_cuda_sync:
22
+ torch.cuda.synchronize()
23
+ dt = time.perf_counter() - t0
24
+ self.times[name].append(dt)
25
+
26
+ def summary(self, top_k=None):
27
+ # returns (name, count, total, mean, p50, p95)
28
+ import numpy as np
29
+ rows = []
30
+ for k, v in self.times.items():
31
+ a = np.array(v, dtype=float)
32
+ rows.append((k, len(a), a.sum(), a.mean(), np.median(a), np.percentile(a, 95)))
33
+ rows.sort(key=lambda r: r[2], reverse=True) # by total time
34
+ return rows[:top_k] if top_k else rows
tr2d2-pep/utils/utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Console logger utilities.
2
+
3
+ Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
4
+ Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
5
+ """
6
+
7
+ import logging
8
+ import fsspec
9
+ import lightning
10
+ import torch
11
+ from timm.scheduler import CosineLRScheduler
12
+ import argparse
13
+ import numpy as np
14
+ import random
15
+ import os
16
+
17
+ def sample_categorical_logits(logits, dtype=torch.float64):
18
+ # do not require logits to be log-softmaxed
19
+ gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log()
20
+ return (logits + gumbel_noise).argmax(dim=-1)
21
+
22
+ def fsspec_exists(filename):
23
+ """Check if a file exists using fsspec."""
24
+ fs, _ = fsspec.core.url_to_fs(filename)
25
+ return fs.exists(filename)
26
+
27
+
28
+ def fsspec_listdir(dirname):
29
+ """Listdir in manner compatible with fsspec."""
30
+ fs, _ = fsspec.core.url_to_fs(dirname)
31
+ return fs.ls(dirname)
32
+
33
+
34
+ def fsspec_mkdirs(dirname, exist_ok=True):
35
+ """Mkdirs in manner compatible with fsspec."""
36
+ fs, _ = fsspec.core.url_to_fs(dirname)
37
+ fs.makedirs(dirname, exist_ok=exist_ok)
38
+
39
+
40
+ def print_nans(tensor, name):
41
+ if torch.isnan(tensor).any():
42
+ print(name, tensor)
43
+
44
+
45
+ class CosineDecayWarmupLRScheduler(
46
+ CosineLRScheduler,
47
+ torch.optim.lr_scheduler._LRScheduler):
48
+
49
+ def __init__(self, *args, **kwargs):
50
+ super().__init__(*args, **kwargs)
51
+ self._last_epoch = -1
52
+ self.step(epoch=0)
53
+
54
+ def step(self, epoch=None):
55
+ if epoch is None:
56
+ self._last_epoch += 1
57
+ else:
58
+ self._last_epoch = epoch
59
+ # We call either step or step_update, depending on
60
+ # whether we're using the scheduler every epoch or every
61
+ # step.
62
+ # Otherwise, lightning will always call step (i.e.,
63
+ # meant for each epoch), and if we set scheduler
64
+ # interval to "step", then the learning rate update will
65
+ # be wrong.
66
+ if self.t_in_epochs:
67
+ super().step(epoch=self._last_epoch)
68
+ else:
69
+ super().step_update(num_updates=self._last_epoch)
70
+
71
+
72
+ class LoggingContext:
73
+ """Context manager for selective logging."""
74
+ def __init__(self, logger, level=None, handler=None, close=True):
75
+ self.logger = logger
76
+ self.level = level
77
+ self.handler = handler
78
+ self.close = close
79
+
80
+ def __enter__(self):
81
+ if self.level is not None:
82
+ self.old_level = self.logger.level
83
+ self.logger.setLevel(self.level)
84
+ if self.handler:
85
+ self.logger.addHandler(self.handler)
86
+
87
+ def __exit__(self, et, ev, tb):
88
+ if self.level is not None:
89
+ self.logger.setLevel(self.old_level)
90
+ if self.handler:
91
+ self.logger.removeHandler(self.handler)
92
+ if self.handler and self.close:
93
+ self.handler.close()
94
+
95
+
96
+ def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
97
+ """Initializes multi-GPU-friendly python logger."""
98
+
99
+ logger = logging.getLogger(name)
100
+ logger.setLevel(level)
101
+
102
+ # this ensures all logging levels get marked with the rank zero decorator
103
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
104
+ for level in ('debug', 'info', 'warning', 'error',
105
+ 'exception', 'fatal', 'critical'):
106
+ setattr(logger,
107
+ level,
108
+ lightning.pytorch.utilities.rank_zero_only(
109
+ getattr(logger, level)))
110
+
111
+ return logger
112
+
113
+
114
+ def str2bool(v):
115
+ if isinstance(v, bool):
116
+ return v
117
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
118
+ return True
119
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
120
+ return False
121
+ else:
122
+ raise argparse.ArgumentTypeError('Boolean value expected.')
123
+
124
+
125
+ def set_seed(seed, use_cuda):
126
+ os.environ['PYTHONHASHSEED'] = str(seed)
127
+ np.random.seed(seed)
128
+ random.seed(seed)
129
+ torch.manual_seed(seed)
130
+ # torch.backends.cudnn.deterministic = True
131
+ if use_cuda:
132
+ torch.cuda.manual_seed(seed)
133
+ torch.cuda.manual_seed_all(seed)
134
+ print(f'=> Seed of the run set to {seed}')
135
+