Sophia Tang
commited on
Commit
·
5e90249
1
Parent(s):
9aa9a1f
Initial commit
Browse files- .gitattributes +4 -0
- .gitignore +17 -0
- README.md +46 -0
- assets/anim-good.gif +3 -0
- assets/peptides.png +3 -0
- tr2d2-pep/README.md +41 -0
- tr2d2-pep/configs/peptune_config.yaml +159 -0
- tr2d2-pep/diffusion.py +1526 -0
- tr2d2-pep/finetune.py +133 -0
- tr2d2-pep/finetune.sh +37 -0
- tr2d2-pep/finetune_peptides.py +193 -0
- tr2d2-pep/finetune_utils.py +138 -0
- tr2d2-pep/generate_mcts.py +192 -0
- tr2d2-pep/metrics.py +71 -0
- tr2d2-pep/noise_schedule.py +150 -0
- tr2d2-pep/peptide_mcts.py +648 -0
- tr2d2-pep/plotting.py +148 -0
- tr2d2-pep/roformer.py +74 -0
- tr2d2-pep/run_mcts.sh +29 -0
- tr2d2-pep/scoring/functions/binding.py +178 -0
- tr2d2-pep/scoring/functions/binding_utils.py +290 -0
- tr2d2-pep/scoring/functions/classifiers/hemolysis-xgboost.json +0 -0
- tr2d2-pep/scoring/functions/classifiers/nonfouling-xgboost.json +0 -0
- tr2d2-pep/scoring/functions/classifiers/permeability-xgboost.json +3 -0
- tr2d2-pep/scoring/functions/classifiers/solubility-xgboost.json +0 -0
- tr2d2-pep/scoring/functions/hemolysis.py +63 -0
- tr2d2-pep/scoring/functions/nonfouling.py +66 -0
- tr2d2-pep/scoring/functions/permeability.py +171 -0
- tr2d2-pep/scoring/functions/scoring_utils.py +94 -0
- tr2d2-pep/scoring/functions/solubility.py +63 -0
- tr2d2-pep/scoring/scoring_functions.py +77 -0
- tr2d2-pep/scoring/tokenizer/my_tokenizers.py +424 -0
- tr2d2-pep/scoring/tokenizer/new_splits.txt +159 -0
- tr2d2-pep/scoring/tokenizer/new_vocab.txt +587 -0
- tr2d2-pep/tokenizer/my_tokenizers.py +424 -0
- tr2d2-pep/tokenizer/new_splits.txt +159 -0
- tr2d2-pep/tokenizer/new_vocab.txt +587 -0
- tr2d2-pep/utils/app.py +1255 -0
- tr2d2-pep/utils/timer.py +34 -0
- tr2d2-pep/utils/utils.py +135 -0
.gitattributes
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tr2d2-pep/scoring/functions/classifiers/permeability-xgboost.json filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
assets/peptides.png filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
assets/tr2d2-anim.gif filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
assets/anim-good.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.venv/
|
| 4 |
+
.env
|
| 5 |
+
.DS_Store
|
| 6 |
+
tr2d2-pep/wandb/
|
| 7 |
+
tr2d2-pep/pretrained/
|
| 8 |
+
tr2d2-pep/logs/
|
| 9 |
+
tr2d2-pep/__pycache__/
|
| 10 |
+
tr2d2-pep/scoring/__pycache__/
|
| 11 |
+
tr2d2-pep/tokenizer/__pycache__/
|
| 12 |
+
tr2d2-pep/utils/__pycache__/
|
| 13 |
+
*.pyc
|
| 14 |
+
*.log
|
| 15 |
+
*.ipynb
|
| 16 |
+
*.pt
|
| 17 |
+
tr2d2-pep/scoring/functions/classifiers/best_model.pt
|
README.md
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [TR2-D2: Tree Search Guided Trajectory-Aware Fine-Tuning for Discrete Diffusion](https://arxiv.org/abs/2509.25171) 🤖🌳
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
[**Sophia Tang**](https://sophtang.github.io/)\*, [**Yuchen Zhu**](https://yuchen-zhu-zyc.github.io/)\*, [**Molei Tao**](https://mtao8.math.gatech.edu/), and [**Pranam Chatterjee**](https://www.chatterjeelab.com/)
|
| 6 |
+
|
| 7 |
+

|
| 8 |
+
|
| 9 |
+
This is the repository for **[TR2-D2: Tree Search Guided Trajectory-Aware Fine-Tuning for Discrete Diffusion](https://arxiv.org/abs/2509.25171)** 🤖🌳. It is partially built on the **[PepTune repo](https://github.com/programmablebio/peptune)** ([Tang et al. 2024](https://arxiv.org/abs/2412.17780)) and **MDNS** ([Zhu et al. 2025](https://arxiv.org/abs/2508.10684)).
|
| 10 |
+
|
| 11 |
+
Inspired by the incredible success of off-policy reinforcement learning (RL), **TR2-D2** introduces a general framework that enhances the performance of off-policy RL with tree search for discrete diffusion fine-tuning.
|
| 12 |
+
|
| 13 |
+
🤖 Off-policy RL enables learning from diffusion trajectories from the non-gradient tracking policy model by storing sampling trajectories in a replay buffer for repeated use.
|
| 14 |
+
|
| 15 |
+
🌳 Tree search balances exploration and exploitation to generate optimal diffusion trajectories, and stores the optimal sequences in the buffer.
|
| 16 |
+
|
| 17 |
+
We use this framework to develop an efficient discrete diffusion fine-tuning strategy that leverages **Monte-Carlo Tree Search (MCTS)** to curate a replay buffer of optimal trajectories combined with an **off-policy control-based RL algorithm grounded in stochastic optimal control theory**, yielding theoretically guaranteed convergence to the optimal distribution. 🌟
|
| 18 |
+
|
| 19 |
+
### Regulatory DNA Sequence Design 🧬
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
In this experiment, we fine-tune the pre-trained **DNA enhancer MDM from DRAKES** (Wang et al. 2025) trained on **~700k HepG2 sequences** to optimize the measured enhancer activity using the reward oracles from DRAKES. Code and instructions to reproduce our results are provided in `/tr2d2-dna`.
|
| 24 |
+
|
| 25 |
+
### Multi-Objective Therapeutic Peptide Design 🧫
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
In this experiment, we fine-tune the pre-trained **unconditional peptide SMILES MDM from PepTune** ([Tang et al. 2024](https://arxiv.org/abs/2412.17780)) ****to optimize **multiple therapeutic properties**, including target protein binding affinity, solubility, non-hemolysis, non-fouling, and permeability. We show that one-shot generation from the fine-tuned policy outperforms inference-time multi-objective guidance, marking a significant advance over prior fine-tuning methods. Code and instructions to reproduce our results are provided in `/tr2d2-pep`.
|
| 30 |
+
|
| 31 |
+

|
| 32 |
+
|
| 33 |
+
## Citation
|
| 34 |
+
|
| 35 |
+
If you find this repository helpful for your publications, please consider citing our paper:
|
| 36 |
+
|
| 37 |
+
```python
|
| 38 |
+
@article{tang2024tr2d2,
|
| 39 |
+
title={TR2-D2: Tree Search Guided Trajectory-Aware Fine-Tuning for Discrete Diffusion},
|
| 40 |
+
author={Sophia Tang and Yuchen Zhu and Molei Tao and Pranam Chatterjee},
|
| 41 |
+
journal={arXiv preprint arXiv:2509.25171},
|
| 42 |
+
year={2025}
|
| 43 |
+
}
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
To use this repository, you agree to abide by the [PepTune License](https://drive.google.com/file/d/1Hsu91wTmxyoJLNJzfPDw5_nTbxVySP5x/view?usp=sharing).
|
assets/anim-good.gif
ADDED
|
Git LFS Details
|
assets/peptides.png
ADDED
|
Git LFS Details
|
tr2d2-pep/README.md
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TR2-D2 For Multi-Objective Therapeutic Peptide Design 🧫
|
| 2 |
+
|
| 3 |
+
This part of the code is for finetuning a peptide MDM to optimize multiple therapeutic properties, including binding affinity to a protein target, solubility, non-hemolysis, non-fouling, and cell membrane permeability, with TR2-D2.
|
| 4 |
+
|
| 5 |
+
The codebase is built upon [PepTune (Tang et.al, 2024)](https://arxiv.org/abs/2412.17780), [MDLM (Sahoo et.al, 2023)](https://github.com/kuleshov-group/mdlm), [SEPO (Zekri et.al, 2025)](https://github.com/ozekri/SEPO/tree/main), and [MDNS (Zhu et.al, 2025)](https://arxiv.org/abs/2508.10684).
|
| 6 |
+
|
| 7 |
+
## Environment Installation
|
| 8 |
+
```
|
| 9 |
+
conda env create -f environment.yml
|
| 10 |
+
|
| 11 |
+
conda activate tr2d2-pep
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
## Model Pretrained Weights Download
|
| 15 |
+
|
| 16 |
+
Follow the steps below to download the model weights requried for this experiment, which is originally from [PepTune](https://arxiv.org/abs/2412.17780).
|
| 17 |
+
|
| 18 |
+
1. Download the PepTune pre-trained MDLM and place in `/TR2-D2/peptides/pretrained/`: https://drive.google.com/file/d/1oXGDpKLNF0KX0ZdOcl1NZj5Czk2lSFUn/view?usp=sharing
|
| 19 |
+
2. Download the pre-trained binding affinity Transformer model and place in `/TR2-D2/tr2d2-pep/scoring/functions/classifiers/`: https://drive.google.com/file/d/128shlEP_-rYAxPgZRCk_n0HBWVbOYSva/view?usp=sharing
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
## Finetune with TR2-D2
|
| 24 |
+
After downloading the pretrained checkpoints, follow the steps below to run fine-tuning:
|
| 25 |
+
|
| 26 |
+
1. Fill in the `base_path` in `scoring/scoring_functions.py` and `diffusion.py`.
|
| 27 |
+
2. Fill in `HOME_LOC` to the base path where `TR2-D2` is located and `ENV_PATH` to the directory where your environment is downloaded in `finetune.sh`.
|
| 28 |
+
3. Create a path `tr2d2-pep/results` where the fine-tuning curves and generation results will be saved and `tr2d2-pep/checkpoints` for checkpoint saving. Also, create `tr2d2-pep/logs` where the training logs will be saved.
|
| 29 |
+
3. To specify a target protein, set `--prot_seq <insert amino acid sequence>` and `--prot_name <insert protein name>`. Default protein is Transferrin receptor (TfR).
|
| 30 |
+
|
| 31 |
+
Run fine-tuning using `nohup` with the following commands:
|
| 32 |
+
```
|
| 33 |
+
chmod +x finetune.sh
|
| 34 |
+
|
| 35 |
+
nohup ./finetune.sh > finetune.log 2>&1 &
|
| 36 |
+
```
|
| 37 |
+
Evaluation will run automatically after the specified number of fine-tuning epochs `--num_epochs` is finished. To summarize metrics, fill in `path` and `prot_name` in `metrics.py` and run:
|
| 38 |
+
|
| 39 |
+
```
|
| 40 |
+
python metrics.py
|
| 41 |
+
```
|
tr2d2-pep/configs/peptune_config.yaml
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
noise:
|
| 2 |
+
type: loglinear
|
| 3 |
+
sigma_min: 1e-4
|
| 4 |
+
sigma_max: 20
|
| 5 |
+
state_dependent: True
|
| 6 |
+
|
| 7 |
+
mode: ppl_eval # train / ppl_eval / sample_eval
|
| 8 |
+
diffusion: absorbing_state
|
| 9 |
+
vocab: old_smiles # old_smiles / new_smiles / selfies / helm
|
| 10 |
+
backbone: roformer # peptideclm / helmgpt / dit / roformer / finetune_roformer
|
| 11 |
+
parameterization: subs # subs
|
| 12 |
+
time_conditioning: False
|
| 13 |
+
T: 0 # 0 (continuous time) / 1000
|
| 14 |
+
subs_masking: False
|
| 15 |
+
|
| 16 |
+
seed: 42
|
| 17 |
+
|
| 18 |
+
mcts:
|
| 19 |
+
num_children: 50
|
| 20 |
+
num_objectives: 5
|
| 21 |
+
topk: 100
|
| 22 |
+
mask_token: 4
|
| 23 |
+
num_iter: 128
|
| 24 |
+
sampling: 0 # 0 is gumbel sampling / > 0 samples children from top k probs
|
| 25 |
+
invalid_penalty: 0.5
|
| 26 |
+
sample_prob: 1.0
|
| 27 |
+
perm: True
|
| 28 |
+
dual: False
|
| 29 |
+
single: False
|
| 30 |
+
time_dependent: True
|
| 31 |
+
|
| 32 |
+
lr_scheduler:
|
| 33 |
+
_target_: transformers.get_constant_schedule_with_warmup
|
| 34 |
+
num_warmup_steps: 2500
|
| 35 |
+
|
| 36 |
+
data:
|
| 37 |
+
train: /home/st512/peptune/scripts/peptide-mdlm-mcts/data/finetune2/30K-train.csv
|
| 38 |
+
valid: /home/st512/peptune/scripts/peptide-mdlm-mcts/data/finetune2/30K-val.csv
|
| 39 |
+
batchinohup ng: wrapping # padding / wrapping
|
| 40 |
+
|
| 41 |
+
loader:
|
| 42 |
+
global_batch_size: 64
|
| 43 |
+
eval_global_batch_size: ${.global_batch_size}
|
| 44 |
+
# Note: batch_size and eval_batch_size are **per machine**
|
| 45 |
+
batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 46 |
+
eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 47 |
+
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
|
| 48 |
+
pin_memory: True
|
| 49 |
+
|
| 50 |
+
sampling:
|
| 51 |
+
predictor: ddpm_cache # analytic, ddpm, ddpm_cache
|
| 52 |
+
num_sequences: 100
|
| 53 |
+
sampling_eps: 1e-3
|
| 54 |
+
steps: 128
|
| 55 |
+
seq_length: 100
|
| 56 |
+
noise_removal: True
|
| 57 |
+
num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
|
| 58 |
+
num_sample_log: 2
|
| 59 |
+
stride_length: 1
|
| 60 |
+
num_strides: 1
|
| 61 |
+
|
| 62 |
+
training:
|
| 63 |
+
antithetic_sampling: True
|
| 64 |
+
sampling_eps: 1e-3
|
| 65 |
+
focus_mask: False
|
| 66 |
+
#dynamic_batching: True
|
| 67 |
+
accumulator: False
|
| 68 |
+
|
| 69 |
+
eval:
|
| 70 |
+
checkpoint_path:
|
| 71 |
+
disable_ema: False
|
| 72 |
+
compute_generative_perplexity: False
|
| 73 |
+
perplexity_batch_size: 8
|
| 74 |
+
compute_perplexity_on_sanity: False
|
| 75 |
+
gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
|
| 76 |
+
generate_samples: True
|
| 77 |
+
generation_model:
|
| 78 |
+
|
| 79 |
+
optim:
|
| 80 |
+
weight_decay: 0.075
|
| 81 |
+
lr: 3e-4
|
| 82 |
+
beta1: 0.9
|
| 83 |
+
beta2: 0.999
|
| 84 |
+
eps: 1e-8
|
| 85 |
+
|
| 86 |
+
pepclm:
|
| 87 |
+
hidden_size: 768
|
| 88 |
+
cond_dim: 256
|
| 89 |
+
n_heads: 20
|
| 90 |
+
n_blocks: 4
|
| 91 |
+
dropout: 0.5
|
| 92 |
+
length: 512
|
| 93 |
+
#scale_by_sigma: True
|
| 94 |
+
|
| 95 |
+
model:
|
| 96 |
+
type: ddit
|
| 97 |
+
hidden_size: 768
|
| 98 |
+
cond_dim: 128
|
| 99 |
+
length: 512
|
| 100 |
+
n_blocks: 12
|
| 101 |
+
n_heads: 12
|
| 102 |
+
scale_by_sigma: True
|
| 103 |
+
dropout: 0.1
|
| 104 |
+
|
| 105 |
+
roformer:
|
| 106 |
+
hidden_size: 768
|
| 107 |
+
n_layers: 8
|
| 108 |
+
n_heads: 8
|
| 109 |
+
max_position_embeddings: 1035
|
| 110 |
+
|
| 111 |
+
helmgpt:
|
| 112 |
+
hidden_size: 256
|
| 113 |
+
embd_pdrop: 0.1
|
| 114 |
+
resid_pdrop: 0.1
|
| 115 |
+
attn_pdrop: 0.1
|
| 116 |
+
ff_dropout: 0.
|
| 117 |
+
block_size: 140
|
| 118 |
+
n_layer: 8
|
| 119 |
+
n_heads: 8
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
trainer:
|
| 123 |
+
_target_: lightning.Trainer
|
| 124 |
+
accelerator: cuda
|
| 125 |
+
num_nodes: 1
|
| 126 |
+
devices: ${device_count:}
|
| 127 |
+
accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
|
| 128 |
+
gradient_clip_val: 1.0
|
| 129 |
+
precision: 64-true
|
| 130 |
+
num_sanity_val_steps: 2
|
| 131 |
+
max_epochs: 100
|
| 132 |
+
max_steps: 1_000_000
|
| 133 |
+
log_every_n_steps: 10
|
| 134 |
+
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
|
| 135 |
+
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
|
| 136 |
+
#val_check_interval: 40 #954
|
| 137 |
+
check_val_every_n_epoch: 1
|
| 138 |
+
|
| 139 |
+
hydra:
|
| 140 |
+
run:
|
| 141 |
+
dir: ./${now:%Y.%m.%d}/
|
| 142 |
+
job:
|
| 143 |
+
chdir: True
|
| 144 |
+
|
| 145 |
+
checkpointing:
|
| 146 |
+
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
|
| 147 |
+
save_dir: ${cwd:}
|
| 148 |
+
# Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
|
| 149 |
+
resume_from_ckpt: True
|
| 150 |
+
resume_ckpt_path:
|
| 151 |
+
|
| 152 |
+
callbacks:
|
| 153 |
+
model_checkpoint:
|
| 154 |
+
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
| 155 |
+
every_n_epochs: 1
|
| 156 |
+
monitor: "val/nll"
|
| 157 |
+
save_top_k: 10
|
| 158 |
+
mode: "min"
|
| 159 |
+
dirpath:
|
tr2d2-pep/diffusion.py
ADDED
|
@@ -0,0 +1,1526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import sys
|
| 3 |
+
import itertools
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
import math
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
import random as rd
|
| 11 |
+
import lightning as L
|
| 12 |
+
import torchmetrics
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
import gc
|
| 15 |
+
import utils.utils as utils
|
| 16 |
+
|
| 17 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 18 |
+
import noise_schedule
|
| 19 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 20 |
+
import roformer as roformer
|
| 21 |
+
from utils.app import PeptideAnalyzer
|
| 22 |
+
import pandas as pd
|
| 23 |
+
|
| 24 |
+
base_path = '/path/to/your/home'
|
| 25 |
+
|
| 26 |
+
def _sample_categorical(categorical_probs):
|
| 27 |
+
gumbel_norm = (
|
| 28 |
+
1e-10
|
| 29 |
+
- (torch.rand_like(categorical_probs) + 1e-10).log())
|
| 30 |
+
return (categorical_probs / gumbel_norm).argmax(dim=-1).to(dtype=torch.long)
|
| 31 |
+
|
| 32 |
+
def _sample_categorical_gradient(categorical_probs, temp = 1.0):
|
| 33 |
+
gumbel_norm = (
|
| 34 |
+
1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log())
|
| 35 |
+
output = torch.nn.functional.softmax((torch.log(categorical_probs)-torch.log(gumbel_norm))/temp, 2)
|
| 36 |
+
return output
|
| 37 |
+
|
| 38 |
+
def _unsqueeze(x, reference):
|
| 39 |
+
return x.view(
|
| 40 |
+
* x.shape,
|
| 41 |
+
* ((1,) * (len(reference.shape) - len(x.shape))))
|
| 42 |
+
|
| 43 |
+
def sample_batched_categorical(categorical_probs, batch_size):
|
| 44 |
+
"""
|
| 45 |
+
Generates `m` distinct sequences sampled from categorical probabilities
|
| 46 |
+
using the Gumbel distribution to ensure randomness while following probabilities
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
categorical_probs (torch.Tensor): tensor of shape (sequence_length, vocab_length)
|
| 50 |
+
representing categorical probabilities
|
| 51 |
+
m (int): number of distinct sequences to sample
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
torch.Tensor: tensor of shape (m, sequence_length), where each row is a
|
| 55 |
+
distinct sequence of sampled category indices.
|
| 56 |
+
"""
|
| 57 |
+
_, sequence_length, vocab_size = categorical_probs.shape
|
| 58 |
+
|
| 59 |
+
# add Gumbel noise and sample m sequences
|
| 60 |
+
gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_size) + 1e-10) + 1e-10)).to(categorical_probs.device)
|
| 61 |
+
noisy_scores = torch.log(categorical_probs) + gumbel_noise # add Gumbel noise to log probabilities
|
| 62 |
+
|
| 63 |
+
# select the highest score (most likely category after Gumbel noise)
|
| 64 |
+
sampled_sequences = noisy_scores.argmax(dim=-1).to(dtype=torch.long) # shape: (m, sequence_length)
|
| 65 |
+
|
| 66 |
+
return sampled_sequences
|
| 67 |
+
|
| 68 |
+
def sample_batched_top_k(categorical_probs, batch_size, k):
|
| 69 |
+
"""
|
| 70 |
+
Generates `m` sequences sampled from the top-k probabilities of each token
|
| 71 |
+
using Gumbel noise to ensure randomness and reduce bias towards the most likely options.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
categorical_probs (torch.Tensor): A tensor of shape (sequence_length, vocab_length)
|
| 75 |
+
representing categorical probabilities.
|
| 76 |
+
m (int): Number of sequences to sample.
|
| 77 |
+
k (int): Number of top probabilities to consider for sampling.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
torch.Tensor: A tensor of shape (m, sequence_length), where each row is a
|
| 81 |
+
sampled sequence of category indices.
|
| 82 |
+
"""
|
| 83 |
+
_, sequence_length, vocab_length = categorical_probs.shape
|
| 84 |
+
|
| 85 |
+
# Add Gumbel noise to the log probabilities
|
| 86 |
+
gumbel_noise = -torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_length) + 1e-10) + 1e-10).to(categorical_probs.device)
|
| 87 |
+
noisy_scores = torch.log(categorical_probs[None, :, :]) + gumbel_noise # Shape: (m, sequence_length, vocab_length)
|
| 88 |
+
|
| 89 |
+
# Get the top-k categories based on noisy scores
|
| 90 |
+
top_k_scores, top_k_indices = torch.topk(noisy_scores, k, dim=-1) # Shape: (m, sequence_length, k)
|
| 91 |
+
|
| 92 |
+
# Convert top-k scores back to probabilities and normalize
|
| 93 |
+
top_k_probs = torch.softmax(top_k_scores, dim=-1).to(categorical_probs.device) # Shape: (m, sequence_length, k)
|
| 94 |
+
|
| 95 |
+
# Sample randomly from the top-k probabilities
|
| 96 |
+
sampled_indices_in_top_k = torch.multinomial(top_k_probs.reshape(-1, k), num_samples=1).squeeze(-1).to(categorical_probs.device)
|
| 97 |
+
sampled_indices_in_top_k = sampled_indices_in_top_k.view(batch_size, sequence_length).to(categorical_probs.device) # Shape: (batch_size, sequence_length)
|
| 98 |
+
|
| 99 |
+
# Map sampled indices back to the original vocabulary indices
|
| 100 |
+
sampled_sequences = torch.gather(top_k_indices, -1, sampled_indices_in_top_k.unsqueeze(-1)).squeeze(-1).to(categorical_probs.device).to(dtype=torch.long)
|
| 101 |
+
|
| 102 |
+
return sampled_sequences
|
| 103 |
+
|
| 104 |
+
@dataclass
|
| 105 |
+
class Loss:
|
| 106 |
+
loss: torch.FloatTensor
|
| 107 |
+
nlls: torch.FloatTensor
|
| 108 |
+
attn_mask: torch.FloatTensor
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class NLL(torchmetrics.aggregation.MeanMetric):
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class BPD(NLL):
|
| 116 |
+
def compute(self) -> Tensor:
|
| 117 |
+
"""Computes the bits per dimension.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
bpd
|
| 121 |
+
"""
|
| 122 |
+
return self.mean_value / self.weight / math.log(2)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class Perplexity(NLL):
|
| 126 |
+
def compute(self) -> Tensor:
|
| 127 |
+
"""Computes the Perplexity.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Perplexity
|
| 131 |
+
"""
|
| 132 |
+
return torch.exp(self.mean_value / self.weight)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class Diffusion(L.LightningModule):
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
config,
|
| 139 |
+
tokenizer = None,
|
| 140 |
+
mode="finetune",
|
| 141 |
+
device=None,
|
| 142 |
+
):
|
| 143 |
+
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.config = config
|
| 146 |
+
#self.save_hyperparameters()
|
| 147 |
+
|
| 148 |
+
# PeptideCLM tokenizer
|
| 149 |
+
if tokenizer is None:
|
| 150 |
+
self.tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/TR2-D2/tr2d2-pep/tokenizer/new_vocab.txt',
|
| 151 |
+
f'{base_path}/TR2-D2/tr2d2-pep/tokenizer/new_splits.txt')
|
| 152 |
+
else:
|
| 153 |
+
self.tokenizer = tokenizer
|
| 154 |
+
|
| 155 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 156 |
+
self.mask_index = self.tokenizer.mask_token_id
|
| 157 |
+
self.sampler = self.config.sampling.predictor
|
| 158 |
+
self.analyzer = PeptideAnalyzer()
|
| 159 |
+
|
| 160 |
+
# backbone LM PeptideCLM model
|
| 161 |
+
self.backbone = roformer.Roformer(self.config, self.tokenizer, device=device)
|
| 162 |
+
if mode == "finetune":
|
| 163 |
+
self.backbone.freeze_model()
|
| 164 |
+
self.backbone.unfreeze_n_layers(n=8)
|
| 165 |
+
elif mode == "eval":
|
| 166 |
+
self.backbone.freeze_model()
|
| 167 |
+
self.backbone.requires_grad_(False)
|
| 168 |
+
self.backbone.eval()
|
| 169 |
+
elif mode == "train":
|
| 170 |
+
self.backbone.requires_grad_(True)
|
| 171 |
+
self.backbone.train()
|
| 172 |
+
|
| 173 |
+
self.neg_infinity = -1000000.0
|
| 174 |
+
self.T = config.T
|
| 175 |
+
# noise schedule for non-peptide bond tokens (default to log-linear)
|
| 176 |
+
self.noise = noise_schedule.get_noise(config)
|
| 177 |
+
|
| 178 |
+
# noise schedule for peptide bonds (log-polynomial)
|
| 179 |
+
self.bond_noise = noise_schedule.LogPolyNoise()
|
| 180 |
+
self.time_conditioning = self.config.time_conditioning
|
| 181 |
+
self.fast_forward_epochs = None
|
| 182 |
+
self.fast_forward_batches = None
|
| 183 |
+
|
| 184 |
+
self.gen_ppl_eval_model_name_or_path = self.config.eval.gen_ppl_eval_model_name_or_path
|
| 185 |
+
self.gen_ppl_metric = Perplexity()
|
| 186 |
+
|
| 187 |
+
self.lr = self.config.optim.lr
|
| 188 |
+
self.sampling_eps = self.config.training.sampling_eps
|
| 189 |
+
|
| 190 |
+
metrics = torchmetrics.MetricCollection({
|
| 191 |
+
'nll': NLL(),
|
| 192 |
+
'bpd': BPD(),
|
| 193 |
+
'ppl': Perplexity(),
|
| 194 |
+
})
|
| 195 |
+
metrics.set_dtype(torch.float64)
|
| 196 |
+
self.train_metrics = metrics.clone(prefix='trainer/')
|
| 197 |
+
self.valid_metrics = metrics.clone(prefix='val/')
|
| 198 |
+
self.test_metrics = metrics.clone(prefix='test/')
|
| 199 |
+
|
| 200 |
+
### FOR THE EXPANSION AND ROLLOUT STEP ###
|
| 201 |
+
def sample_finetuned_with_rnd(self, args, reward_model, pretrained, eps=1e-5):
|
| 202 |
+
num_steps = args.total_num_steps
|
| 203 |
+
B = args.batch_size
|
| 204 |
+
x_rollout = self.sample_prior(
|
| 205 |
+
B, args.seq_length).to(self.device)
|
| 206 |
+
|
| 207 |
+
log_rnd = torch.zeros(args.batch_size, device=self.device)
|
| 208 |
+
|
| 209 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
|
| 210 |
+
dt = (1 - eps) / num_steps
|
| 211 |
+
|
| 212 |
+
for i in range(num_steps):
|
| 213 |
+
t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
|
| 214 |
+
|
| 215 |
+
log_p, x_next, log_policy_step, log_pretrained_step = \
|
| 216 |
+
self.mcts_reverse_step(x_rollout, t=t, dt=dt, pretrained=pretrained)
|
| 217 |
+
|
| 218 |
+
log_rnd += log_pretrained_step - log_policy_step
|
| 219 |
+
|
| 220 |
+
x_rollout = x_next
|
| 221 |
+
|
| 222 |
+
# if mask token remains, fully unmask
|
| 223 |
+
mask_positions = (x_rollout == self.mask_index) # (B, L) bool
|
| 224 |
+
|
| 225 |
+
# does **any** mask remain in any sequence
|
| 226 |
+
any_mask_global = mask_positions.any().item() # true if mask remains
|
| 227 |
+
if any_mask_global:
|
| 228 |
+
log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
|
| 229 |
+
|
| 230 |
+
x_rollout = x_next
|
| 231 |
+
|
| 232 |
+
childSequences = self.tokenizer.batch_decode(x_rollout)
|
| 233 |
+
|
| 234 |
+
# change rewards for peptides
|
| 235 |
+
valid_x_final = []
|
| 236 |
+
validSequences = []
|
| 237 |
+
valid_log_rnd = []
|
| 238 |
+
|
| 239 |
+
for i in range(B):
|
| 240 |
+
# string sequence
|
| 241 |
+
childSeq = childSequences[i]
|
| 242 |
+
|
| 243 |
+
# check if the peptide is valid
|
| 244 |
+
if self.analyzer.is_peptide(childSeq):
|
| 245 |
+
valid_x_final.append(x_rollout[i])
|
| 246 |
+
validSequences.append(childSeq)
|
| 247 |
+
valid_log_rnd.append(log_rnd[i])
|
| 248 |
+
|
| 249 |
+
# compute multi-objective rewards
|
| 250 |
+
score_vectors = reward_model(input_seqs=validSequences)
|
| 251 |
+
scalar_rewards = np.sum(score_vectors, axis=-1)
|
| 252 |
+
scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=self.device)
|
| 253 |
+
|
| 254 |
+
print(f"scalar reward dim{len(scalar_rewards)}")
|
| 255 |
+
valid_log_rnd = torch.stack(valid_log_rnd, dim=0)
|
| 256 |
+
|
| 257 |
+
log_rnd = valid_log_rnd + (scalar_rewards / args.alpha) # scale down by alpha
|
| 258 |
+
valid_x_final = torch.stack(valid_x_final, dim=0)
|
| 259 |
+
|
| 260 |
+
return valid_x_final, log_rnd, scalar_rewards
|
| 261 |
+
|
| 262 |
+
def sample_finetuned(self, args, reward_model, batch_size=None, dataframe=False, eps=1e-5):
|
| 263 |
+
torch.cuda.empty_cache()
|
| 264 |
+
self.backbone.eval()
|
| 265 |
+
self.noise.eval()
|
| 266 |
+
print(f"device:{self.device}")
|
| 267 |
+
|
| 268 |
+
if batch_size is None:
|
| 269 |
+
batch_size = args.batch_size
|
| 270 |
+
|
| 271 |
+
num_steps = args.total_num_steps
|
| 272 |
+
x_rollout = self.sample_prior(
|
| 273 |
+
batch_size,
|
| 274 |
+
args.seq_length).to(self.device, dtype=torch.long)
|
| 275 |
+
|
| 276 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
|
| 277 |
+
dt = torch.tensor((1 - eps) / num_steps, device=self.device)
|
| 278 |
+
|
| 279 |
+
for i in range(num_steps):
|
| 280 |
+
t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
|
| 281 |
+
|
| 282 |
+
log_p, x_next = self.single_reverse_step(x_rollout, t=t, dt=dt)
|
| 283 |
+
|
| 284 |
+
x_rollout = x_next
|
| 285 |
+
x_rollout = x_rollout.to(self.device)
|
| 286 |
+
|
| 287 |
+
# if mask token remains, fully unmask
|
| 288 |
+
mask_positions = (x_rollout == self.mask_index) # (B, L) bool
|
| 289 |
+
|
| 290 |
+
# does **any** mask remain in any sequence
|
| 291 |
+
any_mask_global = mask_positions.any().item() # true if mask remains
|
| 292 |
+
if any_mask_global:
|
| 293 |
+
log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
|
| 294 |
+
|
| 295 |
+
x_rollout = x_next
|
| 296 |
+
x_rollout = x_rollout.to(self.device)
|
| 297 |
+
|
| 298 |
+
childSequences = self.tokenizer.batch_decode(x_rollout)
|
| 299 |
+
valid_x_final = []
|
| 300 |
+
validSequences = []
|
| 301 |
+
|
| 302 |
+
for idx, seq in enumerate(childSequences):
|
| 303 |
+
if self.analyzer.is_peptide(seq):
|
| 304 |
+
valid_x_final.append(x_rollout[idx])
|
| 305 |
+
validSequences.append(seq)
|
| 306 |
+
|
| 307 |
+
valid_fraction = len(validSequences) / batch_size
|
| 308 |
+
|
| 309 |
+
if (len(validSequences) != 0):
|
| 310 |
+
# add scores to log
|
| 311 |
+
score_vectors = reward_model(input_seqs=validSequences) # (num_children, num_objectives)
|
| 312 |
+
average_scores = score_vectors.T
|
| 313 |
+
|
| 314 |
+
affinity = average_scores[0]
|
| 315 |
+
sol = average_scores[1]
|
| 316 |
+
hemo = average_scores[2]
|
| 317 |
+
nf = average_scores[3]
|
| 318 |
+
permeability = average_scores[4]
|
| 319 |
+
|
| 320 |
+
else:
|
| 321 |
+
zeros = [0.0]
|
| 322 |
+
|
| 323 |
+
affinity = zeros
|
| 324 |
+
sol = zeros
|
| 325 |
+
hemo = zeros
|
| 326 |
+
nf = zeros
|
| 327 |
+
permeability = zeros
|
| 328 |
+
|
| 329 |
+
if dataframe:
|
| 330 |
+
df = pd.DataFrame({
|
| 331 |
+
"Peptide Sequence": validSequences,
|
| 332 |
+
"Binding Affinity": affinity if len(validSequences) else [0.0],
|
| 333 |
+
"Solubility": sol if len(validSequences) else [0.0],
|
| 334 |
+
"Hemolysis": hemo if len(validSequences) else [0.0],
|
| 335 |
+
"Nonfouling": nf if len(validSequences) else [0.0],
|
| 336 |
+
"Permeability": permeability if len(validSequences) else [0.0],
|
| 337 |
+
})
|
| 338 |
+
return x_rollout, affinity, sol, hemo, nf, permeability, valid_fraction, df
|
| 339 |
+
|
| 340 |
+
return x_rollout, affinity, sol, hemo, nf, permeability, valid_fraction
|
| 341 |
+
|
| 342 |
+
def compute_log_policy(self, token_array, x_next, t, dt, attn_mask=None):
|
| 343 |
+
torch.cuda.empty_cache()
|
| 344 |
+
self.backbone.eval()
|
| 345 |
+
self.noise.eval()
|
| 346 |
+
|
| 347 |
+
sigma_t, _ = self.noise(t)
|
| 348 |
+
|
| 349 |
+
if token_array.ndim == 1:
|
| 350 |
+
token_array = token_array.unsqueeze(0)
|
| 351 |
+
|
| 352 |
+
if x_next.ndim == 1:
|
| 353 |
+
x_next = x_next.unsqueeze(0)
|
| 354 |
+
|
| 355 |
+
if t.ndim > 1:
|
| 356 |
+
t = t.squeeze(-1)
|
| 357 |
+
assert t.ndim == 1
|
| 358 |
+
|
| 359 |
+
change_prob_t = t[:, None, None]
|
| 360 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 361 |
+
|
| 362 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 363 |
+
|
| 364 |
+
if attn_mask is None:
|
| 365 |
+
attn_mask = torch.ones_like(token_array).to(self.device)
|
| 366 |
+
|
| 367 |
+
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 368 |
+
p_x0 = log_p.exp()
|
| 369 |
+
|
| 370 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 371 |
+
|
| 372 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 373 |
+
|
| 374 |
+
# zero-masking probability
|
| 375 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 376 |
+
|
| 377 |
+
copy_flag = (token_array != self.mask_index)
|
| 378 |
+
|
| 379 |
+
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
|
| 380 |
+
changed_mask = (~copy_flag)
|
| 381 |
+
|
| 382 |
+
# compute the per-sequence log-probability under the pretrained model
|
| 383 |
+
log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1)
|
| 384 |
+
|
| 385 |
+
unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_policy_token.dtype)
|
| 386 |
+
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
|
| 387 |
+
|
| 388 |
+
# returns:
|
| 389 |
+
# log_policy_step (B, ) log probability x_next tokens under policy
|
| 390 |
+
if log_policy_step.ndim == 1:
|
| 391 |
+
log_policy_step = log_policy_step.squeeze(0)
|
| 392 |
+
|
| 393 |
+
return log_policy_step
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def single_reverse_step(self, token_array, t, dt, p_x0=None, attn_mask=None):
|
| 397 |
+
torch.cuda.empty_cache()
|
| 398 |
+
dev = self.device
|
| 399 |
+
self.backbone.to(dev).eval()
|
| 400 |
+
self.noise.eval()
|
| 401 |
+
|
| 402 |
+
t = t.to(dev)
|
| 403 |
+
dt = torch.as_tensor(dt, device=dev, dtype=t.dtype)
|
| 404 |
+
assert self.config.noise.type == 'loglinear'
|
| 405 |
+
sigma_t, _ = self.noise(t)
|
| 406 |
+
sigma_t = sigma_t.to(dev)
|
| 407 |
+
|
| 408 |
+
if t.ndim > 1:
|
| 409 |
+
t = t.squeeze(-1)
|
| 410 |
+
assert t.ndim == 1
|
| 411 |
+
|
| 412 |
+
change_prob_t = t[:, None, None]
|
| 413 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 414 |
+
|
| 415 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 416 |
+
|
| 417 |
+
if attn_mask is None:
|
| 418 |
+
attn_mask = torch.ones_like(token_array, device=dev, dtype=torch.long)
|
| 419 |
+
else:
|
| 420 |
+
attn_mask = attn_mask.to(dev)
|
| 421 |
+
|
| 422 |
+
if p_x0 is None:
|
| 423 |
+
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 424 |
+
p_x0 = log_p.exp()
|
| 425 |
+
else:
|
| 426 |
+
# ensure provided p_x0 is on dev
|
| 427 |
+
log_p = None
|
| 428 |
+
p_x0 = p_x0.to(dev)
|
| 429 |
+
|
| 430 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 431 |
+
|
| 432 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 433 |
+
|
| 434 |
+
# zero-masking probability
|
| 435 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 436 |
+
|
| 437 |
+
x_changed = _sample_categorical(q_xs)
|
| 438 |
+
if x_changed.device != dev or x_changed.dtype != token_array.dtype:
|
| 439 |
+
x_changed = x_changed.to(dev, dtype=token_array.dtype)
|
| 440 |
+
|
| 441 |
+
copy_flag = (token_array != self.mask_index)
|
| 442 |
+
|
| 443 |
+
int_copy_flag = copy_flag.to(token_array.dtype)
|
| 444 |
+
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
|
| 445 |
+
|
| 446 |
+
# returns:
|
| 447 |
+
# log_p (B, L, D) log probabilties of each token under the policy model
|
| 448 |
+
# x_next (B, L) next sequences
|
| 449 |
+
return log_p, x_next
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def single_noise_removal(self, token_array, t, dt, p_x0=None, attn_mask=None):
|
| 453 |
+
torch.cuda.empty_cache()
|
| 454 |
+
self.backbone.eval()
|
| 455 |
+
self.noise.eval()
|
| 456 |
+
|
| 457 |
+
assert self.config.noise.type == 'loglinear'
|
| 458 |
+
sigma_t, _ = self.noise(t)
|
| 459 |
+
|
| 460 |
+
if t.ndim > 1:
|
| 461 |
+
t = t.squeeze(-1)
|
| 462 |
+
assert t.ndim == 1
|
| 463 |
+
|
| 464 |
+
change_prob_t = t[:, None, None]
|
| 465 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 466 |
+
|
| 467 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 468 |
+
|
| 469 |
+
if attn_mask is None:
|
| 470 |
+
attn_mask = torch.ones_like(token_array).to(self.device)
|
| 471 |
+
|
| 472 |
+
if p_x0 is None:
|
| 473 |
+
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 474 |
+
p_x0 = log_p.exp()
|
| 475 |
+
|
| 476 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 477 |
+
|
| 478 |
+
# changed for noise removal
|
| 479 |
+
p_x0 = p_x0.clone()
|
| 480 |
+
p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask
|
| 481 |
+
p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK
|
| 482 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 483 |
+
|
| 484 |
+
x_changed = _sample_categorical(q_xs)
|
| 485 |
+
|
| 486 |
+
copy_flag = (token_array != self.mask_index)
|
| 487 |
+
|
| 488 |
+
int_copy_flag = copy_flag.to(token_array.dtype)
|
| 489 |
+
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
|
| 490 |
+
|
| 491 |
+
# returns:
|
| 492 |
+
# log_p (B, L, D) log probabilties of each token under the policy model
|
| 493 |
+
# x_next (B, L) next sequences
|
| 494 |
+
return log_p, x_next
|
| 495 |
+
|
| 496 |
+
def mcts_reverse_step(self, token_array, t, dt, pretrained, p_x0=None, attn_mask=None):
|
| 497 |
+
torch.cuda.empty_cache()
|
| 498 |
+
self.backbone.eval()
|
| 499 |
+
self.noise.eval()
|
| 500 |
+
assert self.config.noise.type == 'loglinear'
|
| 501 |
+
sigma_t, _ = self.noise(t)
|
| 502 |
+
|
| 503 |
+
if t.ndim > 1:
|
| 504 |
+
t = t.squeeze(-1)
|
| 505 |
+
assert t.ndim == 1
|
| 506 |
+
|
| 507 |
+
change_prob_t = t[:, None, None]
|
| 508 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 509 |
+
|
| 510 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 511 |
+
|
| 512 |
+
if attn_mask is None:
|
| 513 |
+
attn_mask = torch.ones_like(token_array).to(self.device)
|
| 514 |
+
|
| 515 |
+
if p_x0 is None:
|
| 516 |
+
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 517 |
+
p_x0 = log_p.exp()
|
| 518 |
+
|
| 519 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 520 |
+
|
| 521 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 522 |
+
|
| 523 |
+
# zero-masking probability
|
| 524 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 525 |
+
|
| 526 |
+
x_changed = _sample_categorical(q_xs)
|
| 527 |
+
|
| 528 |
+
copy_flag = (token_array != self.mask_index)
|
| 529 |
+
|
| 530 |
+
int_copy_flag = copy_flag.to(token_array.dtype)
|
| 531 |
+
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
|
| 532 |
+
|
| 533 |
+
# compute the log-probability under pretrained model at each step
|
| 534 |
+
with torch.no_grad():
|
| 535 |
+
# pretrained should output log-probs over vocab at each position given the *parent* (masked) input
|
| 536 |
+
log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 537 |
+
|
| 538 |
+
# log-prob of the *sampled token* at each position
|
| 539 |
+
log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
|
| 540 |
+
|
| 541 |
+
# sum only over the sites actually sampled this step (i.e., where parent was mask)
|
| 542 |
+
|
| 543 |
+
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
|
| 544 |
+
changed_mask = (~copy_flag)
|
| 545 |
+
# mask of tokens that were unmasked in this step
|
| 546 |
+
unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype)
|
| 547 |
+
|
| 548 |
+
log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
|
| 549 |
+
|
| 550 |
+
# compute the per-sequence log-probability under the pretrained model
|
| 551 |
+
log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
|
| 552 |
+
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
|
| 553 |
+
|
| 554 |
+
# returns:
|
| 555 |
+
# log_p (B, L, D) log probabilties of each token under the policy model
|
| 556 |
+
# x_next (B, L) next sequences
|
| 557 |
+
# log_policy_step (B, ) log probability of all unmasked tokens under policy
|
| 558 |
+
# log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
|
| 559 |
+
return log_p, x_next, log_policy_step, log_pretrained_step
|
| 560 |
+
|
| 561 |
+
def mcts_noise_removal(self, token_array, t, dt, pretrained, p_x0=None, attn_mask=None):
|
| 562 |
+
torch.cuda.empty_cache()
|
| 563 |
+
self.backbone.eval()
|
| 564 |
+
self.noise.eval()
|
| 565 |
+
|
| 566 |
+
assert self.config.noise.type == 'loglinear'
|
| 567 |
+
sigma_t, _ = self.noise(t)
|
| 568 |
+
|
| 569 |
+
if t.ndim > 1:
|
| 570 |
+
t = t.squeeze(-1)
|
| 571 |
+
assert t.ndim == 1
|
| 572 |
+
|
| 573 |
+
change_prob_t = t[:, None, None]
|
| 574 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 575 |
+
|
| 576 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 577 |
+
|
| 578 |
+
if attn_mask is None:
|
| 579 |
+
attn_mask = torch.ones_like(token_array).to(self.device)
|
| 580 |
+
|
| 581 |
+
if p_x0 is None:
|
| 582 |
+
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 583 |
+
p_x0 = log_p.exp()
|
| 584 |
+
|
| 585 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 586 |
+
|
| 587 |
+
# changed for noise removal
|
| 588 |
+
p_x0 = p_x0.clone()
|
| 589 |
+
p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask
|
| 590 |
+
p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK
|
| 591 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 592 |
+
|
| 593 |
+
x_changed = _sample_categorical(q_xs)
|
| 594 |
+
|
| 595 |
+
copy_flag = (token_array != self.mask_index)
|
| 596 |
+
|
| 597 |
+
int_copy_flag = copy_flag.to(token_array.dtype)
|
| 598 |
+
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
|
| 599 |
+
|
| 600 |
+
# compute the log-probability under pretrained model at each step
|
| 601 |
+
with torch.no_grad():
|
| 602 |
+
# pretrained should output log-probs over vocab at each position given the *parent* (masked) input
|
| 603 |
+
log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 604 |
+
|
| 605 |
+
# log-prob of the *sampled token* at each position
|
| 606 |
+
log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
|
| 607 |
+
|
| 608 |
+
# sum only over the sites actually sampled this step (i.e., where parent was mask)
|
| 609 |
+
|
| 610 |
+
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
|
| 611 |
+
changed_mask = (~copy_flag)
|
| 612 |
+
# mask of tokens that were unmasked in this step
|
| 613 |
+
unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype)
|
| 614 |
+
|
| 615 |
+
log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
|
| 616 |
+
|
| 617 |
+
# compute the per-sequence log-probability under the pretrained model
|
| 618 |
+
log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
|
| 619 |
+
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
|
| 620 |
+
|
| 621 |
+
# returns:
|
| 622 |
+
# log_p (B, L, D) log probabilties of each token under the policy model
|
| 623 |
+
# x_next (B, L) next sequences
|
| 624 |
+
# log_policy_step (B, ) log probability of all unmasked tokens under policy
|
| 625 |
+
# log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
|
| 626 |
+
return log_p, x_next, log_policy_step, log_pretrained_step
|
| 627 |
+
|
| 628 |
+
# first step in expansion
|
| 629 |
+
def batch_mcts_reverse_step(self, token_array, t, dt, batch_size, pretrained, p_x0=None, attn_mask=None):
|
| 630 |
+
torch.cuda.empty_cache()
|
| 631 |
+
self.backbone.eval()
|
| 632 |
+
self.noise.eval()
|
| 633 |
+
|
| 634 |
+
assert self.config.noise.type == 'loglinear'
|
| 635 |
+
sigma_t, _ = self.noise(t)
|
| 636 |
+
|
| 637 |
+
if t.ndim > 1:
|
| 638 |
+
t = t.squeeze(-1)
|
| 639 |
+
assert t.ndim == 1
|
| 640 |
+
|
| 641 |
+
change_prob_t = t[:, None, None]
|
| 642 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 643 |
+
|
| 644 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 645 |
+
|
| 646 |
+
if token_array.dim() == 1:
|
| 647 |
+
token_array = token_array.unsqueeze(0)
|
| 648 |
+
|
| 649 |
+
# expand to match (num_children, L)
|
| 650 |
+
if attn_mask is None:
|
| 651 |
+
attn_mask = torch.ones_like(token_array).to(self.device)
|
| 652 |
+
|
| 653 |
+
token_array = token_array.to(self.device)
|
| 654 |
+
sigma_t = sigma_t.to(self.device)
|
| 655 |
+
|
| 656 |
+
if p_x0 is None:
|
| 657 |
+
log_p = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 658 |
+
p_x0 = log_p.exp()
|
| 659 |
+
|
| 660 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 661 |
+
|
| 662 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 663 |
+
|
| 664 |
+
# zero-masking probability
|
| 665 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 666 |
+
|
| 667 |
+
# repeat the parent token along the first dimension which will be unmasked into distinct sequences
|
| 668 |
+
token_array_expanded = token_array.repeat(batch_size, 1)
|
| 669 |
+
|
| 670 |
+
if self.config.mcts.sampling == 0:
|
| 671 |
+
x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
|
| 672 |
+
else:
|
| 673 |
+
x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
|
| 674 |
+
|
| 675 |
+
copy_flag = (token_array_expanded != self.mask_index)
|
| 676 |
+
|
| 677 |
+
int_copy_flag = copy_flag.to(token_array.dtype)
|
| 678 |
+
x_children = int_copy_flag * token_array_expanded + (1 - int_copy_flag) * x_changed
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
# compute the log-probability under pretrained model at each step
|
| 682 |
+
with torch.no_grad():
|
| 683 |
+
# pretrained should output log-probs over vocab at each position given the *parent* (masked) input
|
| 684 |
+
log_pre = pretrained.forward(token_array, attn_mask=attn_mask, sigma=sigma_t)
|
| 685 |
+
|
| 686 |
+
# expand to match the shape of x_children
|
| 687 |
+
log_pre = log_pre.repeat(batch_size, 1, 1)
|
| 688 |
+
|
| 689 |
+
# log-prob of the *sampled token* at each position
|
| 690 |
+
log_pre_token = log_pre.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
|
| 691 |
+
|
| 692 |
+
# sum only over the sites actually sampled this step (i.e., where parent was mask)
|
| 693 |
+
|
| 694 |
+
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
|
| 695 |
+
changed_mask = (~copy_flag)
|
| 696 |
+
# mask of tokens that were unmasked in this step
|
| 697 |
+
unmasked_this_step = (changed_mask & (x_children != self.mask_index)).to(log_pre_token.dtype)
|
| 698 |
+
|
| 699 |
+
log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
|
| 700 |
+
|
| 701 |
+
# compute the per-child log-probability under the pretrained model
|
| 702 |
+
log_p = log_p.repeat(batch_size, 1, 1)
|
| 703 |
+
log_policy_token = log_p.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # (B, L) probability of each chosen token
|
| 704 |
+
#print(log_policy_token)
|
| 705 |
+
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
|
| 706 |
+
|
| 707 |
+
# returns:
|
| 708 |
+
# log_p (B, L, D) log probabilties of each token under the policy model
|
| 709 |
+
# x_children (B, L) child sequences
|
| 710 |
+
# log_policy_step (B, ) log probability of all unmasked tokens under policy
|
| 711 |
+
# log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
|
| 712 |
+
return log_p, x_children, log_policy_step, log_pretrained_step
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
def compute_invalid_loss(self, logits, k=None, temp=None):
|
| 716 |
+
"""
|
| 717 |
+
Penalizes logits that produce invalid sequences using the `is_peptide` function,
|
| 718 |
+
scaling penalties inversely with token probabilities.
|
| 719 |
+
|
| 720 |
+
Args:
|
| 721 |
+
logits: Tensor of shape [batch_size, seq_len, vocab_size].
|
| 722 |
+
k: Number of samples for Gumbel-Rao.
|
| 723 |
+
temp: Temperature for softmax.
|
| 724 |
+
|
| 725 |
+
Returns:
|
| 726 |
+
loss: A scalar tensor representing the total loss for invalid sequences.
|
| 727 |
+
"""
|
| 728 |
+
|
| 729 |
+
#samples = self.gumbel_rao(logits, k=k, temp=temp) # (batch_size, seq_len, vocab_size)
|
| 730 |
+
|
| 731 |
+
# Convert logits to sequences using the tokenizer
|
| 732 |
+
batch_token_ids = logits.argmax(dim=-1).to(self.device) # (batch_size, seq_len)
|
| 733 |
+
sampled_sequences = self.tokenizer.batch_decode(batch_token_ids)
|
| 734 |
+
|
| 735 |
+
# Check validity of each sampled sequence (not differentiable)
|
| 736 |
+
penalties = torch.tensor(
|
| 737 |
+
[1 if not self.analyzer.is_peptide(seq) else 0 for seq in sampled_sequences],
|
| 738 |
+
dtype=torch.float32,
|
| 739 |
+
device=self.device
|
| 740 |
+
)
|
| 741 |
+
#print(penalties)
|
| 742 |
+
|
| 743 |
+
# Compute probabilities for each token (batch_size, seq_length)
|
| 744 |
+
sampled_probs = torch.softmax(logits, dim=-1).gather(dim=-1, index=batch_token_ids.unsqueeze(-1)).squeeze(-1).to(self.device)
|
| 745 |
+
|
| 746 |
+
# scale penalties by softmax probability of sampled tokens
|
| 747 |
+
scaled_penalty = penalties[:, None] * sampled_probs # (batch_size, seq_length)
|
| 748 |
+
|
| 749 |
+
return scaled_penalty.to(self.device)
|
| 750 |
+
|
| 751 |
+
### DIFFUSION LOSS ###
|
| 752 |
+
|
| 753 |
+
def sample_t(self, n, device):
|
| 754 |
+
"""
|
| 755 |
+
Sample random time steps for batch training
|
| 756 |
+
"""
|
| 757 |
+
# sample values uniformly at random from [0, 1)
|
| 758 |
+
eps_t = torch.rand(n, device=device)
|
| 759 |
+
# antithetic sampling: reduce variance by pairing each sample with complementary sample
|
| 760 |
+
if self.config.training.antithetic_sampling:
|
| 761 |
+
# compute interval between sampled time steps
|
| 762 |
+
offset = torch.arange(n, device=device) / n
|
| 763 |
+
# ensure that each eps value is evenly spaced between [0, 1)
|
| 764 |
+
eps_t = ((eps_t / n) + offset) % 1
|
| 765 |
+
|
| 766 |
+
# ensures values are not exactly 0 or 1
|
| 767 |
+
t = (1 - self.config.training.sampling_eps) * eps_t + self.config.training.sampling_eps
|
| 768 |
+
|
| 769 |
+
return t
|
| 770 |
+
|
| 771 |
+
"""def mask_samples(self, x0, mask_prob):
|
| 772 |
+
|
| 773 |
+
# generate array of values in range [0, 1] uniformly at random
|
| 774 |
+
# will be used to determine which tokens are masked
|
| 775 |
+
mask_indices = torch.rand(* x0.shape, device=x0.device) # (batch_size, L)
|
| 776 |
+
|
| 777 |
+
# select tokens to mask if the random value in mask_indices is less than mask_prob
|
| 778 |
+
# this will mask approximately the fraction of tokens indicated by mask_prob
|
| 779 |
+
zt = torch.where(mask_indices < mask_prob, self.mask_index, x0)
|
| 780 |
+
|
| 781 |
+
return zt"""
|
| 782 |
+
|
| 783 |
+
def q_xt(self, x, mask_prob):
|
| 784 |
+
"""Computes the noisy sample xt.
|
| 785 |
+
|
| 786 |
+
Args:
|
| 787 |
+
x: int torch.Tensor with shape (batch_size,
|
| 788 |
+
diffusion_model_input_length), input.
|
| 789 |
+
move_chance: float torch.Tensor with shape (batch_size, 1).
|
| 790 |
+
"""
|
| 791 |
+
|
| 792 |
+
actual_seq_length = (x != 0).sum(dim=-1, keepdim=True)
|
| 793 |
+
#print(actual_seq_length)
|
| 794 |
+
|
| 795 |
+
max_mask_length = (actual_seq_length * 0.75).long()
|
| 796 |
+
|
| 797 |
+
mask_indices = torch.rand(*x.shape, device=x.device) < mask_prob
|
| 798 |
+
|
| 799 |
+
restricted_move_indices = torch.zeros_like(mask_indices, dtype=torch.bool)
|
| 800 |
+
|
| 801 |
+
for i in range(x.shape[0]):
|
| 802 |
+
true_positions = torch.where(mask_indices[i])[0]
|
| 803 |
+
if len(true_positions) > max_mask_length[i]:
|
| 804 |
+
selected_positions = true_positions[:max_mask_length[i].item()]
|
| 805 |
+
restricted_move_indices[i, selected_positions] = True
|
| 806 |
+
else:
|
| 807 |
+
restricted_move_indices[i] = mask_indices[i]
|
| 808 |
+
|
| 809 |
+
xt = torch.where(restricted_move_indices, self.tokenizer.mask_token_id, x)
|
| 810 |
+
|
| 811 |
+
return xt
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
def sample_prior(self, *batch_dims):
|
| 815 |
+
"""
|
| 816 |
+
Returns array of fully masked sequences with same shape as input
|
| 817 |
+
"""
|
| 818 |
+
return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64)
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
### COMPUTING LOSS ###
|
| 822 |
+
|
| 823 |
+
def compute_diffusion_loss(self, model_output, xt, x0, t):
|
| 824 |
+
"""
|
| 825 |
+
Computes diffusion loss term in ELBO
|
| 826 |
+
(evaluates how accurately the model predicts the token probabilities at each time step)
|
| 827 |
+
|
| 828 |
+
Inputs:
|
| 829 |
+
- model_output: [sequence length, vocab size, vocab size] array of logits for each token at each sequence position
|
| 830 |
+
- zt: corrupted version of original input x0 at timestep t
|
| 831 |
+
- x0: original input sequence
|
| 832 |
+
- t: timestep
|
| 833 |
+
"""
|
| 834 |
+
# compute interval between each timestep
|
| 835 |
+
dt = 1 / self.T
|
| 836 |
+
|
| 837 |
+
# compute vectorized alpha scaling terms for the logits at timestep s and t
|
| 838 |
+
alpha_t = 1 - t + torch.zeros_like(x0)
|
| 839 |
+
# s = t - dt
|
| 840 |
+
alpha_s = 1 - (t - dt) + torch.zeros_like(x0)
|
| 841 |
+
|
| 842 |
+
# gather vector of log-probabilities for each token in x0
|
| 843 |
+
# log<x_theta, x>
|
| 844 |
+
log_x_theta_at_x0 = torch.gather(model_output, -1, x0[:, :, None]) # shape (B, L, vocab_size)
|
| 845 |
+
# gather log-probabillities for assigning a masked token at each position in the sequence at time t
|
| 846 |
+
# log<x_theta, m>
|
| 847 |
+
log_x_theta_at_m = model_output[:, :, self.mask_index]
|
| 848 |
+
# obtain non-log probability of assigning a masked token
|
| 849 |
+
# <xt, m>
|
| 850 |
+
x_theta_at_m = log_x_theta_at_m.exp()
|
| 851 |
+
|
| 852 |
+
# first term of diffusion loss
|
| 853 |
+
term_1_coef = dt / t
|
| 854 |
+
term_1_log_numerator = torch.log((alpha_t * x_theta_at_m) / t + 1)
|
| 855 |
+
term_1_log_denom = log_x_theta_at_x0
|
| 856 |
+
|
| 857 |
+
# second term of diffusion loss
|
| 858 |
+
term_2_coef = 1 - (dt / t)
|
| 859 |
+
term_2_log_numerator = term_1_log_numerator
|
| 860 |
+
term_2_log_denom = torch.log((alpha_s * x_theta_at_m) / (t - dt) + 1)
|
| 861 |
+
|
| 862 |
+
L_vb_masked = (term_1_coef * (term_1_log_numerator - term_1_log_denom) +
|
| 863 |
+
term_2_coef * (term_2_log_numerator - term_2_log_denom))
|
| 864 |
+
|
| 865 |
+
# multiply by <zt, m> term
|
| 866 |
+
L_vb = L_vb_masked * (xt == self.mask_index)
|
| 867 |
+
|
| 868 |
+
# scale by T and return
|
| 869 |
+
return self.T * L_vb
|
| 870 |
+
|
| 871 |
+
def _forward_pass_diffusion(self, x0, attn_mask, bond_mask=None, mask=None):
|
| 872 |
+
"""
|
| 873 |
+
Training reverse diffusion model x_theta to reconstruct samples x0
|
| 874 |
+
|
| 875 |
+
bond_mask: (batch, seq_length)
|
| 876 |
+
"""
|
| 877 |
+
# randomly sample time steps to start the denoising process for each x0 in batch
|
| 878 |
+
t = self.sample_t(x0.shape[0], self.device)
|
| 879 |
+
|
| 880 |
+
# if we are training the intermediate transition blocks
|
| 881 |
+
if self.T > 0:
|
| 882 |
+
# scale by total timesteps T and cast to integer
|
| 883 |
+
t = (t * self.T).to(torch.int)
|
| 884 |
+
# scale down by T to get a multiple of 1/T
|
| 885 |
+
t = t / self.T
|
| 886 |
+
# add 1/T to ensure no 0 values
|
| 887 |
+
t += (1 / self.T)
|
| 888 |
+
|
| 889 |
+
# get noise and rate of noise at timestep t
|
| 890 |
+
# sigma = -log(1-t); dsigma = 1 / (1-t)
|
| 891 |
+
sigma, dsigma = self.noise(t)
|
| 892 |
+
time_conditioning = sigma[:, None]
|
| 893 |
+
|
| 894 |
+
# Get masking probabilities for all tokens for each batch
|
| 895 |
+
# log-linear: 1 - alpha = t
|
| 896 |
+
base_mask_prob = 1 - torch.exp(-sigma[:, None]) # (batch_size, L)
|
| 897 |
+
|
| 898 |
+
if self.config.noise.state_dependent and (bond_mask is not None):
|
| 899 |
+
# log-polynomial masking schedule: alpha = 1 - t^w
|
| 900 |
+
# bond_sigma = -log(1-t^w) for w = 3 (default)
|
| 901 |
+
# bond_dsigma = -wt^(w-1) / (1-t^w)
|
| 902 |
+
bond_sigma, bond_dsigma = self.bond_noise(t) # scalar
|
| 903 |
+
# expand dimensions for broadcasting to (B, L)
|
| 904 |
+
bond_sigma = bond_sigma[:, None]
|
| 905 |
+
bond_dsigma = bond_dsigma[:, None]
|
| 906 |
+
sigma = sigma[:, None]
|
| 907 |
+
dsigma = dsigma[:, None]
|
| 908 |
+
|
| 909 |
+
# compute masking probability for peptide bonds 1 - bond_alpha = t^w
|
| 910 |
+
bond_mask_prob = 1 - torch.exp(-bond_sigma).to(self.device)
|
| 911 |
+
# piece together (B, L) tensor with modified masking prob at peptide-bond locations
|
| 912 |
+
mask_prob = torch.where(bond_mask == 1, bond_mask_prob, base_mask_prob).to(self.device)
|
| 913 |
+
#print(mask_prob)
|
| 914 |
+
dsigma = torch.where(bond_mask == 1, bond_dsigma, dsigma).to(self.device)
|
| 915 |
+
sigma = torch.where(bond_mask == 1, bond_sigma, sigma).to(self.device)
|
| 916 |
+
else:
|
| 917 |
+
mask_prob = base_mask_prob.to(self.device)
|
| 918 |
+
|
| 919 |
+
# get masked samples at different timesteps
|
| 920 |
+
if mask is None:
|
| 921 |
+
zt = self.q_xt(x0, mask_prob).to(self.device)
|
| 922 |
+
else:
|
| 923 |
+
zt = x0.where(mask==1, torch.full_like(x0, self.mask_index)).to(self.device)
|
| 924 |
+
|
| 925 |
+
model_output = self.forward(zt, attn_mask=attn_mask.to(self.device), sigma=time_conditioning).to(self.device)
|
| 926 |
+
|
| 927 |
+
# debugging
|
| 928 |
+
assert not torch.isnan(model_output).any()
|
| 929 |
+
assert model_output.is_cuda
|
| 930 |
+
utils.print_nans(model_output, 'model_output')
|
| 931 |
+
|
| 932 |
+
# compute invalid loss
|
| 933 |
+
invalid_loss = self.compute_invalid_loss(logits=model_output).to(self.device) # (B, L)
|
| 934 |
+
#print(invalid_loss)
|
| 935 |
+
|
| 936 |
+
if self.T > 0:
|
| 937 |
+
# compute diffusion loss
|
| 938 |
+
diffusion_loss = self.compute_diffusion_loss(model_output, zt, x0, t)
|
| 939 |
+
return diffusion_loss
|
| 940 |
+
|
| 941 |
+
# compute loss for the final that converts from z0 to x0
|
| 942 |
+
# -log(p_theta)
|
| 943 |
+
# get (batch_size, L) array of log-probabilities
|
| 944 |
+
log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1).to(self.device) # (B, L)
|
| 945 |
+
|
| 946 |
+
if self.config.noise.state_dependent and (bond_mask is not None):
|
| 947 |
+
return (-log_p_theta * (dsigma / torch.expm1(sigma)) + invalid_loss).to(self.device)
|
| 948 |
+
else:
|
| 949 |
+
return ((-log_p_theta * (dsigma / torch.expm1(sigma))[:, None]) + invalid_loss).to(self.device)
|
| 950 |
+
|
| 951 |
+
def _loss(self, x0, attn_mask, bond_mask=None, mask=None):
|
| 952 |
+
loss = self._forward_pass_diffusion(x0, attn_mask, bond_mask, mask)
|
| 953 |
+
|
| 954 |
+
# negative log loss
|
| 955 |
+
nlls = loss * attn_mask
|
| 956 |
+
|
| 957 |
+
# count number of tokens
|
| 958 |
+
num_tokens = attn_mask.sum()
|
| 959 |
+
|
| 960 |
+
# compute batch loss
|
| 961 |
+
batch_nll = nlls.sum()
|
| 962 |
+
# compute per token loss
|
| 963 |
+
token_nll = batch_nll / num_tokens
|
| 964 |
+
# return losses
|
| 965 |
+
return Loss(loss = token_nll.to(self.device), nlls = nlls.to(self.device), attn_mask = attn_mask.to(self.device))
|
| 966 |
+
|
| 967 |
+
def _compute_loss(self, batch, prefix, bond_mask=None):
|
| 968 |
+
|
| 969 |
+
attn_mask = batch['attention_mask'].to(self.device)
|
| 970 |
+
|
| 971 |
+
if 'mask' in batch:
|
| 972 |
+
mask = batch['mask'].to(self.device)
|
| 973 |
+
else:
|
| 974 |
+
mask = None
|
| 975 |
+
|
| 976 |
+
if 'bond_mask' in batch:
|
| 977 |
+
bond_mask = batch['bond_mask'].to(self.device)
|
| 978 |
+
else:
|
| 979 |
+
bond_mask = None
|
| 980 |
+
|
| 981 |
+
losses = self._loss(batch['input_ids'].to(self.device), attn_mask, bond_mask, mask)
|
| 982 |
+
loss = losses.loss
|
| 983 |
+
|
| 984 |
+
if prefix == 'train':
|
| 985 |
+
self.train_metrics.update(
|
| 986 |
+
losses.nlls.to(self.device),
|
| 987 |
+
losses.attn_mask.to(self.device)
|
| 988 |
+
)
|
| 989 |
+
metrics = self.train_metrics
|
| 990 |
+
elif prefix == 'val':
|
| 991 |
+
self.valid_metrics.update(
|
| 992 |
+
losses.nlls.to(self.device),
|
| 993 |
+
losses.attn_mask.to(self.device)
|
| 994 |
+
)
|
| 995 |
+
metrics = self.valid_metrics
|
| 996 |
+
elif prefix == 'test':
|
| 997 |
+
self.test_metrics.update(losses.nlls, losses.attn_mask)
|
| 998 |
+
metrics = self.test_metrics
|
| 999 |
+
else:
|
| 1000 |
+
raise ValueError(f'Invalid prefix: {prefix}')
|
| 1001 |
+
|
| 1002 |
+
self.log_dict(metrics,
|
| 1003 |
+
on_step=False,
|
| 1004 |
+
on_epoch=True,
|
| 1005 |
+
sync_dist=True)
|
| 1006 |
+
|
| 1007 |
+
return loss
|
| 1008 |
+
|
| 1009 |
+
|
| 1010 |
+
### SAMPLING ###
|
| 1011 |
+
|
| 1012 |
+
def generate_from_masked(self, num_samples=None, seq_length=None, sample_steps=128, eps=1e-5):
|
| 1013 |
+
# get number of timesteps
|
| 1014 |
+
if sample_steps is None:
|
| 1015 |
+
sample_steps = self.config.sampling.steps
|
| 1016 |
+
|
| 1017 |
+
if seq_length is None:
|
| 1018 |
+
seq_length = self.config.sampling.seq_length
|
| 1019 |
+
|
| 1020 |
+
# sample fully masked sequences
|
| 1021 |
+
z = self.sample_prior(num_samples, seq_length).to(self.device)
|
| 1022 |
+
|
| 1023 |
+
# create vector of sample_steps timesteps
|
| 1024 |
+
timesteps = torch.linspace(1, eps, sample_steps + 1, device=self.device)
|
| 1025 |
+
|
| 1026 |
+
# compute interval between timesteps
|
| 1027 |
+
dt = (1 - eps) / sample_steps
|
| 1028 |
+
|
| 1029 |
+
for i in range(sample_steps):
|
| 1030 |
+
t = timesteps[i] * torch.ones(z.shape[0], 1, device=self.device)
|
| 1031 |
+
|
| 1032 |
+
z = self.single_reverse_step(z, t, dt)
|
| 1033 |
+
|
| 1034 |
+
return z
|
| 1035 |
+
|
| 1036 |
+
|
| 1037 |
+
### SAMPLING STEP ###
|
| 1038 |
+
"""
|
| 1039 |
+
def single_reverse_step(self, zt, t, dt, attn_mask=None):
|
| 1040 |
+
# get sigma values that determine masking prob
|
| 1041 |
+
sigma_t, _ = self.noise(t)
|
| 1042 |
+
sigma_s, _ = self.noise(t - dt)
|
| 1043 |
+
|
| 1044 |
+
# reshape sigmas
|
| 1045 |
+
if sigma_t.ndim > 1:
|
| 1046 |
+
sigma_t = sigma_t.squeeze(-1)
|
| 1047 |
+
if sigma_s.ndim > 1:
|
| 1048 |
+
sigma_s = sigma_s.squeeze(-1)
|
| 1049 |
+
assert sigma_t.ndim == 1, sigma_t.shape
|
| 1050 |
+
assert sigma_s.ndim == 1, sigma_s.shape
|
| 1051 |
+
|
| 1052 |
+
# compute masking probabilities for each timestep
|
| 1053 |
+
change_prob_t = 1 - torch.exp(-sigma_t)
|
| 1054 |
+
change_prob_s = 1 - torch.exp(-sigma_s)
|
| 1055 |
+
|
| 1056 |
+
# expand dimensions
|
| 1057 |
+
change_prob_t = change_prob_t[:, None, None]
|
| 1058 |
+
change_prob_s = change_prob_s[:, None, None]
|
| 1059 |
+
|
| 1060 |
+
# get prodiction model that outputs token probabilities
|
| 1061 |
+
log_p_x0 = self.forward(zt, attn_mask=attn_mask, sigma=sigma_t)
|
| 1062 |
+
|
| 1063 |
+
# check dimensions match
|
| 1064 |
+
assert change_prob_t.ndim == log_p_x0.ndim
|
| 1065 |
+
|
| 1066 |
+
# compute reverse diffusion probability of being unmasked at timestep s
|
| 1067 |
+
# (sigma_s - sigma_t)*x_theta
|
| 1068 |
+
q_zs = log_p_x0.exp() * (change_prob_t - change_prob_s)
|
| 1069 |
+
|
| 1070 |
+
# compute reverse diffusion probability of remaining masked at timestep s
|
| 1071 |
+
# (1 - sigma_s)*m
|
| 1072 |
+
q_zs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 1073 |
+
|
| 1074 |
+
# sample sequence at timestep s from categorical distribution of q_zs
|
| 1075 |
+
z_changed = _sample_categorical(q_zs)
|
| 1076 |
+
|
| 1077 |
+
copy_flag = (zt != self.mask_index).to(zt.dtype)
|
| 1078 |
+
return (copy_flag * zt) + ((1 - copy_flag) * z_changed)"""
|
| 1079 |
+
|
| 1080 |
+
def cached_reverse_step(self, x, t, dt, p_x0=None, attn_mask=None):
|
| 1081 |
+
assert self.config.noise.type == 'loglinear'
|
| 1082 |
+
sigma_t, _ = self.noise(t)
|
| 1083 |
+
|
| 1084 |
+
if t.ndim > 1:
|
| 1085 |
+
t = t.squeeze(-1)
|
| 1086 |
+
assert t.ndim == 1
|
| 1087 |
+
|
| 1088 |
+
change_prob_t = t[:, None, None]
|
| 1089 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 1090 |
+
|
| 1091 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 1092 |
+
|
| 1093 |
+
if p_x0 is None:
|
| 1094 |
+
p_x0 = self.forward(x, attn_mask=attn_mask, sigma=sigma_t).exp()
|
| 1095 |
+
|
| 1096 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 1097 |
+
|
| 1098 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 1099 |
+
|
| 1100 |
+
# zero-masking probability
|
| 1101 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 1102 |
+
|
| 1103 |
+
x_changed = _sample_categorical(q_xs)
|
| 1104 |
+
|
| 1105 |
+
copy_flag = (x != self.mask_index).to(x.dtype)
|
| 1106 |
+
|
| 1107 |
+
return p_x0, copy_flag * x + (1 - copy_flag) * x_changed
|
| 1108 |
+
|
| 1109 |
+
# first step in expansion
|
| 1110 |
+
def batch_cached_reverse_step(self, token_array, t, dt, batch_size, p_x0=None, attn_mask=None):
|
| 1111 |
+
"""
|
| 1112 |
+
Generates batch_size different samples from the same starting point for the
|
| 1113 |
+
first expansion step of MCTS
|
| 1114 |
+
"""
|
| 1115 |
+
|
| 1116 |
+
assert self.config.noise.type == 'loglinear'
|
| 1117 |
+
sigma_t, _ = self.noise(t)
|
| 1118 |
+
|
| 1119 |
+
if t.ndim > 1:
|
| 1120 |
+
t = t.squeeze(-1)
|
| 1121 |
+
assert t.ndim == 1
|
| 1122 |
+
|
| 1123 |
+
change_prob_t = t[:, None, None]
|
| 1124 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 1125 |
+
|
| 1126 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 1127 |
+
|
| 1128 |
+
if token_array.dim() == 1:
|
| 1129 |
+
token_array = token_array.unsqueeze(0)
|
| 1130 |
+
#token_array = token_array.repeat(batch_size, 1)
|
| 1131 |
+
|
| 1132 |
+
attn_mask = torch.ones_like(token_array).to(self.device)
|
| 1133 |
+
|
| 1134 |
+
if p_x0 is None:
|
| 1135 |
+
p_x0 = self.forward(token_array, attn_mask=attn_mask, sigma=sigma_t).exp()
|
| 1136 |
+
|
| 1137 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 1138 |
+
|
| 1139 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 1140 |
+
|
| 1141 |
+
# zero-masking probability
|
| 1142 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 1143 |
+
|
| 1144 |
+
# repeat the parent token along the first dimension which will be unmasked into distinct sequences
|
| 1145 |
+
token_array = token_array.repeat(batch_size, 1)
|
| 1146 |
+
|
| 1147 |
+
if self.config.mcts.sampling == 0:
|
| 1148 |
+
x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
|
| 1149 |
+
else:
|
| 1150 |
+
x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
|
| 1151 |
+
|
| 1152 |
+
copy_flag = (token_array != self.mask_index).to(token_array.dtype)
|
| 1153 |
+
|
| 1154 |
+
return p_x0, copy_flag * token_array + (1 - copy_flag) * x_changed
|
| 1155 |
+
|
| 1156 |
+
def _process_sigma(self, sigma):
|
| 1157 |
+
if sigma.ndim > 1:
|
| 1158 |
+
sigma = sigma.squeeze(-1)
|
| 1159 |
+
if not self.time_conditioning:
|
| 1160 |
+
sigma = torch.zeros_like(sigma)
|
| 1161 |
+
assert sigma.ndim == 1, sigma.shape
|
| 1162 |
+
return sigma
|
| 1163 |
+
|
| 1164 |
+
def forward(self, zt, attn_mask, sigma):
|
| 1165 |
+
"""
|
| 1166 |
+
Predicts the token log-probabilities from zt at time t with noise schedule sigma
|
| 1167 |
+
"""
|
| 1168 |
+
sigma = self._process_sigma(sigma)
|
| 1169 |
+
|
| 1170 |
+
with torch.cuda.amp.autocast(dtype=torch.float32):
|
| 1171 |
+
logits = self.backbone(zt, attn_mask).to(self.device)
|
| 1172 |
+
|
| 1173 |
+
return self.subs_parameterization(logits, zt)
|
| 1174 |
+
|
| 1175 |
+
def subs_parameterization(self, logits, zt):
|
| 1176 |
+
"""
|
| 1177 |
+
Updates reverse diffusion logits based on SUBS parameterization:
|
| 1178 |
+
- zero masking probabilities: -infinity probability of being masked during reverse diffusion
|
| 1179 |
+
- carry-over unmasking: unmasked input tokens remain unchanged during reverse diffusion
|
| 1180 |
+
|
| 1181 |
+
Args:
|
| 1182 |
+
logits: vector of token probabilities for unmasking masked tokens
|
| 1183 |
+
zt: partially unmasked sequence at current timestep
|
| 1184 |
+
"""
|
| 1185 |
+
logits[:, :, self.mask_index] += self.neg_infinity # [sequence index, current token, next token]
|
| 1186 |
+
|
| 1187 |
+
|
| 1188 |
+
logits = (logits - torch.logsumexp(logits, dim=-1, keepdim=True)).to(self.device)
|
| 1189 |
+
|
| 1190 |
+
|
| 1191 |
+
unmasked_indices = (zt != self.mask_index).to(self.device) # shape: [200, seq_length]
|
| 1192 |
+
batch_idx, seq_idx = torch.where(unmasked_indices) # Get explicit indices
|
| 1193 |
+
batch_idx = batch_idx.to(self.device)
|
| 1194 |
+
seq_idx = seq_idx.to(self.device)
|
| 1195 |
+
tokens = zt[batch_idx, seq_idx].to(self.device) # Get the tokens at those positions
|
| 1196 |
+
|
| 1197 |
+
#assert logits.is_contiguous(), "logits tensor is not contiguous"
|
| 1198 |
+
#assert unmasked_indices.shape == zt.shape, "same shape"
|
| 1199 |
+
#assert not torch.isnan(logits).any(), "NaN values found in logits"
|
| 1200 |
+
#assert tokens.max() < logits.shape[-1], "token indices out of bounds"
|
| 1201 |
+
#assert batch_idx.max() < logits.shape[0], "batch index out of bounds"
|
| 1202 |
+
#assert seq_idx.max() < logits.shape[1], "seq index out of bounds"
|
| 1203 |
+
#assert batch_idx.device == seq_idx.device == logits.device == tokens.device, "device inconsistent"
|
| 1204 |
+
|
| 1205 |
+
logits[unmasked_indices] = self.neg_infinity # Set everything to -inf first
|
| 1206 |
+
logits[unmasked_indices, zt[unmasked_indices]] = 0 # Set only the specific token positions to 0
|
| 1207 |
+
# return logits with SUBS parameterization
|
| 1208 |
+
return logits.to(self.device)
|
| 1209 |
+
|
| 1210 |
+
"""SAMPLING"""
|
| 1211 |
+
@torch.no_grad()
|
| 1212 |
+
def _sample(self, num_steps=None, eps=1e-5, x_input=None):
|
| 1213 |
+
"""
|
| 1214 |
+
Generate samples
|
| 1215 |
+
"""
|
| 1216 |
+
batch_size_per_gpu = self.config.eval.perplexity_batch_size
|
| 1217 |
+
|
| 1218 |
+
if num_steps is None:
|
| 1219 |
+
num_steps = self.config.sampling.steps
|
| 1220 |
+
|
| 1221 |
+
if x_input is not None:
|
| 1222 |
+
x = x_input['input_ids'].to(self.device)
|
| 1223 |
+
attn_mask = x_input['attention_mask'].to(self.device)
|
| 1224 |
+
else:
|
| 1225 |
+
x = self.sample_prior(batch_size_per_gpu, self.config.model.length).to(self.device)
|
| 1226 |
+
attn_mask = torch.ones_like(x).to(self.device)
|
| 1227 |
+
|
| 1228 |
+
|
| 1229 |
+
timesteps = torch.linspace(1, eps, num_steps+1, device=self.device)
|
| 1230 |
+
dt = (1 - eps) / num_steps
|
| 1231 |
+
p_x0_cache = None
|
| 1232 |
+
generation_history = [] # used to track which tokens are unmasked
|
| 1233 |
+
|
| 1234 |
+
for i in range(num_steps):
|
| 1235 |
+
t = timesteps[i] * torch.ones(x.shape[0], 1, device = self.device)
|
| 1236 |
+
if self.sampler == 'ddpm':
|
| 1237 |
+
x = self.single_reverse_step(x, t, dt).to(self.device)
|
| 1238 |
+
|
| 1239 |
+
elif self.sampler == 'ddpm_cache':
|
| 1240 |
+
p_x0_cache, x_next = self.cached_reverse_step(x, t, dt, p_x0=p_x0_cache, attn_mask=attn_mask)
|
| 1241 |
+
if (not torch.allclose(x_next, x) or self.time_conditioning):
|
| 1242 |
+
# Disable caching
|
| 1243 |
+
p_x0_cache = None
|
| 1244 |
+
x = x_next.to(self.device)
|
| 1245 |
+
#print(self.tokenizer.decode(x.squeeze()))
|
| 1246 |
+
else:
|
| 1247 |
+
x = self._analytic_update(x, t, dt, attn_mask).to(self.device)
|
| 1248 |
+
|
| 1249 |
+
if self.config.sampling.noise_removal:
|
| 1250 |
+
t = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device)
|
| 1251 |
+
if self.sampler == 'analytic':
|
| 1252 |
+
x = self._denoiser_update(x, t).to(self.device)
|
| 1253 |
+
else:
|
| 1254 |
+
time_conditioning = self.noise(t)[0].to(self.device)
|
| 1255 |
+
x = self.forward(x, attn_mask=attn_mask, sigma=time_conditioning).argmax(dim=-1).to(self.device)
|
| 1256 |
+
#print(self.tokenizer.decode(x.squeeze()))
|
| 1257 |
+
return x.to(self.device)
|
| 1258 |
+
|
| 1259 |
+
|
| 1260 |
+
def restore_model_and_sample(self, num_steps, eps=1e-5):
|
| 1261 |
+
"""Generate samples from the model."""
|
| 1262 |
+
self.backbone.eval()
|
| 1263 |
+
self.noise.eval()
|
| 1264 |
+
samples = self._sample(num_steps=num_steps, eps=eps)
|
| 1265 |
+
self.backbone.train()
|
| 1266 |
+
self.noise.train()
|
| 1267 |
+
return samples
|
| 1268 |
+
|
| 1269 |
+
def get_score(self, zt, sigma, attn_mask=None):
|
| 1270 |
+
|
| 1271 |
+
# score(x, t) = p_t(y) / p_t(x)
|
| 1272 |
+
# => log score(x, t) = log p_t(y) - log p_t(x)
|
| 1273 |
+
|
| 1274 |
+
# case 1: x = masked
|
| 1275 |
+
# (i) y = unmasked
|
| 1276 |
+
# log score(x, t) = log p_\theta(x)|_y + log k
|
| 1277 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
| 1278 |
+
# (ii) y = masked
|
| 1279 |
+
# log score(x, t) = 0
|
| 1280 |
+
|
| 1281 |
+
# case 2: x = unmasked
|
| 1282 |
+
# (i) y != masked, y != x
|
| 1283 |
+
# log score(x_i, t) = - inf
|
| 1284 |
+
# (ii) y = x
|
| 1285 |
+
# log score(x_i, t) = 0
|
| 1286 |
+
# (iii) y = masked token
|
| 1287 |
+
# log score(x_i, t) = - log k
|
| 1288 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
| 1289 |
+
|
| 1290 |
+
model_output = self.forward(zt, attn_mask=attn_mask, sigma=sigma)
|
| 1291 |
+
|
| 1292 |
+
log_k = -torch.log(torch.expm1(sigma)).squeeze(-1)
|
| 1293 |
+
assert log_k.ndim == 1
|
| 1294 |
+
|
| 1295 |
+
masked_score = model_output + log_k[:, None, None]
|
| 1296 |
+
masked_score[:, :, self.mask_index] = 0
|
| 1297 |
+
|
| 1298 |
+
unmasked_score = self.neg_infinity * torch.ones_like(model_output)
|
| 1299 |
+
unmasked_score = torch.scatter(
|
| 1300 |
+
unmasked_score, -1,
|
| 1301 |
+
zt[..., None],
|
| 1302 |
+
torch.zeros_like(unmasked_score[..., :1]))
|
| 1303 |
+
|
| 1304 |
+
unmasked_score[:, :, self.mask_index] = - (log_k[:, None] * torch.ones_like(zt))
|
| 1305 |
+
|
| 1306 |
+
masked_indices = (zt == self.mask_index).to(model_output.dtype)[:, :, None]
|
| 1307 |
+
|
| 1308 |
+
model_output = (masked_score * masked_indices + unmasked_score * (1 - masked_indices))
|
| 1309 |
+
|
| 1310 |
+
return model_output.exp()
|
| 1311 |
+
|
| 1312 |
+
def _staggered_score(self, score, dsigma):
|
| 1313 |
+
score = score.clone()
|
| 1314 |
+
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
|
| 1315 |
+
score *= dsigma.exp()[:, None]
|
| 1316 |
+
score[..., self.mask_index] += extra_const
|
| 1317 |
+
return score
|
| 1318 |
+
|
| 1319 |
+
def _analytic_update(self, x, t, step_size, attn_mask=None):
|
| 1320 |
+
curr_sigma, _ = self.noise(t)
|
| 1321 |
+
next_sigma, _ = self.noise(t - step_size)
|
| 1322 |
+
dsigma = curr_sigma - next_sigma
|
| 1323 |
+
score = self.get_score(x, attn_mask, curr_sigma)
|
| 1324 |
+
stag_score = self._staggered_score(score, dsigma)
|
| 1325 |
+
probs = stag_score * self._transp_transition(x, dsigma)
|
| 1326 |
+
return _sample_categorical(probs)
|
| 1327 |
+
|
| 1328 |
+
def _denoiser_update(self, x, t):
|
| 1329 |
+
sigma, _ = self.noise(t)
|
| 1330 |
+
score = self.get_score(x, sigma)
|
| 1331 |
+
stag_score = self._staggered_score(score, sigma)
|
| 1332 |
+
probs = stag_score * self._transp_transition(x, sigma)
|
| 1333 |
+
probs[..., self.mask_index] = 0
|
| 1334 |
+
samples = _sample_categorical(probs)
|
| 1335 |
+
return samples
|
| 1336 |
+
|
| 1337 |
+
def _transp_transition(self, i, sigma):
|
| 1338 |
+
sigma = unsqueeze(sigma, reference=i[..., None])
|
| 1339 |
+
edge = torch.exp(-sigma) * F.one_hot(
|
| 1340 |
+
i, num_classes=self.vocab_size)
|
| 1341 |
+
edge += torch.where(i == self.mask_index,
|
| 1342 |
+
1 - torch.exp(-sigma).squeeze(-1),
|
| 1343 |
+
0)[..., None]
|
| 1344 |
+
return edge
|
| 1345 |
+
|
| 1346 |
+
|
| 1347 |
+
"""TRAINING from https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py"""
|
| 1348 |
+
|
| 1349 |
+
def on_train_epoch_start(self):
|
| 1350 |
+
torch.cuda.empty_cache()
|
| 1351 |
+
self.backbone.train()
|
| 1352 |
+
self.noise.train()
|
| 1353 |
+
|
| 1354 |
+
|
| 1355 |
+
def training_step(self, batch, batch_idx):
|
| 1356 |
+
# Initialize throughput calculation
|
| 1357 |
+
start_time = time.time()
|
| 1358 |
+
|
| 1359 |
+
if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
|
| 1360 |
+
loss = self._compute_loss(batch, prefix='train', bond_mask=batch['bond_mask'])
|
| 1361 |
+
else:
|
| 1362 |
+
loss = self._compute_loss(batch, prefix='train')
|
| 1363 |
+
|
| 1364 |
+
self.log(name='trainer/loss',
|
| 1365 |
+
value=loss.item(),
|
| 1366 |
+
on_step=True,
|
| 1367 |
+
on_epoch=False,
|
| 1368 |
+
sync_dist=True)
|
| 1369 |
+
|
| 1370 |
+
# Calculate throughput
|
| 1371 |
+
elapsed_time = time.time() - start_time
|
| 1372 |
+
total_tokens = batch['input_ids'].numel()
|
| 1373 |
+
throughput = total_tokens / elapsed_time
|
| 1374 |
+
|
| 1375 |
+
self.log(name='trainer/throughput',
|
| 1376 |
+
value=throughput,
|
| 1377 |
+
on_step=True,
|
| 1378 |
+
on_epoch=False,
|
| 1379 |
+
sync_dist=True)
|
| 1380 |
+
|
| 1381 |
+
return loss
|
| 1382 |
+
|
| 1383 |
+
|
| 1384 |
+
def on_load_checkpoint(self, checkpoint):
|
| 1385 |
+
self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
|
| 1386 |
+
self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed']
|
| 1387 |
+
|
| 1388 |
+
### VALIDATION ###
|
| 1389 |
+
def on_validation_epoch_start(self):
|
| 1390 |
+
gc.collect()
|
| 1391 |
+
torch.cuda.empty_cache()
|
| 1392 |
+
self.backbone.eval()
|
| 1393 |
+
self.noise.eval()
|
| 1394 |
+
assert self.valid_metrics.nll.mean_value == 0
|
| 1395 |
+
assert self.valid_metrics.nll.weight == 0
|
| 1396 |
+
|
| 1397 |
+
def validation_step(self, batch, batch_idx):
|
| 1398 |
+
if self.config.vocab == 'old_smiles' or self.config.vocab == 'new_smiles':
|
| 1399 |
+
loss = self._compute_loss(batch, prefix='val', bond_mask=batch['bond_mask'])
|
| 1400 |
+
else:
|
| 1401 |
+
loss = self._compute_loss(batch, prefix='val')
|
| 1402 |
+
|
| 1403 |
+
self.log(name='trainer/val_loss',
|
| 1404 |
+
value=loss.item(),
|
| 1405 |
+
on_step=True,
|
| 1406 |
+
on_epoch=False,
|
| 1407 |
+
prog_bar=True,
|
| 1408 |
+
sync_dist=True)
|
| 1409 |
+
return loss
|
| 1410 |
+
|
| 1411 |
+
def on_validation_epoch_end(self):
|
| 1412 |
+
gc.collect()
|
| 1413 |
+
torch.cuda.empty_cache()
|
| 1414 |
+
|
| 1415 |
+
### OPTIMIZATION ###
|
| 1416 |
+
|
| 1417 |
+
def optimizer_step(self, *args, **kwargs):
|
| 1418 |
+
super().optimizer_step(*args, **kwargs)
|
| 1419 |
+
|
| 1420 |
+
gc.collect()
|
| 1421 |
+
torch.cuda.empty_cache()
|
| 1422 |
+
|
| 1423 |
+
def configure_optimizers(self):
|
| 1424 |
+
optimizer = torch.optim.AdamW(
|
| 1425 |
+
itertools.chain(self.backbone.parameters(),self.noise.parameters()),
|
| 1426 |
+
lr=self.config.optim.lr,
|
| 1427 |
+
betas=(self.config.optim.beta1, self.config.optim.beta2),
|
| 1428 |
+
eps=self.config.optim.eps,
|
| 1429 |
+
weight_decay=self.config.optim.weight_decay
|
| 1430 |
+
)
|
| 1431 |
+
|
| 1432 |
+
self.total_steps = self.config.trainer.max_steps
|
| 1433 |
+
scheduler = CosineWarmup(optimizer,
|
| 1434 |
+
warmup_steps=self.config.lr_scheduler.num_warmup_steps,
|
| 1435 |
+
total_steps=self.total_steps)
|
| 1436 |
+
|
| 1437 |
+
scheduler_dict = {
|
| 1438 |
+
'scheduler': scheduler,
|
| 1439 |
+
'interval': 'step',
|
| 1440 |
+
'frequency': 1,
|
| 1441 |
+
'monitor': 'val/loss',
|
| 1442 |
+
'name': 'trainer/lr'
|
| 1443 |
+
}
|
| 1444 |
+
|
| 1445 |
+
return [optimizer], [scheduler_dict]
|
| 1446 |
+
|
| 1447 |
+
@torch.no_grad()
|
| 1448 |
+
def compute_masked_perplexity(self, generated_ids, input_ids):
|
| 1449 |
+
"""
|
| 1450 |
+
Computes masked perplexity between array of generated token ids and masked ids that are converted to logits
|
| 1451 |
+
"""
|
| 1452 |
+
|
| 1453 |
+
total_nll = 0
|
| 1454 |
+
total_tokens = 0
|
| 1455 |
+
|
| 1456 |
+
input_ids = torch.tensor(input_ids).to(self.device)
|
| 1457 |
+
#print(input_ids)
|
| 1458 |
+
|
| 1459 |
+
for sequence in generated_ids:
|
| 1460 |
+
# tokenize the sequence
|
| 1461 |
+
|
| 1462 |
+
gt_ids = torch.tensor(sequence).to(self.device)
|
| 1463 |
+
#print(gt_ids)
|
| 1464 |
+
|
| 1465 |
+
sys.stdout.flush()
|
| 1466 |
+
|
| 1467 |
+
# forward pass thorugh backbone peptideclm model
|
| 1468 |
+
attn_mask = torch.ones_like(input_ids).to(self.device)
|
| 1469 |
+
|
| 1470 |
+
# compute logits using backbone
|
| 1471 |
+
|
| 1472 |
+
if self.config.mode in ['train', 'ppl_eval']:
|
| 1473 |
+
outputs = self.backbone.forward(input_ids=input_ids, attn_mask=attn_mask)
|
| 1474 |
+
elif self.config.mode == 'sample_eval':
|
| 1475 |
+
outputs = self.backbone.forward(input_ids=input_ids)
|
| 1476 |
+
|
| 1477 |
+
|
| 1478 |
+
# get logits for each position in sequence across all tokens in vocab
|
| 1479 |
+
#logits = outputs[-1] # (batch_size, seq_length, vocab_size)
|
| 1480 |
+
|
| 1481 |
+
logits = outputs.view(-1, outputs.size(-1))
|
| 1482 |
+
gt_ids = gt_ids.view(-1)
|
| 1483 |
+
|
| 1484 |
+
#print(logits.shape)
|
| 1485 |
+
#print(gt_ids.shape)
|
| 1486 |
+
|
| 1487 |
+
# compute loss
|
| 1488 |
+
# shift_logits = logits[:, :-1, :].contiguous() # remove eos
|
| 1489 |
+
# shift_labels = input_ids[:, 1:].contiguous()
|
| 1490 |
+
# print(masked)
|
| 1491 |
+
|
| 1492 |
+
loss = F.cross_entropy(logits,
|
| 1493 |
+
gt_ids.where(input_ids==self.mask_index, torch.full_like(gt_ids, -100)).view(-1),
|
| 1494 |
+
reduction='sum')
|
| 1495 |
+
|
| 1496 |
+
total_nll += loss.item()
|
| 1497 |
+
# count all non-padding tokens
|
| 1498 |
+
total_tokens += input_ids.ne(self.tokenizer.pad_token_id).sum().item() # count in bos and eos
|
| 1499 |
+
|
| 1500 |
+
# compute pseudo-perplexity
|
| 1501 |
+
# print(total_nll, ",;,", total_tokens)
|
| 1502 |
+
pseudo_perplexity = torch.exp(torch.tensor(total_nll / total_tokens))
|
| 1503 |
+
self.gen_ppl_metric.update(pseudo_perplexity)
|
| 1504 |
+
|
| 1505 |
+
return pseudo_perplexity.item()
|
| 1506 |
+
|
| 1507 |
+
|
| 1508 |
+
def unsqueeze(x, reference):
|
| 1509 |
+
return x.view(* x.shape, * ((1,) * (len(reference.shape) - len(x.shape))))
|
| 1510 |
+
|
| 1511 |
+
class CosineWarmup(_LRScheduler):
|
| 1512 |
+
def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
|
| 1513 |
+
self.warmup_steps = warmup_steps
|
| 1514 |
+
self.total_steps = total_steps
|
| 1515 |
+
self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
|
| 1516 |
+
super(CosineWarmup, self).__init__(optimizer, last_epoch)
|
| 1517 |
+
|
| 1518 |
+
def get_lr(self):
|
| 1519 |
+
if self.last_epoch < self.warmup_steps:
|
| 1520 |
+
return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
|
| 1521 |
+
|
| 1522 |
+
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
| 1523 |
+
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
|
| 1524 |
+
decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
|
| 1525 |
+
|
| 1526 |
+
return [decayed_lr * base_lr for base_lr in self.base_lrs]
|
tr2d2-pep/finetune.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# direct reward backpropagation
|
| 2 |
+
from diffusion import Diffusion
|
| 3 |
+
from hydra import initialize, compose
|
| 4 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 5 |
+
import numpy as np
|
| 6 |
+
from scipy.stats import pearsonr
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import argparse
|
| 10 |
+
import wandb
|
| 11 |
+
import os
|
| 12 |
+
import datetime
|
| 13 |
+
from finetune_peptides import finetune
|
| 14 |
+
from peptide_mcts import MCTS
|
| 15 |
+
from utils.utils import str2bool, set_seed
|
| 16 |
+
from scoring.scoring_functions import ScoringFunctions
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 20 |
+
argparser.add_argument('--base_path', type=str, default='')
|
| 21 |
+
argparser.add_argument('--learning_rate', type=float, default=1e-4)
|
| 22 |
+
argparser.add_argument('--num_epochs', type=int, default=100)
|
| 23 |
+
argparser.add_argument('--num_accum_steps', type=int, default=4)
|
| 24 |
+
argparser.add_argument('--truncate_steps', type=int, default=50)
|
| 25 |
+
argparser.add_argument("--truncate_kl", type=str2bool, default=False)
|
| 26 |
+
argparser.add_argument('--gumbel_temp', type=float, default=1.0)
|
| 27 |
+
argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
|
| 28 |
+
argparser.add_argument('--batch_size', type=int, default=32)
|
| 29 |
+
argparser.add_argument('--name', type=str, default='debug')
|
| 30 |
+
argparser.add_argument('--total_num_steps', type=int, default=128)
|
| 31 |
+
argparser.add_argument('--copy_flag_temp', type=float, default=None)
|
| 32 |
+
argparser.add_argument('--save_every_n_epochs', type=int, default=10)
|
| 33 |
+
argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
|
| 34 |
+
argparser.add_argument("--seed", type=int, default=0)
|
| 35 |
+
# new
|
| 36 |
+
argparser.add_argument('--run_name', type=str, default='peptides')
|
| 37 |
+
argparser.add_argument("--device", default="cuda:0", type=str)
|
| 38 |
+
argparser.add_argument("--save_path_dir", default="/scratch/pranamlab/sophtang/home/tr2d2/peptides/checkpoints/", type=str)
|
| 39 |
+
# mcts
|
| 40 |
+
argparser.add_argument('--num_sequences', type=int, default=10)
|
| 41 |
+
argparser.add_argument('--num_children', type=int, default=50)
|
| 42 |
+
argparser.add_argument('--num_iter', type=int, default=30) # iterations of mcts
|
| 43 |
+
argparser.add_argument('--seq_length', type=int, default=200)
|
| 44 |
+
argparser.add_argument('--time_conditioning', action='store_true', default=False)
|
| 45 |
+
argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
|
| 46 |
+
argparser.add_argument('--buffer_size', type=int, default=100)
|
| 47 |
+
argparser.add_argument('--wdce_num_replicates', type=int, default=16)
|
| 48 |
+
argparser.add_argument('--noise_removal', action='store_true', default=False)
|
| 49 |
+
argparser.add_argument('--grad_clip', action='store_true', default=False)
|
| 50 |
+
argparser.add_argument('--resample_every_n_step', type=int, default=10)
|
| 51 |
+
argparser.add_argument('--exploration', type=float, default=0.1)
|
| 52 |
+
argparser.add_argument('--reset_every_n_step', type=int, default=100)
|
| 53 |
+
argparser.add_argument('--alpha', type=float, default=0.01)
|
| 54 |
+
argparser.add_argument('--scalarization', type=str, default='sum')
|
| 55 |
+
argparser.add_argument('--no_mcts', action='store_true', default=False)
|
| 56 |
+
argparser.add_argument("--centering", action='store_true', default=False)
|
| 57 |
+
|
| 58 |
+
# objectives
|
| 59 |
+
argparser.add_argument('--num_obj', type=int, default=5)
|
| 60 |
+
argparser.add_argument('--prot_seq', type=str, default=None)
|
| 61 |
+
argparser.add_argument('--prot_name', type=str, default=None)
|
| 62 |
+
|
| 63 |
+
args = argparser.parse_args()
|
| 64 |
+
print(args)
|
| 65 |
+
|
| 66 |
+
# pretrained model path
|
| 67 |
+
ckpt_path = f'{args.base_path}/TR2-D2/tr2d2-pep/pretrained/peptune-pretrained.ckpt'
|
| 68 |
+
|
| 69 |
+
# reinitialize Hydra
|
| 70 |
+
GlobalHydra.instance().clear()
|
| 71 |
+
|
| 72 |
+
# Initialize Hydra and compose the configuration
|
| 73 |
+
initialize(config_path="configs", job_name="load_model")
|
| 74 |
+
cfg = compose(config_name="peptune_config.yaml")
|
| 75 |
+
curr_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 76 |
+
|
| 77 |
+
# proteins
|
| 78 |
+
amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
|
| 79 |
+
tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
|
| 80 |
+
gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
|
| 81 |
+
glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
|
| 82 |
+
glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
|
| 83 |
+
ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
|
| 84 |
+
cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL'
|
| 85 |
+
ligase = 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS'
|
| 86 |
+
skp2 = 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL'
|
| 87 |
+
|
| 88 |
+
if args.prot_seq is not None:
|
| 89 |
+
prot = args.prot_seq
|
| 90 |
+
prot_name = args.prot_name
|
| 91 |
+
filename = args.prot_name
|
| 92 |
+
else:
|
| 93 |
+
prot = tfr
|
| 94 |
+
prot_name = "tfr"
|
| 95 |
+
filename = "tfr"
|
| 96 |
+
|
| 97 |
+
if args.no_mcts:
|
| 98 |
+
args.run_name = f'{prot_name}_resample{args.resample_every_n_step}_no-mcts'
|
| 99 |
+
else:
|
| 100 |
+
args.run_name = f'{prot_name}_resample{args.resample_every_n_step}_buffer{args.buffer_size}_numiter{args.num_iter}_children{args.num_children}_{curr_time}'
|
| 101 |
+
|
| 102 |
+
args.save_path = os.path.join(args.save_path_dir, args.run_name)
|
| 103 |
+
os.makedirs(args.save_path, exist_ok=True)
|
| 104 |
+
# wandb init
|
| 105 |
+
wandb.init(project='tree-multi', name=args.run_name, config=args, dir=args.save_path)
|
| 106 |
+
|
| 107 |
+
log_path = os.path.join(args.save_path, 'log.txt')
|
| 108 |
+
|
| 109 |
+
set_seed(args.seed, use_cuda=True)
|
| 110 |
+
|
| 111 |
+
# Initialize the model
|
| 112 |
+
policy_model = Diffusion.load_from_checkpoint(ckpt_path,
|
| 113 |
+
config=cfg,
|
| 114 |
+
mode="train",
|
| 115 |
+
device=args.device,
|
| 116 |
+
map_location=args.device)
|
| 117 |
+
pretrained = Diffusion.load_from_checkpoint(ckpt_path,
|
| 118 |
+
config=cfg,
|
| 119 |
+
mode="eval",
|
| 120 |
+
device=args.device,
|
| 121 |
+
map_location=args.device)
|
| 122 |
+
|
| 123 |
+
# define mcts
|
| 124 |
+
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
|
| 125 |
+
|
| 126 |
+
mcts = MCTS(args, cfg, policy_model, pretrained, score_func_names, prot_seqs=[prot])
|
| 127 |
+
|
| 128 |
+
if args.no_mcts:
|
| 129 |
+
reward_model = ScoringFunctions(score_func_names, prot_seqs=[prot], device=args.device)
|
| 130 |
+
finetune(args, cfg, policy_model, reward_model=reward_model, mcts=None, pretrained=pretrained, filename=filename, prot_name=prot_name)
|
| 131 |
+
else:
|
| 132 |
+
mcts = MCTS(args, cfg, policy_model, pretrained, score_func_names, prot_seqs=[prot])
|
| 133 |
+
finetune(args, cfg, policy_model, reward_model=mcts.rewardFunc, mcts=mcts, pretrained=None, filename=filename, prot_name=prot_name)
|
tr2d2-pep/finetune.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home
|
| 4 |
+
ENV_PATH=/path/to/your/env
|
| 5 |
+
SCRIPT_LOC=$HOME_LOC/TR2-D2/tr2d2-pep
|
| 6 |
+
LOG_LOC=$HOME_LOC/TR2-D2/tr2d2-pep/logs
|
| 7 |
+
DATE=$(date +%m_%d)
|
| 8 |
+
SPECIAL_PREFIX='tr2d2-finetune-tfr'
|
| 9 |
+
# set 3 have skip connection
|
| 10 |
+
PYTHON_EXECUTABLE=$ENV_PATH/bin/python
|
| 11 |
+
|
| 12 |
+
# ===================================================================
|
| 13 |
+
|
| 14 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 15 |
+
conda activate $ENV_PATH
|
| 16 |
+
|
| 17 |
+
# ===================================================================
|
| 18 |
+
|
| 19 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/finetune.py \
|
| 20 |
+
--base_path $HOME_LOC \
|
| 21 |
+
--device "cuda:6" \
|
| 22 |
+
--noise_removal \
|
| 23 |
+
--wdce_num_replicates 16 \
|
| 24 |
+
--buffer_size 20 \
|
| 25 |
+
--seq_length 200 \
|
| 26 |
+
--num_children 50 \
|
| 27 |
+
--total_num_steps 128 \
|
| 28 |
+
--num_iter 10 \
|
| 29 |
+
--resample_every_n_step 10 \
|
| 30 |
+
--num_epochs 1000 \
|
| 31 |
+
--exploration 0.1 \
|
| 32 |
+
--save_every_n_epoch 50 \
|
| 33 |
+
--reset_every_n_step 1 \
|
| 34 |
+
--alpha 0.1 \
|
| 35 |
+
--grad_clip > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}.log" 2>&1
|
| 36 |
+
|
| 37 |
+
conda deactivate
|
tr2d2-pep/finetune_peptides.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# direct reward backpropagation
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import wandb
|
| 5 |
+
import os
|
| 6 |
+
from finetune_utils import loss_wdce
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from plotting import plot_data_with_distribution_seaborn, plot_data
|
| 10 |
+
|
| 11 |
+
def finetune(args, cfg, policy_model, reward_model, mcts=None, pretrained=None, filename=None, prot_name=None, eps=1e-5):
|
| 12 |
+
"""
|
| 13 |
+
Finetuning with WDCE loss
|
| 14 |
+
"""
|
| 15 |
+
base_path = args.base_path
|
| 16 |
+
dt = (1 - eps) / args.total_num_steps
|
| 17 |
+
|
| 18 |
+
if args.no_mcts:
|
| 19 |
+
assert pretrained is not None, "pretrained model is required for no mcts"
|
| 20 |
+
else:
|
| 21 |
+
assert mcts is not None, "mcts is required for mcts"
|
| 22 |
+
|
| 23 |
+
# set model to train mode
|
| 24 |
+
policy_model.train()
|
| 25 |
+
torch.set_grad_enabled(True)
|
| 26 |
+
optim = torch.optim.AdamW(policy_model.parameters(), lr=args.learning_rate)
|
| 27 |
+
|
| 28 |
+
# record metrics
|
| 29 |
+
batch_losses = []
|
| 30 |
+
#batch_rewards = []
|
| 31 |
+
|
| 32 |
+
# initialize the final seqs and log_rnd of the trajectories that generated those seqs
|
| 33 |
+
x_saved, log_rnd_saved, final_rewards_saved = None, None, None
|
| 34 |
+
|
| 35 |
+
valid_fraction_log = []
|
| 36 |
+
affinity_log = []
|
| 37 |
+
sol_log = []
|
| 38 |
+
hemo_log = []
|
| 39 |
+
nf_log = []
|
| 40 |
+
permeability_log = []
|
| 41 |
+
|
| 42 |
+
### End of Fine-Tuning Loop ###
|
| 43 |
+
pbar = tqdm(range(args.num_epochs))
|
| 44 |
+
|
| 45 |
+
for epoch in pbar:
|
| 46 |
+
# store metrics
|
| 47 |
+
rewards = []
|
| 48 |
+
losses = []
|
| 49 |
+
|
| 50 |
+
policy_model.train()
|
| 51 |
+
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
if x_saved is None or epoch % args.resample_every_n_step == 0:
|
| 54 |
+
# compute final sequences and trajectory log_rnd
|
| 55 |
+
if args.no_mcts:
|
| 56 |
+
x_final, log_rnd, final_rewards = policy_model.sample_finetuned_with_rnd(args, reward_model, pretrained)
|
| 57 |
+
else:
|
| 58 |
+
# decides whether to reset tree
|
| 59 |
+
if (epoch) % args.reset_every_n_step == 0:
|
| 60 |
+
x_final, log_rnd, final_rewards, _, _ = mcts.forward(resetTree=True)
|
| 61 |
+
else:
|
| 62 |
+
x_final, log_rnd, final_rewards, _, _ = mcts.forward(resetTree=False)
|
| 63 |
+
|
| 64 |
+
# save for next iteration
|
| 65 |
+
x_saved, log_rnd_saved, final_rewards_saved = x_final, log_rnd, final_rewards
|
| 66 |
+
else:
|
| 67 |
+
x_final, log_rnd, final_rewards = x_saved, log_rnd_saved, final_rewards_saved
|
| 68 |
+
|
| 69 |
+
# compute wdce loss
|
| 70 |
+
loss = loss_wdce(policy_model, log_rnd, x_final, num_replicates=args.wdce_num_replicates, centering=args.centering)
|
| 71 |
+
|
| 72 |
+
# gradient descent
|
| 73 |
+
loss.backward()
|
| 74 |
+
|
| 75 |
+
# optimizer
|
| 76 |
+
if args.grad_clip:
|
| 77 |
+
torch.nn.utils.clip_grad_norm_(policy_model.parameters(), args.gradnorm_clip)
|
| 78 |
+
|
| 79 |
+
optim.step()
|
| 80 |
+
optim.zero_grad()
|
| 81 |
+
|
| 82 |
+
pbar.set_postfix(loss=loss.item())
|
| 83 |
+
|
| 84 |
+
# sample a eval batch with updated policy to evaluate rewards
|
| 85 |
+
x_eval, affinity, sol, hemo, nf, permeability, valid_fraction = policy_model.sample_finetuned(args, reward_model, batch_size=50, dataframe=False)
|
| 86 |
+
|
| 87 |
+
# append to log
|
| 88 |
+
affinity_log.append(affinity)
|
| 89 |
+
sol_log.append(sol)
|
| 90 |
+
hemo_log.append(hemo)
|
| 91 |
+
nf_log.append(nf)
|
| 92 |
+
permeability_log.append(permeability)
|
| 93 |
+
valid_fraction_log.append(valid_fraction)
|
| 94 |
+
|
| 95 |
+
batch_losses.append(loss.cpu().detach().numpy())
|
| 96 |
+
|
| 97 |
+
losses.append(loss.cpu().detach().numpy())
|
| 98 |
+
losses = np.array(losses)
|
| 99 |
+
|
| 100 |
+
if args.no_mcts:
|
| 101 |
+
mean_reward_search = final_rewards.mean().item()
|
| 102 |
+
min_reward_search = final_rewards.min().item()
|
| 103 |
+
max_reward_search = final_rewards.max().item()
|
| 104 |
+
median_reward_search = final_rewards.median().item()
|
| 105 |
+
else:
|
| 106 |
+
mean_reward_search = np.mean(final_rewards)
|
| 107 |
+
min_reward_search = np.min(final_rewards)
|
| 108 |
+
max_reward_search = np.max(final_rewards)
|
| 109 |
+
median_reward_search = np.median(final_rewards)
|
| 110 |
+
|
| 111 |
+
print("epoch %d"%epoch, "affinity %f"%np.mean(affinity), "sol %f"%np.mean(sol), "hemo %f"%np.mean(hemo), "nf %f"%np.mean(nf), "permeability %f"%np.mean(permeability), "mean loss %f"%np.mean(losses))
|
| 112 |
+
|
| 113 |
+
wandb.log({"epoch": epoch, "affinity": np.mean(affinity), "sol": np.mean(sol), "hemo": np.mean(hemo), "nf": np.mean(nf), "permeability": np.mean(permeability),
|
| 114 |
+
"mean_loss": np.mean(losses),
|
| 115 |
+
"mean_reward_search": mean_reward_search, "min_reward_search": min_reward_search,
|
| 116 |
+
"max_reward_search": max_reward_search, "median_reward_search": median_reward_search})
|
| 117 |
+
|
| 118 |
+
if (epoch+1) % args.save_every_n_epochs == 0:
|
| 119 |
+
model_path = os.path.join(args.save_path, f'model_{epoch}.ckpt')
|
| 120 |
+
torch.save(policy_model.state_dict(), model_path)
|
| 121 |
+
print(f"model saved at epoch {epoch}")
|
| 122 |
+
|
| 123 |
+
### End of Fine-Tuning Loop ###
|
| 124 |
+
|
| 125 |
+
wandb.finish()
|
| 126 |
+
|
| 127 |
+
# save logs and plot
|
| 128 |
+
plot_path = f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}'
|
| 129 |
+
os.makedirs(plot_path, exist_ok=True)
|
| 130 |
+
output_log_path = f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/log_{filename}.csv'
|
| 131 |
+
save_logs_to_file(valid_fraction_log, affinity_log,
|
| 132 |
+
sol_log, hemo_log, nf_log,
|
| 133 |
+
permeability_log, output_log_path)
|
| 134 |
+
|
| 135 |
+
plot_data(valid_fraction_log,
|
| 136 |
+
save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/valid_{filename}.png')
|
| 137 |
+
|
| 138 |
+
plot_data_with_distribution_seaborn(log1=affinity_log,
|
| 139 |
+
save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/binding_{filename}.png',
|
| 140 |
+
label1=f"Average Binding Affinity to {prot_name}",
|
| 141 |
+
title=f"Average Binding Affinity to {prot_name} Over Iterations")
|
| 142 |
+
|
| 143 |
+
plot_data_with_distribution_seaborn(log1=sol_log,
|
| 144 |
+
save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/sol_{filename}.png',
|
| 145 |
+
label1="Average Solubility Score",
|
| 146 |
+
title="Average Solubility Score Over Iterations")
|
| 147 |
+
plot_data_with_distribution_seaborn(log1=hemo_log,
|
| 148 |
+
save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/hemo_{filename}.png',
|
| 149 |
+
label1="Average Hemolysis Score",
|
| 150 |
+
title="Average Hemolysis Score Over Iterations")
|
| 151 |
+
plot_data_with_distribution_seaborn(log1=nf_log,
|
| 152 |
+
save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/nf_{filename}.png',
|
| 153 |
+
label1="Average Nonfouling Score",
|
| 154 |
+
title="Average Nonfouling Score Over Iterations")
|
| 155 |
+
plot_data_with_distribution_seaborn(log1=permeability_log,
|
| 156 |
+
save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/perm_{filename}.png',
|
| 157 |
+
label1="Average Permeability Score",
|
| 158 |
+
title="Average Permeability Score Over Iterations")
|
| 159 |
+
|
| 160 |
+
x_eval, affinity, sol, hemo, nf, permeability, valid_fraction, df = policy_model.sample_finetuned(args, reward_model, batch_size=200, dataframe=True)
|
| 161 |
+
df.to_csv(f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/{prot_name}_generation_results.csv', index=False)
|
| 162 |
+
|
| 163 |
+
return batch_losses
|
| 164 |
+
|
| 165 |
+
def save_logs_to_file(valid_fraction_log, affinity_log,
|
| 166 |
+
sol_log, hemo_log, nf_log,
|
| 167 |
+
permeability_log, output_path):
|
| 168 |
+
"""
|
| 169 |
+
Saves the logs (valid_fraction_log, affinity1_log, and permeability_log) to a CSV file.
|
| 170 |
+
|
| 171 |
+
Parameters:
|
| 172 |
+
valid_fraction_log (list): Log of valid fractions over iterations.
|
| 173 |
+
affinity1_log (list): Log of binding affinity over iterations.
|
| 174 |
+
permeability_log (list): Log of membrane permeability over iterations.
|
| 175 |
+
output_path (str): Path to save the log CSV file.
|
| 176 |
+
"""
|
| 177 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 178 |
+
|
| 179 |
+
# Combine logs into a DataFrame
|
| 180 |
+
log_data = {
|
| 181 |
+
"Iteration": list(range(1, len(valid_fraction_log) + 1)),
|
| 182 |
+
"Valid Fraction": valid_fraction_log,
|
| 183 |
+
"Binding Affinity": affinity_log,
|
| 184 |
+
"Solubility": sol_log,
|
| 185 |
+
"Hemolysis": hemo_log,
|
| 186 |
+
"Nonfouling": nf_log,
|
| 187 |
+
"Permeability": permeability_log
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
df = pd.DataFrame(log_data)
|
| 191 |
+
|
| 192 |
+
# Save to CSV
|
| 193 |
+
df.to_csv(output_path, index=False)
|
tr2d2-pep/finetune_utils.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 4 |
+
from utils.utils import sample_categorical_logits
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
def to_one_hot(x_idx, num_classes=4):
|
| 11 |
+
oh = F.one_hot(x_idx.long(), num_classes=num_classes)
|
| 12 |
+
return oh.float()
|
| 13 |
+
|
| 14 |
+
def rnd(model, reward_model, batch_size, scale=1, device='cuda:0'):
|
| 15 |
+
r"""
|
| 16 |
+
Run random order sampling and compute the RND $\log\frac{dP^*}{dP^u}$ along the trajectory
|
| 17 |
+
reward_model: r(X)
|
| 18 |
+
|
| 19 |
+
return:
|
| 20 |
+
- x: the final samples, [B, D]
|
| 21 |
+
- log_rnd: the log RND along this trajectory, [B]
|
| 22 |
+
"""
|
| 23 |
+
if hasattr(model, 'module'):
|
| 24 |
+
model = model.module
|
| 25 |
+
|
| 26 |
+
x = torch.full((batch_size, model.length), model.vocab_size-1).to(device=device, dtype=torch.int64)
|
| 27 |
+
batch_arange = torch.arange(batch_size, device=device)
|
| 28 |
+
jump_pos = torch.rand(x.shape, device=device).argsort(dim=-1)
|
| 29 |
+
# jump_times, jump_pos = torch.rand(x.shape, device=device).sort(dim=-1)
|
| 30 |
+
# jump_times: Unif[0,1] in increasing order
|
| 31 |
+
# jump_pos: random permutation of range(D)
|
| 32 |
+
log_rnd = torch.zeros(batch_size, device=device) # [B]
|
| 33 |
+
for d in range(model.length-1, -1, -1):
|
| 34 |
+
# jump at time jump_times[:, d] at position jump_pos[:, d]
|
| 35 |
+
logits = model(x)[:, :, :-1] # [B, D, N-1]
|
| 36 |
+
update = sample_categorical_logits(
|
| 37 |
+
logits[batch_arange, jump_pos[:, d]]) # [B]
|
| 38 |
+
if torch.is_grad_enabled(): # avoid issues with in-place operations
|
| 39 |
+
x = x.clone()
|
| 40 |
+
x[batch_arange, jump_pos[:, d]] = update
|
| 41 |
+
log_rnd += -np.log(model.vocab_size-1) - logits[batch_arange, jump_pos[:, d], update]
|
| 42 |
+
log_rnd += scale * reward_model(x) # [B]
|
| 43 |
+
return x, log_rnd
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@torch.no_grad()
|
| 47 |
+
def sampling(model, batch_size, rounds=1, device='cuda:0'):
|
| 48 |
+
"""Any order autoregressive sampling"""
|
| 49 |
+
if hasattr(model, 'module'):
|
| 50 |
+
model = model.module
|
| 51 |
+
batch_arange = torch.arange(batch_size, device=device)
|
| 52 |
+
all_samples = []
|
| 53 |
+
for _ in tqdm(range(rounds), leave=False):
|
| 54 |
+
x = torch.full((batch_size, model.length), model.vocab_size-1).to(device=device, dtype=torch.int64)
|
| 55 |
+
jump_pos = torch.rand(x.shape, device=device).argsort(dim=-1)
|
| 56 |
+
# jump_times, jump_pos = torch.rand(x.shape, device=device).sort(dim=-1)
|
| 57 |
+
# jump_times: Unif[0,1] in increasing order
|
| 58 |
+
# jump_pos: random permutation of range(D)
|
| 59 |
+
for d in tqdm(range(model.length-1, -1, -1), leave=False):
|
| 60 |
+
# jump at time jump_times[:, d] at position jump_pos[:, d]
|
| 61 |
+
logits = model.logits(x)[:, :, :-1] # [B, D, N-1], not log-softmaxed but fine
|
| 62 |
+
update = sample_categorical_logits(
|
| 63 |
+
logits[batch_arange, jump_pos[:, d]]) # [B]
|
| 64 |
+
x[batch_arange, jump_pos[:, d]] = update
|
| 65 |
+
all_samples.append(x)
|
| 66 |
+
return torch.cat(all_samples) # (rounds * B, L)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def loss_ce(log_rnd):
|
| 70 |
+
"""Cross entropy loss KL(P^*||P^u)"""
|
| 71 |
+
weights = log_rnd.detach().softmax(dim=-1)
|
| 72 |
+
return (log_rnd * weights).sum()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def loss_lv(log_rnd):
|
| 76 |
+
r"""Log variance loss Var_{P^\bar{u}}\log\frac{dP^*}{dP^u}"""
|
| 77 |
+
return log_rnd.var()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def loss_re_rf(log_rnd, const=0):
|
| 81 |
+
r"""Relative entropy loss KL(P^u||P^*) with REINFORCE trick"""
|
| 82 |
+
return (-log_rnd * (-log_rnd.detach() + const)).mean()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def loss_wdce(policy_model, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False):
|
| 86 |
+
r"""
|
| 87 |
+
Weighted denoising cross entropy loss
|
| 88 |
+
X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
|
| 89 |
+
|
| 90 |
+
log_rnd: [B]; x: [B, L] (no mask)
|
| 91 |
+
num_replicates: R, number of replicates of each row in x
|
| 92 |
+
weight_func: w(lambda) for each sample, 1/lambda by default
|
| 93 |
+
"""
|
| 94 |
+
mask_index = policy_model.mask_index
|
| 95 |
+
if hasattr(policy_model, 'module'):
|
| 96 |
+
policy_model = policy_model.module
|
| 97 |
+
|
| 98 |
+
batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
|
| 99 |
+
|
| 100 |
+
batch_weights = log_rnd.detach_().softmax(dim=-1) # [B*R]
|
| 101 |
+
if centering:
|
| 102 |
+
batch_weights = batch_weights - batch_weights.mean(dim=-1, keepdim=True)
|
| 103 |
+
|
| 104 |
+
batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0)
|
| 105 |
+
|
| 106 |
+
lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
|
| 107 |
+
lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
|
| 108 |
+
|
| 109 |
+
masked_index = torch.rand(*batch.shape, device=batch.device) < lamda[..., None] # [B*R, D]
|
| 110 |
+
perturbed_batch = torch.where(masked_index, mask_index, batch)
|
| 111 |
+
|
| 112 |
+
# add time conditioning
|
| 113 |
+
t = lamda
|
| 114 |
+
sigma_t = -torch.log1p(-(1 - eps) * t)
|
| 115 |
+
attn_mask = torch.ones_like(perturbed_batch).to(policy_model.device)
|
| 116 |
+
|
| 117 |
+
# compute logits
|
| 118 |
+
logits = policy_model(perturbed_batch, attn_mask=attn_mask, sigma=sigma_t)
|
| 119 |
+
losses = torch.zeros(*batch.shape, device=batch.device, dtype=logits.dtype) # [B*R, D]
|
| 120 |
+
losses[masked_index] = torch.gather(input=logits[masked_index], dim=-1,
|
| 121 |
+
index=batch[masked_index][..., None]).squeeze(-1)
|
| 122 |
+
return - (losses.sum(dim=-1) * lamda_weights * batch_weights).mean()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def loss_dce(model, x, weight_func=lambda l: 1/l):
|
| 126 |
+
r"""
|
| 127 |
+
Denoising cross entropy loss, x [B, D] are ground truth samples
|
| 128 |
+
weight_func: w(lambda) for each sample, 1/lambda by default
|
| 129 |
+
"""
|
| 130 |
+
lamda = torch.rand(x.shape[0], device=x.device) # [B]
|
| 131 |
+
lamda_weights = weight_func(lamda).clamp(max=1e5) # [B]
|
| 132 |
+
masked_index = torch.rand(*x.shape, device=x.device) < lamda[..., None] # [B, D]
|
| 133 |
+
perturbed_batch = torch.where(masked_index, model.vocab_size-1, x)
|
| 134 |
+
logits = model(perturbed_batch)
|
| 135 |
+
losses = torch.zeros(*x.shape, device=x.device, dtype=logits.dtype) # [B, D]
|
| 136 |
+
losses[masked_index] = torch.gather(input=logits[masked_index], dim=-1,
|
| 137 |
+
index=x[masked_index][..., None]).squeeze(-1)
|
| 138 |
+
return - (losses.sum(dim=-1) * lamda_weights).mean()
|
tr2d2-pep/generate_mcts.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import sys
|
| 7 |
+
from diffusion import Diffusion
|
| 8 |
+
import hydra
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import os
|
| 12 |
+
import seaborn as sns
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import numpy as np
|
| 15 |
+
import argparse
|
| 16 |
+
# direct reward backpropagation
|
| 17 |
+
from diffusion import Diffusion
|
| 18 |
+
from hydra import initialize, compose
|
| 19 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
import argparse
|
| 23 |
+
import os
|
| 24 |
+
import datetime
|
| 25 |
+
from utils.utils import str2bool, set_seed
|
| 26 |
+
|
| 27 |
+
# for peptides
|
| 28 |
+
from utils.app import PeptideAnalyzer
|
| 29 |
+
from peptide_mcts import MCTS
|
| 30 |
+
|
| 31 |
+
@torch.no_grad()
|
| 32 |
+
def generate_mcts(args, cfg, policy_model, pretrained, prot=None, prot_name=None, filename=None):
|
| 33 |
+
|
| 34 |
+
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
|
| 35 |
+
|
| 36 |
+
mcts = MCTS(args, cfg, policy_model, pretrained, score_func_names, prot_seqs=[prot])
|
| 37 |
+
|
| 38 |
+
final_x, log_rnd, final_rewards, score_vectors, sequences = mcts.forward()
|
| 39 |
+
|
| 40 |
+
return final_x, log_rnd, final_rewards, score_vectors, sequences
|
| 41 |
+
|
| 42 |
+
def save_logs_to_file(reward_log, logrnd_log,
|
| 43 |
+
valid_fraction_log, affinity1_log,
|
| 44 |
+
sol_log, hemo_log, nf_log,
|
| 45 |
+
permeability_log, output_path):
|
| 46 |
+
"""
|
| 47 |
+
Saves the logs (valid_fraction_log, affinity1_log, and permeability_log) to a CSV file.
|
| 48 |
+
|
| 49 |
+
Parameters:
|
| 50 |
+
valid_fraction_log (list): Log of valid fractions over iterations.
|
| 51 |
+
affinity1_log (list): Log of binding affinity over iterations.
|
| 52 |
+
permeability_log (list): Log of membrane permeability over iterations.
|
| 53 |
+
output_path (str): Path to save the log CSV file.
|
| 54 |
+
"""
|
| 55 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 56 |
+
|
| 57 |
+
# Combine logs into a DataFrame
|
| 58 |
+
log_data = {
|
| 59 |
+
"Iteration": list(range(1, len(valid_fraction_log) + 1)),
|
| 60 |
+
"Reward": reward_log,
|
| 61 |
+
"Log RND": logrnd_log,
|
| 62 |
+
"Valid Fraction": valid_fraction_log,
|
| 63 |
+
"Binding Affinity": affinity1_log,
|
| 64 |
+
"Solubility": sol_log,
|
| 65 |
+
"Hemolysis": hemo_log,
|
| 66 |
+
"Nonfouling": nf_log,
|
| 67 |
+
"Permeability": permeability_log
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
df = pd.DataFrame(log_data)
|
| 71 |
+
|
| 72 |
+
# Save to CSV
|
| 73 |
+
df.to_csv(output_path, index=False)
|
| 74 |
+
|
| 75 |
+
argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 76 |
+
argparser.add_argument('--base_path', type=str, default='')
|
| 77 |
+
argparser.add_argument('--learning_rate', type=float, default=1e-4)
|
| 78 |
+
argparser.add_argument('--num_epochs', type=int, default=1000)
|
| 79 |
+
argparser.add_argument('--num_accum_steps', type=int, default=4)
|
| 80 |
+
argparser.add_argument('--truncate_steps', type=int, default=50)
|
| 81 |
+
argparser.add_argument("--truncate_kl", type=str2bool, default=False)
|
| 82 |
+
argparser.add_argument('--gumbel_temp', type=float, default=1.0)
|
| 83 |
+
argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
|
| 84 |
+
argparser.add_argument('--batch_size', type=int, default=32)
|
| 85 |
+
argparser.add_argument('--name', type=str, default='debug')
|
| 86 |
+
argparser.add_argument('--total_num_steps', type=int, default=128)
|
| 87 |
+
argparser.add_argument('--copy_flag_temp', type=float, default=None)
|
| 88 |
+
argparser.add_argument('--save_every_n_epochs', type=int, default=50)
|
| 89 |
+
argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
|
| 90 |
+
argparser.add_argument("--seed", type=int, default=0)
|
| 91 |
+
# new
|
| 92 |
+
argparser.add_argument('--run_name', type=str, default='drakes')
|
| 93 |
+
argparser.add_argument("--device", default="cuda", type=str)
|
| 94 |
+
|
| 95 |
+
# mcts
|
| 96 |
+
argparser.add_argument('--num_sequences', type=int, default=100)
|
| 97 |
+
argparser.add_argument('--num_children', type=int, default=20)
|
| 98 |
+
argparser.add_argument('--num_iter', type=int, default=100) # iterations of mcts
|
| 99 |
+
argparser.add_argument('--seq_length', type=int, default=200)
|
| 100 |
+
argparser.add_argument('--time_conditioning', action='store_true', default=False)
|
| 101 |
+
argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
|
| 102 |
+
argparser.add_argument('--buffer_size', type=int, default=100)
|
| 103 |
+
argparser.add_argument('--wdce_num_replicates', type=int, default=16)
|
| 104 |
+
argparser.add_argument('--noise_removal', action='store_true', default=False)
|
| 105 |
+
argparser.add_argument('--exploration', type=float, default=0.1)
|
| 106 |
+
argparser.add_argument('--reset_every_n_step', type=int, default=100)
|
| 107 |
+
argparser.add_argument('--alpha', type=float, default=0.01)
|
| 108 |
+
argparser.add_argument('--scalarization', type=str, default='sum')
|
| 109 |
+
argparser.add_argument('--no_mcts', action='store_true', default=False)
|
| 110 |
+
argparser.add_argument("--centering", action='store_true', default=False)
|
| 111 |
+
argparser.add_argument('--num_obj', type=int, default=5)
|
| 112 |
+
|
| 113 |
+
argparser.add_argument('--prot_seq', type=str, default=None)
|
| 114 |
+
argparser.add_argument('--prot_name', type=str, default=None)
|
| 115 |
+
|
| 116 |
+
args = argparser.parse_args()
|
| 117 |
+
print(args)
|
| 118 |
+
|
| 119 |
+
# pretrained model path
|
| 120 |
+
ckpt_path = f'{args.base_path}/TR2-D2/tr2d2-pep/pretrained/peptune-pretrained.ckpt'
|
| 121 |
+
|
| 122 |
+
# reinitialize Hydra
|
| 123 |
+
GlobalHydra.instance().clear()
|
| 124 |
+
|
| 125 |
+
# Initialize Hydra and compose the configuration
|
| 126 |
+
initialize(config_path="configs", job_name="load_model")
|
| 127 |
+
cfg = compose(config_name="config.yaml")
|
| 128 |
+
curr_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 129 |
+
|
| 130 |
+
set_seed(args.seed, use_cuda=True)
|
| 131 |
+
|
| 132 |
+
# proteins
|
| 133 |
+
amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV'
|
| 134 |
+
tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF'
|
| 135 |
+
gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM'
|
| 136 |
+
glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS'
|
| 137 |
+
glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM'
|
| 138 |
+
ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF'
|
| 139 |
+
cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL'
|
| 140 |
+
ligase = 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS'
|
| 141 |
+
skp2 = 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL'
|
| 142 |
+
|
| 143 |
+
if args.prot_seq is not None:
|
| 144 |
+
prot = args.prot_seq
|
| 145 |
+
prot_name = args.prot_name
|
| 146 |
+
filename = args.prot_name
|
| 147 |
+
else:
|
| 148 |
+
prot = tfr
|
| 149 |
+
prot_name = "tfr"
|
| 150 |
+
filename = "tfr"
|
| 151 |
+
|
| 152 |
+
# Initialize the model
|
| 153 |
+
new_model = Diffusion.load_from_checkpoint(ckpt_path, config=cfg, strict=False, map_location=args.device)
|
| 154 |
+
old_model = Diffusion.load_from_checkpoint(ckpt_path, config=cfg, strict=False, map_location=args.device)
|
| 155 |
+
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
final_x, log_rnd, final_rewards, score_vectors, sequences = generate_mcts(args, cfg, new_model, old_model,
|
| 158 |
+
prot=prot, prot_name=prot_name)
|
| 159 |
+
|
| 160 |
+
final_x = final_x.detach().to('cpu') # [B, L] integer tokens
|
| 161 |
+
log_rnd = log_rnd.detach().to('cpu').float().view(-1) # [B]
|
| 162 |
+
#final_rewards = final_rewards.detach().to('cpu').float().view(-1) # [B]
|
| 163 |
+
|
| 164 |
+
print("loaded models...")
|
| 165 |
+
analyzer = PeptideAnalyzer()
|
| 166 |
+
|
| 167 |
+
generation_results = []
|
| 168 |
+
|
| 169 |
+
for i in range(final_x.shape[0]):
|
| 170 |
+
sequence = sequences[i]
|
| 171 |
+
log_rnd_single = log_rnd[i]
|
| 172 |
+
final_reward = final_rewards[i]
|
| 173 |
+
|
| 174 |
+
aa_seq, seq_length = analyzer.analyze_structure(sequence)
|
| 175 |
+
|
| 176 |
+
scores = score_vectors[i]
|
| 177 |
+
|
| 178 |
+
binding1 = scores[0]
|
| 179 |
+
solubility = scores[1]
|
| 180 |
+
hemo = scores[2]
|
| 181 |
+
nonfouling = scores[3]
|
| 182 |
+
permeability = scores[4]
|
| 183 |
+
|
| 184 |
+
generation_results.append([sequence, aa_seq, final_reward, log_rnd_single, binding1, solubility, hemo, nonfouling, permeability])
|
| 185 |
+
print(f"length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity: {binding1} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling} | Permeability: {permeability}")
|
| 186 |
+
|
| 187 |
+
sys.stdout.flush()
|
| 188 |
+
|
| 189 |
+
df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Peptide Sequence', 'Final Reward', 'Log RND', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling', 'Permeability'])
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
df.to_csv(f'{args.base_path}/TR2-D2/tr2d2-pep/plots/{prot_name}-peptune-baseline/generation_results.csv', index=False)
|
tr2d2-pep/metrics.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from math import sqrt
|
| 3 |
+
|
| 4 |
+
def summarize_metrics(skip, csv_path: str, save_path: str | None = None) -> pd.DataFrame:
|
| 5 |
+
"""
|
| 6 |
+
Compute mean and standard deviation for all columns except the first
|
| 7 |
+
(assumed non-numeric identifier like 'Peptide Sequence').
|
| 8 |
+
|
| 9 |
+
Returns a DataFrame with rows = column names and columns = ['mean','std','count'].
|
| 10 |
+
Uses sample std (ddof=1). Non-numeric cells are coerced to NaN.
|
| 11 |
+
"""
|
| 12 |
+
df = pd.read_csv(csv_path)
|
| 13 |
+
vals = df.iloc[:, skip:].apply(pd.to_numeric, errors='coerce') # columns 2..end
|
| 14 |
+
stats = vals.agg(['mean', 'std', 'count']).T # shape: (num_metrics, 3)
|
| 15 |
+
if save_path:
|
| 16 |
+
stats.to_csv(save_path, index=True)
|
| 17 |
+
return stats
|
| 18 |
+
|
| 19 |
+
def summarize_list(xs, ddof = 1):
|
| 20 |
+
# Clean & coerce to float
|
| 21 |
+
vals = []
|
| 22 |
+
for x in xs:
|
| 23 |
+
if x is None or x == "":
|
| 24 |
+
continue
|
| 25 |
+
try:
|
| 26 |
+
vals.append(float(x))
|
| 27 |
+
except (TypeError, ValueError):
|
| 28 |
+
continue
|
| 29 |
+
|
| 30 |
+
n = len(vals)
|
| 31 |
+
if n == 0:
|
| 32 |
+
raise ValueError("No numeric values found.")
|
| 33 |
+
if n <= ddof:
|
| 34 |
+
raise ValueError(f"Need at least {ddof + 1} numeric values; got {n}.")
|
| 35 |
+
|
| 36 |
+
# Welford’s algorithm (one pass, stable)
|
| 37 |
+
mean = 0.0
|
| 38 |
+
M2 = 0.0
|
| 39 |
+
count = 0
|
| 40 |
+
for v in vals:
|
| 41 |
+
count += 1
|
| 42 |
+
delta = v - mean
|
| 43 |
+
mean += delta / count
|
| 44 |
+
M2 += delta * (v - mean)
|
| 45 |
+
|
| 46 |
+
var = M2 / (count - ddof)
|
| 47 |
+
std = sqrt(var)
|
| 48 |
+
|
| 49 |
+
result = {"mean": mean, "std": std, "count": count}
|
| 50 |
+
|
| 51 |
+
return result
|
| 52 |
+
|
| 53 |
+
def csv_column_to_list(path: str, column: str, *, dropna: bool = True):
|
| 54 |
+
df = pd.read_csv(path)
|
| 55 |
+
if column not in df.columns:
|
| 56 |
+
raise KeyError(f"Column '{column}' not found. Available: {list(df.columns)}")
|
| 57 |
+
s = df[column]
|
| 58 |
+
if dropna:
|
| 59 |
+
s = s.dropna()
|
| 60 |
+
return s.tolist()
|
| 61 |
+
|
| 62 |
+
def main():
|
| 63 |
+
path = ""
|
| 64 |
+
prot_name = ""
|
| 65 |
+
stats = summarize_metrics(skip=1, csv_path=f"{path}/{prot_name}_generation_results.csv",
|
| 66 |
+
save_path=f"{path}/results_summary.csv")
|
| 67 |
+
|
| 68 |
+
print(stats)
|
| 69 |
+
|
| 70 |
+
if __name__ == '__main__':
|
| 71 |
+
main()
|
tr2d2-pep/noise_schedule.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
torch._C._jit_set_profiling_mode(False)
|
| 7 |
+
torch._C._jit_set_profiling_executor(False)
|
| 8 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
| 9 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
| 10 |
+
|
| 11 |
+
def get_noise(config, dtype=torch.float32):
|
| 12 |
+
if config.noise.type == 'geometric':
|
| 13 |
+
return GeometricNoise(config.noise.sigma_min, config.noise.sigma_max)
|
| 14 |
+
elif config.noise.type == 'loglinear':
|
| 15 |
+
return LogLinearNoise()
|
| 16 |
+
elif config.noise.type == 'cosine':
|
| 17 |
+
return CosineNoise()
|
| 18 |
+
elif config.noise.type == 'cosinesqr':
|
| 19 |
+
return CosineSqrNoise()
|
| 20 |
+
elif config.noise.type == 'linear':
|
| 21 |
+
return Linear(config.noise.sigma_min, config.noise.sigma_max, dtype)
|
| 22 |
+
else:
|
| 23 |
+
raise ValueError(f'{config.noise.type} is not a valid noise')
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def binary_discretization(z):
|
| 27 |
+
z_hard = torch.sign(z)
|
| 28 |
+
z_soft = z / torch.norm(z, dim=-1, keepdim=True)
|
| 29 |
+
return z_soft + (z_hard - z_soft).detach()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Noise(abc.ABC, nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Baseline forward method to get the total + rate of noise at a timestep
|
| 35 |
+
"""
|
| 36 |
+
def forward(self, t):
|
| 37 |
+
# Assume time goes from 0 to 1
|
| 38 |
+
return self.total_noise(t), self.rate_noise(t)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class CosineNoise(Noise):
|
| 42 |
+
def __init__(self, eps=1e-3):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.eps = eps
|
| 45 |
+
|
| 46 |
+
def rate_noise(self, t):
|
| 47 |
+
cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
|
| 48 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
|
| 49 |
+
scale = torch.pi / 2
|
| 50 |
+
return scale * sin / (cos + self.eps)
|
| 51 |
+
|
| 52 |
+
def total_noise(self, t):
|
| 53 |
+
cos = torch.cos(t * torch.pi / 2)
|
| 54 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class CosineSqrNoise(Noise):
|
| 58 |
+
def __init__(self, eps=1e-3):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.eps = eps
|
| 61 |
+
|
| 62 |
+
def rate_noise(self, t):
|
| 63 |
+
cos = (1 - self.eps) * (
|
| 64 |
+
torch.cos(t * torch.pi / 2) ** 2)
|
| 65 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi)
|
| 66 |
+
scale = torch.pi / 2
|
| 67 |
+
return scale * sin / (cos + self.eps)
|
| 68 |
+
|
| 69 |
+
def total_noise(self, t):
|
| 70 |
+
cos = torch.cos(t * torch.pi / 2) ** 2
|
| 71 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class Linear(Noise):
|
| 75 |
+
def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
|
| 78 |
+
self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
|
| 79 |
+
|
| 80 |
+
def rate_noise(self):
|
| 81 |
+
return self.sigma_max - self.sigma_min
|
| 82 |
+
|
| 83 |
+
def total_noise(self, t):
|
| 84 |
+
return self.sigma_min + t * (self.sigma_max - self.sigma_min)
|
| 85 |
+
|
| 86 |
+
def importance_sampling_transformation(self, t):
|
| 87 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
| 88 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
| 89 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
| 90 |
+
return (sigma_t - self.sigma_min) / (
|
| 91 |
+
self.sigma_max - self.sigma_min)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class GeometricNoise(Noise):
|
| 95 |
+
def __init__(self, sigma_min=1e-3, sigma_max=1):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
|
| 98 |
+
|
| 99 |
+
def rate_noise(self, t):
|
| 100 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
|
| 101 |
+
self.sigmas[1].log() - self.sigmas[0].log())
|
| 102 |
+
|
| 103 |
+
def total_noise(self, t):
|
| 104 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class LogLinearNoise(Noise):
|
| 108 |
+
"""Log Linear noise schedule.
|
| 109 |
+
|
| 110 |
+
Built such that 1 - 1/e^(n(t)) interpolates between 0 and
|
| 111 |
+
~1 when t varies from 0 to 1. Total noise is
|
| 112 |
+
-log(1 - (1 - eps) * t), so the sigma will be
|
| 113 |
+
(1 - eps) * t.
|
| 114 |
+
"""
|
| 115 |
+
def __init__(self, eps=1e-3):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.eps = eps
|
| 118 |
+
self.sigma_max = self.total_noise(torch.tensor(1.0))
|
| 119 |
+
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
|
| 120 |
+
|
| 121 |
+
def rate_noise(self, t):
|
| 122 |
+
return (1 - self.eps) / (1 - (1 - self.eps) * t)
|
| 123 |
+
|
| 124 |
+
def total_noise(self, t):
|
| 125 |
+
return -torch.log1p(-(1 - self.eps) * t)
|
| 126 |
+
|
| 127 |
+
def importance_sampling_transformation(self, t):
|
| 128 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
| 129 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
| 130 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
| 131 |
+
t = - torch.expm1(- sigma_t) / (1 - self.eps)
|
| 132 |
+
return t
|
| 133 |
+
|
| 134 |
+
class LogPolyNoise(Noise):
|
| 135 |
+
"""
|
| 136 |
+
Log Polynomial noise schedule for slower masking of peptide bond tokens
|
| 137 |
+
"""
|
| 138 |
+
def __init__(self, eps=1e-3):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.eps = eps
|
| 141 |
+
self.sigma_max = self.total_noise(torch.tensor(1.0))
|
| 142 |
+
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
|
| 143 |
+
|
| 144 |
+
def rate_noise(self, t):
|
| 145 |
+
# derivative of -log(1-t^w)
|
| 146 |
+
return ((3 * (t**2)) - self.eps) / (1 - (1 - self.eps) * (t**3))
|
| 147 |
+
|
| 148 |
+
def total_noise(self, t):
|
| 149 |
+
# -log(1-t^w)
|
| 150 |
+
return -torch.log1p(-(1 - self.eps) * (t**3))
|
tr2d2-pep/peptide_mcts.py
ADDED
|
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random as rd
|
| 6 |
+
from utils.app import PeptideAnalyzer
|
| 7 |
+
from utils.timer import StepTimer
|
| 8 |
+
from scoring.scoring_functions import ScoringFunctions
|
| 9 |
+
|
| 10 |
+
import noise_schedule
|
| 11 |
+
|
| 12 |
+
### for peptide multi-objective ###
|
| 13 |
+
def dominates(a, b):
|
| 14 |
+
a = np.asarray(a); b = np.asarray(b)
|
| 15 |
+
return np.all(a >= b) and np.any(a > b)
|
| 16 |
+
|
| 17 |
+
def dominated_by(a, b):
|
| 18 |
+
return dominates(b, a)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def updateParetoFront(paretoFront, node, scoreVector, totalSize=None, eps=1e-12):
|
| 22 |
+
"""
|
| 23 |
+
Maintain a non-dominated set (Pareto front) of (node -> scoreVector).
|
| 24 |
+
|
| 25 |
+
- Accept 'node' iff it is NOT dominated by any node in the set.
|
| 26 |
+
- Remove any nodes that ARE dominated by 'node'.
|
| 27 |
+
- Skip insertion if an equal point already exists (within eps).
|
| 28 |
+
- If totalSize is given and the archive exceeds it, drop the item
|
| 29 |
+
with the smallest sum(scoreVector) as a simple tie-breaker.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
paretoFront (dict): {node: scoreVector}
|
| 33 |
+
node: candidate node (used as dict key)
|
| 34 |
+
scoreVector (array-like): candidate scores (to be maximized)
|
| 35 |
+
totalSize (int|None): optional max size for the archive
|
| 36 |
+
eps (float): tolerance for equality/inequality checks
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
dict: updated paretoFront
|
| 40 |
+
"""
|
| 41 |
+
s = np.asarray(scoreVector, dtype=float)
|
| 42 |
+
|
| 43 |
+
def dominates(a, b):
|
| 44 |
+
# a >= b in all coords and > in at least one (with tolerance)
|
| 45 |
+
return np.all(a >= b - eps) and np.any(a > b + eps)
|
| 46 |
+
|
| 47 |
+
def equal(a, b):
|
| 48 |
+
return np.all(np.abs(a - b) <= eps)
|
| 49 |
+
|
| 50 |
+
# reject if candidate is dominated by any node already in the set
|
| 51 |
+
for v in paretoFront.values():
|
| 52 |
+
v = np.asarray(v, dtype=float)
|
| 53 |
+
if dominates(v, s):
|
| 54 |
+
return paretoFront # no change
|
| 55 |
+
|
| 56 |
+
# remove any nodes dominated by candidate node
|
| 57 |
+
survivors = {}
|
| 58 |
+
#has_equal = False
|
| 59 |
+
for k, v in paretoFront.items():
|
| 60 |
+
v_arr = np.asarray(v, dtype=float)
|
| 61 |
+
if dominates(s, v_arr):
|
| 62 |
+
continue # drop dominated incumbent
|
| 63 |
+
"""if equal(s, v_arr):
|
| 64 |
+
has_equal = True # skip duplicate insertion later"""
|
| 65 |
+
survivors[k] = v_arr
|
| 66 |
+
|
| 67 |
+
# if an equal point exists, keep survivors as-is (no duplicate)
|
| 68 |
+
"""if has_equal:
|
| 69 |
+
return survivors"""
|
| 70 |
+
|
| 71 |
+
# insert node
|
| 72 |
+
survivors[node] = s
|
| 73 |
+
|
| 74 |
+
# delete nodes if larger than total size
|
| 75 |
+
if totalSize is not None and totalSize > 0 and len(survivors) > totalSize:
|
| 76 |
+
# remove the item with the smallest sum(scoreVector)
|
| 77 |
+
keys = list(survivors.keys())
|
| 78 |
+
sums = np.array([np.sum(np.asarray(survivors[k], dtype=float)) for k in keys])
|
| 79 |
+
drop_idx = int(np.argmin(sums))
|
| 80 |
+
del survivors[keys[drop_idx]]
|
| 81 |
+
|
| 82 |
+
return survivors
|
| 83 |
+
|
| 84 |
+
### BEGINNING OF NODE CLASS ###
|
| 85 |
+
|
| 86 |
+
class Node:
|
| 87 |
+
"""
|
| 88 |
+
Node class: partially unmasked sequence
|
| 89 |
+
- parentNode: Node object at previous time step
|
| 90 |
+
- childNodes: set of M Node objects generated from sampling M distinct unmasking schemes
|
| 91 |
+
- totalReward: vector of cumulative rewards for all K objectives
|
| 92 |
+
- visits: number of times the node has been visited by an interation
|
| 93 |
+
- path: array of partially unmasked SMILES strings leading to the node from the completely masked root node
|
| 94 |
+
- timestep: the time step where the sequence was sampled
|
| 95 |
+
"""
|
| 96 |
+
def __init__(self, args, tokens=None, log_rnd=None, log_policy_step=None, log_pretrained_step=None, parentNode=None, childNodes=None, totalReward=None, timestep=None):
|
| 97 |
+
self.args = args
|
| 98 |
+
self.parentNode = parentNode
|
| 99 |
+
# fixed child node list creation
|
| 100 |
+
self.childNodes = [] if childNodes is None else childNodes
|
| 101 |
+
|
| 102 |
+
self.log_rnd = log_rnd # stores the log_rnd up to that step
|
| 103 |
+
|
| 104 |
+
#self.log_p0 = 0 # stores the log probabiltiy of the unmasking step from the previous iteration
|
| 105 |
+
self.log_policy_step = log_policy_step # stores the log probability of the unmasking step under the current policy
|
| 106 |
+
self.log_pretrained_step = log_pretrained_step
|
| 107 |
+
|
| 108 |
+
# initialize total rewards to the reward of the roll out unmasked sequence
|
| 109 |
+
if totalReward is not None:
|
| 110 |
+
self.totalReward = totalReward # potential reward of the node based on generated children
|
| 111 |
+
else:
|
| 112 |
+
self.totalReward = np.zeros(self.args.num_obj)
|
| 113 |
+
|
| 114 |
+
# set initial visits to 1
|
| 115 |
+
self.visits = 1
|
| 116 |
+
|
| 117 |
+
# set timestep (value between 0 and num_steps)
|
| 118 |
+
self.timestep = timestep
|
| 119 |
+
|
| 120 |
+
# dict with 'seqs' as token array and 'attention_mask'
|
| 121 |
+
self.tokens = tokens
|
| 122 |
+
|
| 123 |
+
def selectNode(self):
|
| 124 |
+
"""
|
| 125 |
+
Selects a node to move to among the children nodes based on select score
|
| 126 |
+
"""
|
| 127 |
+
# extract the status of the current node
|
| 128 |
+
nodeStatus = self.getExpandStatus()
|
| 129 |
+
|
| 130 |
+
# if the node is a legal non-leaf node
|
| 131 |
+
if (nodeStatus == 3):
|
| 132 |
+
# initialize array that will store select score vectors of each child node
|
| 133 |
+
|
| 134 |
+
paretoFront = {}
|
| 135 |
+
|
| 136 |
+
for childNode in self.childNodes:
|
| 137 |
+
childStatus = childNode.getExpandStatus()
|
| 138 |
+
# only append child if it is legal leaf node (expandable) or legal non-leaf node
|
| 139 |
+
if childStatus == 2 or childStatus == 3:
|
| 140 |
+
selectScore = childNode.calcSelectScore()
|
| 141 |
+
paretoFront = updateParetoFront(paretoFront, childNode, selectScore)
|
| 142 |
+
|
| 143 |
+
selected = rd.choice(list(paretoFront.keys()))
|
| 144 |
+
|
| 145 |
+
# return selected child node and status
|
| 146 |
+
return selected, selected.getExpandStatus()
|
| 147 |
+
|
| 148 |
+
# if node is not valid non-leaf node
|
| 149 |
+
return self, nodeStatus
|
| 150 |
+
|
| 151 |
+
def addChildNode(self, tokens, log_rnd, log_policy_step, log_pretrained_step, totalReward):
|
| 152 |
+
""""
|
| 153 |
+
Adds a child node:
|
| 154 |
+
log_rnd: log_rnd of the path up to the added child node
|
| 155 |
+
log_policy_step: scalar value of the log-prob of sampling the step under the policy
|
| 156 |
+
log_pretrained_step: scalar value of the log-prob of sampling the step under the pretrained model
|
| 157 |
+
"""
|
| 158 |
+
child = Node(args=self.args,
|
| 159 |
+
tokens=tokens,
|
| 160 |
+
log_rnd = log_rnd,
|
| 161 |
+
log_policy_step=log_policy_step,
|
| 162 |
+
log_pretrained_step=log_pretrained_step,
|
| 163 |
+
parentNode=self,
|
| 164 |
+
childNodes=[],
|
| 165 |
+
totalReward=totalReward,
|
| 166 |
+
timestep=self.timestep+1)
|
| 167 |
+
|
| 168 |
+
self.childNodes.append(child)
|
| 169 |
+
return child
|
| 170 |
+
|
| 171 |
+
def update_logrnd(self, log_policy_step, log_rnd):
|
| 172 |
+
self.log_policy_step = log_policy_step
|
| 173 |
+
self.log_rnd = log_rnd
|
| 174 |
+
|
| 175 |
+
def updateNode(self, rewards):
|
| 176 |
+
"""
|
| 177 |
+
Updates the cumulative rewards vector with the reward vector at a descendent leaf node.
|
| 178 |
+
Increments the number of visits to the node.
|
| 179 |
+
"""
|
| 180 |
+
self.visits += 1
|
| 181 |
+
|
| 182 |
+
self.totalReward += rewards # singleton tensor
|
| 183 |
+
|
| 184 |
+
def calcSelectScore(self):
|
| 185 |
+
"""
|
| 186 |
+
Calculates the select score for the node from the cumulative rewards vector and number of visits.
|
| 187 |
+
- c: determines the degree of exploration
|
| 188 |
+
- minSelectScore: determines the
|
| 189 |
+
"""
|
| 190 |
+
scaling = 0.1 # scaling of the second term in the select score
|
| 191 |
+
|
| 192 |
+
# K-dimensional vector of normalized rewards for each objective
|
| 193 |
+
normRewards = self.totalReward / self.visits
|
| 194 |
+
|
| 195 |
+
# scales the cumulative reward by the sampling probability
|
| 196 |
+
|
| 197 |
+
return normRewards + (scaling * self.log_policy_step.detach().cpu().item() * np.sqrt(self.parentNode.visits) / self.visits)
|
| 198 |
+
|
| 199 |
+
def getExpandStatus(self):
|
| 200 |
+
"""
|
| 201 |
+
Returns an integer indicating whether the node is a:
|
| 202 |
+
1. terminal node (sequence is fully unmasked)
|
| 203 |
+
2. legal leaf node (partially unmasked sequence that can be expanded)
|
| 204 |
+
3. legal non-leaf node (already expanded sequence with M child nodes)
|
| 205 |
+
"""
|
| 206 |
+
if self.timestep == self.args.total_num_steps:
|
| 207 |
+
return 1
|
| 208 |
+
elif (self.timestep < self.args.total_num_steps) and (len(self.childNodes) == 0):
|
| 209 |
+
return 2
|
| 210 |
+
return 3
|
| 211 |
+
|
| 212 |
+
### END OF NODE CLASS ###
|
| 213 |
+
|
| 214 |
+
### BEGINNING OF MCTS CLASS ###
|
| 215 |
+
|
| 216 |
+
class MCTS:
|
| 217 |
+
def __init__(self, args, config, policy_model, pretrained, score_func_names=[], prot_seqs=None, rootNode=None):
|
| 218 |
+
self.timer = StepTimer(policy_model.device)
|
| 219 |
+
|
| 220 |
+
self.device = policy_model.device
|
| 221 |
+
|
| 222 |
+
self.args = args
|
| 223 |
+
self.config = config
|
| 224 |
+
self.noise = noise_schedule.get_noise(config)
|
| 225 |
+
self.time_conditioning = args.time_conditioning
|
| 226 |
+
|
| 227 |
+
self.num_obj = len(score_func_names)
|
| 228 |
+
|
| 229 |
+
self.mask_index = policy_model.mask_index
|
| 230 |
+
masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index
|
| 231 |
+
masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)}
|
| 232 |
+
if rootNode is None:
|
| 233 |
+
self.rootNode = Node(self.args, tokens = masked_tokens,
|
| 234 |
+
log_rnd=torch.zeros((), device=self.device),
|
| 235 |
+
log_policy_step=torch.zeros((), device=self.device),
|
| 236 |
+
log_pretrained_step=torch.zeros((), device=self.device),
|
| 237 |
+
totalReward=np.zeros(self.num_obj), timestep=0)
|
| 238 |
+
else:
|
| 239 |
+
self.rootNode = rootNode # stores the root node of the tree
|
| 240 |
+
|
| 241 |
+
# dictionary:
|
| 242 |
+
# "seq": final unmasked sequence
|
| 243 |
+
# "traj": list of (N_steps, L)
|
| 244 |
+
# "reward": reward of the trajectory
|
| 245 |
+
self.buffer = [] # List[Dict[str, Any]]
|
| 246 |
+
|
| 247 |
+
self.buffer_size = args.buffer_size
|
| 248 |
+
|
| 249 |
+
self.num_steps = args.total_num_steps
|
| 250 |
+
#self.num_sequences = args.num_sequences
|
| 251 |
+
|
| 252 |
+
# pretrained model
|
| 253 |
+
self.pretrained = pretrained
|
| 254 |
+
|
| 255 |
+
# the policy model that we want to finetune
|
| 256 |
+
self.policy_model = policy_model
|
| 257 |
+
#self.tokenizer = policy_model.tokenizer
|
| 258 |
+
self.device = policy_model.device
|
| 259 |
+
|
| 260 |
+
self.sequence_length = args.seq_length
|
| 261 |
+
|
| 262 |
+
self.num_iter = args.num_iter
|
| 263 |
+
|
| 264 |
+
self.num_children = args.num_children
|
| 265 |
+
|
| 266 |
+
# score functions
|
| 267 |
+
|
| 268 |
+
self.rewardFunc = ScoringFunctions(score_func_names, prot_seqs, device=args.device)
|
| 269 |
+
|
| 270 |
+
self.iter_num = 0
|
| 271 |
+
|
| 272 |
+
self.reward_log = [] # stores scalarized total rewards
|
| 273 |
+
self.logrnd_log = []
|
| 274 |
+
# stores each objective
|
| 275 |
+
self.valid_fraction_log = []
|
| 276 |
+
self.affinity1_log = []
|
| 277 |
+
self.affinity2_log = []
|
| 278 |
+
self.permeability_log = []
|
| 279 |
+
self.sol_log = []
|
| 280 |
+
self.hemo_log = []
|
| 281 |
+
self.nf_log = []
|
| 282 |
+
|
| 283 |
+
self.policy_model.eval()
|
| 284 |
+
self.pretrained.eval()
|
| 285 |
+
|
| 286 |
+
# for peptides
|
| 287 |
+
self.analyzer = PeptideAnalyzer()
|
| 288 |
+
self.tokenizer = policy_model.tokenizer
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def reset(self, resetTree):
|
| 292 |
+
self.iter_num = 0
|
| 293 |
+
self.buffer = []
|
| 294 |
+
self.reward_log = []
|
| 295 |
+
self.logrnd_log = []
|
| 296 |
+
|
| 297 |
+
# reset logs for each objective
|
| 298 |
+
self.valid_fraction_log = []
|
| 299 |
+
self.affinity1_log = []
|
| 300 |
+
self.affinity2_log = []
|
| 301 |
+
self.permeability_log = []
|
| 302 |
+
self.sol_log = []
|
| 303 |
+
self.hemo_log = []
|
| 304 |
+
self.nf_log = []
|
| 305 |
+
|
| 306 |
+
# add option to continue with the same tree
|
| 307 |
+
if resetTree:
|
| 308 |
+
masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index
|
| 309 |
+
masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)}
|
| 310 |
+
self.rootNode = Node(self.args, tokens = masked_tokens,
|
| 311 |
+
log_rnd=torch.zeros((), device=self.device),
|
| 312 |
+
log_policy_step=torch.zeros((), device=self.device),
|
| 313 |
+
log_pretrained_step=torch.zeros((), device=self.device),
|
| 314 |
+
totalReward=np.zeros(self.num_obj), timestep=0)
|
| 315 |
+
|
| 316 |
+
def forward(self, resetTree=False):
|
| 317 |
+
|
| 318 |
+
self.reset(resetTree)
|
| 319 |
+
|
| 320 |
+
while (self.iter_num < self.num_iter):
|
| 321 |
+
self.iter_num += 1
|
| 322 |
+
|
| 323 |
+
# traverse the tree form the root node until a leaf node
|
| 324 |
+
with self.timer.section("select"):
|
| 325 |
+
leafNode, _ = self.select(self.rootNode)
|
| 326 |
+
|
| 327 |
+
# expand leaf node into num_children partially unmasked sequences at the next timestep
|
| 328 |
+
with self.timer.section("expand"):
|
| 329 |
+
self.expand(leafNode)
|
| 330 |
+
|
| 331 |
+
final_x, log_rnd, final_rewards, score_vectors, sequences = self.consolidateBuffer()
|
| 332 |
+
# return final_seqs (B, L), log_rnd (B, ), and final rewards (B, )
|
| 333 |
+
|
| 334 |
+
rows = self.timer.summary()
|
| 335 |
+
print("\n=== Timing summary (by total time) ===")
|
| 336 |
+
for name, cnt, total, mean, p50, p95 in rows:
|
| 337 |
+
print(f"{name:30s} n={cnt:5d} total={total:8.3f}s mean={mean*1e3:7.2f}ms "
|
| 338 |
+
f"p50={p50*1e3:7.2f}ms p95={p95*1e3:7.2f}ms")
|
| 339 |
+
|
| 340 |
+
return final_x, log_rnd, final_rewards, score_vectors, sequences
|
| 341 |
+
|
| 342 |
+
# new updateBuffer
|
| 343 |
+
def _debug_buffer_decision(self, sv, reason, extra=None):
|
| 344 |
+
if extra is None: extra = {}
|
| 345 |
+
print(f"[BUFFER] reason={reason} sv={np.round(sv,4)} "
|
| 346 |
+
f"buf_len={len(self.buffer)} extra={extra}")
|
| 347 |
+
|
| 348 |
+
def updateBuffer(self, x_final, log_rnd, score_vectors, childSequences):
|
| 349 |
+
B = x_final.shape[0]
|
| 350 |
+
traj_log_rnds, scalar_rewards = [], []
|
| 351 |
+
|
| 352 |
+
for i in range(B):
|
| 353 |
+
sv = np.asarray(score_vectors[i], dtype=float)
|
| 354 |
+
|
| 355 |
+
# determine how to scalarize the multi-objective rewards
|
| 356 |
+
if self.args.scalarization == "normalized":
|
| 357 |
+
pass
|
| 358 |
+
elif self.args.scalarization == "weighted":
|
| 359 |
+
pass
|
| 360 |
+
else:
|
| 361 |
+
scalar_reward = float(np.sum(sv))
|
| 362 |
+
|
| 363 |
+
traj_log_rnd = log_rnd[i] + (scalar_reward / self.args.alpha) # scale down by alpha
|
| 364 |
+
|
| 365 |
+
item = {
|
| 366 |
+
"x_final": x_final[i].clone(), # clone?
|
| 367 |
+
"log_rnd": traj_log_rnd.clone(),
|
| 368 |
+
"final_reward": scalar_reward,
|
| 369 |
+
"score_vector": sv.copy(),
|
| 370 |
+
"seq": childSequences[i],
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
# Drop if dominated by any existing
|
| 374 |
+
if any(dominated_by(sv, bi["score_vector"]) for bi in self.buffer):
|
| 375 |
+
# for debugging
|
| 376 |
+
self._debug_buffer_decision(sv, "rejected_dominated")
|
| 377 |
+
continue
|
| 378 |
+
|
| 379 |
+
# Remove any existing that this candidate dominates
|
| 380 |
+
keep = []
|
| 381 |
+
for bi in self.buffer:
|
| 382 |
+
if not dominates(sv, bi["score_vector"]):
|
| 383 |
+
keep.append(bi)
|
| 384 |
+
self.buffer = keep
|
| 385 |
+
|
| 386 |
+
# Insert with capacity rule
|
| 387 |
+
if len(self.buffer) < self.buffer_size:
|
| 388 |
+
self.buffer.append(item)
|
| 389 |
+
else:
|
| 390 |
+
# tie-breaker: replace the worst by a simple heuristic (min sum)
|
| 391 |
+
worst_i = int(np.argmin([np.sum(bi["score_vector"]) for bi in self.buffer]))
|
| 392 |
+
self.buffer[worst_i] = item
|
| 393 |
+
|
| 394 |
+
# for debugging
|
| 395 |
+
self._debug_buffer_decision(sv, "inserted", {"new_len": len(self.buffer)})
|
| 396 |
+
|
| 397 |
+
traj_log_rnds.append(traj_log_rnd)
|
| 398 |
+
scalar_rewards.append(scalar_reward)
|
| 399 |
+
|
| 400 |
+
traj_log_rnds = torch.stack(traj_log_rnds, dim=0) if traj_log_rnds else torch.empty(0)
|
| 401 |
+
scalar_rewards = np.asarray(scalar_rewards, dtype=float)
|
| 402 |
+
return traj_log_rnds, scalar_rewards
|
| 403 |
+
|
| 404 |
+
def consolidateBuffer(self):
|
| 405 |
+
"""
|
| 406 |
+
returns x_final, log_rnd, and final_rewards in tensors
|
| 407 |
+
"""
|
| 408 |
+
x_final = []
|
| 409 |
+
log_rnd = []
|
| 410 |
+
final_rewards = []
|
| 411 |
+
score_vectors = []
|
| 412 |
+
sequences = []
|
| 413 |
+
for item in self.buffer:
|
| 414 |
+
x_final.append(item["x_final"])
|
| 415 |
+
log_rnd.append(item["log_rnd"])
|
| 416 |
+
final_rewards.append(item["final_reward"])
|
| 417 |
+
score_vectors.append(item["score_vector"])
|
| 418 |
+
sequences.append(item["seq"])
|
| 419 |
+
|
| 420 |
+
x_final = torch.stack(x_final, dim=0) # (B, L)
|
| 421 |
+
log_rnd = torch.stack(log_rnd, dim=0).to(dtype=torch.float32) # (B)
|
| 422 |
+
final_rewards = np.stack(final_rewards, axis=0).astype(np.float32)
|
| 423 |
+
score_vectors = np.stack(score_vectors, axis=0).astype(np.float32)
|
| 424 |
+
|
| 425 |
+
return x_final, log_rnd, final_rewards, score_vectors, sequences
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def isPathEnd(self, path, maxDepth):
|
| 429 |
+
"""
|
| 430 |
+
Checks if the node is completely unmasked (ie. end of path)
|
| 431 |
+
or if the path is at the max depth
|
| 432 |
+
"""
|
| 433 |
+
if (path[-1] != self.mask_index).all():
|
| 434 |
+
return True
|
| 435 |
+
elif len(path) >= maxDepth:
|
| 436 |
+
return True
|
| 437 |
+
return False
|
| 438 |
+
|
| 439 |
+
def select(self, currNode, eps=1e-5):
|
| 440 |
+
"""
|
| 441 |
+
Traverse the tree from the root node until reaching a legal leaf node
|
| 442 |
+
"""
|
| 443 |
+
updated_log_rnd = torch.zeros((), device=self.device)
|
| 444 |
+
while True:
|
| 445 |
+
currNode, nodeStatus = currNode.selectNode()
|
| 446 |
+
|
| 447 |
+
if currNode.parentNode is not None:
|
| 448 |
+
# compute new log_policy
|
| 449 |
+
child_tokens = currNode.tokens['seqs'].to(self.device)
|
| 450 |
+
attn_mask = currNode.tokens['attention_mask'].to(self.device)
|
| 451 |
+
parent = currNode.parentNode
|
| 452 |
+
parent_tokens = parent.tokens['seqs'].to(self.device)
|
| 453 |
+
t = torch.ones(1, device = self.device)
|
| 454 |
+
dt = (1 - eps) / self.num_steps
|
| 455 |
+
with torch.no_grad():
|
| 456 |
+
with self.timer.section("select.compute_log_policy"):
|
| 457 |
+
updated_log_policy_step = self.policy_model.compute_log_policy(parent_tokens,
|
| 458 |
+
child_tokens,
|
| 459 |
+
t=t, dt=dt)
|
| 460 |
+
updated_log_rnd += updated_log_policy_step
|
| 461 |
+
|
| 462 |
+
currNode.update_logrnd(updated_log_policy_step, updated_log_rnd) # update log_rnd
|
| 463 |
+
|
| 464 |
+
if nodeStatus != 3:
|
| 465 |
+
return currNode, nodeStatus
|
| 466 |
+
|
| 467 |
+
def expand(self, parentNode, eps=1e-5):
|
| 468 |
+
"""
|
| 469 |
+
Sample unmasking steps from the pre-trained MDLM
|
| 470 |
+
adds num_children partially unmasked sequences to the children of the parentNode
|
| 471 |
+
"""
|
| 472 |
+
|
| 473 |
+
num_children = self.num_children
|
| 474 |
+
# initialize child rewards that will be added to total rewards
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
# compute number of rollout steps
|
| 478 |
+
# if parentNode.timestep = self.num_steps then num_rollout_steps = 1
|
| 479 |
+
num_rollout_steps = self.num_steps - parentNode.timestep
|
| 480 |
+
# array of rollout timesteps from the timestep of parent node to 0
|
| 481 |
+
rollout_t = torch.linspace(1, eps, self.num_steps + 1, device=self.device)
|
| 482 |
+
dt = (1 - eps) / self.num_steps
|
| 483 |
+
|
| 484 |
+
# initialize x and attn_mask
|
| 485 |
+
x = parentNode.tokens['seqs'].to(self.device)
|
| 486 |
+
attn_mask = parentNode.tokens['attention_mask'].to(self.device)
|
| 487 |
+
parent_log_rnd = parentNode.log_rnd # stores the log_rnd up to parent node
|
| 488 |
+
|
| 489 |
+
t = rollout_t[parentNode.timestep] * torch.ones(1, 1, device = self.device)
|
| 490 |
+
|
| 491 |
+
# sample M child sequences and compute their log probabilities
|
| 492 |
+
with torch.no_grad():
|
| 493 |
+
with self.timer.section("expand.batch_mcts_reverse_step"):
|
| 494 |
+
_, x_children, child_log_policy_step, child_log_pretrained_step = \
|
| 495 |
+
self.policy_model.batch_mcts_reverse_step(token_array=x,
|
| 496 |
+
t=t, dt=dt,
|
| 497 |
+
batch_size=num_children,
|
| 498 |
+
pretrained=self.pretrained)
|
| 499 |
+
|
| 500 |
+
# compute weight of the step (num_children, 1)
|
| 501 |
+
|
| 502 |
+
child_log_rnd = (parent_log_rnd + (child_log_pretrained_step - child_log_policy_step)).to(self.device)
|
| 503 |
+
|
| 504 |
+
x_rollout = x_children
|
| 505 |
+
|
| 506 |
+
traj_log_rnd = child_log_rnd # initialize log_rnd for entire rolled out trajectory
|
| 507 |
+
|
| 508 |
+
# rollout under the policy and compute the log ratio at each step
|
| 509 |
+
with self.timer.section("expand.rollout_total"):
|
| 510 |
+
for i in range(1, num_rollout_steps):
|
| 511 |
+
t = rollout_t[parentNode.timestep + i] * torch.ones(num_children, 1, device = self.device)
|
| 512 |
+
|
| 513 |
+
with torch.no_grad():
|
| 514 |
+
_, x_next, log_policy_step, log_pretrained_step = \
|
| 515 |
+
self.policy_model.mcts_reverse_step(x_rollout,
|
| 516 |
+
t=t, dt=dt,
|
| 517 |
+
pretrained=self.pretrained)
|
| 518 |
+
|
| 519 |
+
# add the rollout step
|
| 520 |
+
traj_log_rnd += log_pretrained_step - log_policy_step
|
| 521 |
+
|
| 522 |
+
x_rollout = x_next
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
# if mask token remains, fully unmask
|
| 526 |
+
mask_positions = (x_rollout == self.mask_index) # (B, L) bool
|
| 527 |
+
|
| 528 |
+
# does **any** mask remain in any sequence
|
| 529 |
+
any_mask_global = mask_positions.any().item() # true if mask remains
|
| 530 |
+
if any_mask_global:
|
| 531 |
+
with torch.no_grad():
|
| 532 |
+
with self.timer.section("expand.noise_removal"):
|
| 533 |
+
log_p, x_next, log_policy_step, log_pretrained_step = \
|
| 534 |
+
self.policy_model.mcts_noise_removal(x_rollout,
|
| 535 |
+
t=t, dt=dt,
|
| 536 |
+
pretrained=self.pretrained)
|
| 537 |
+
|
| 538 |
+
traj_log_rnd += log_pretrained_step - log_policy_step
|
| 539 |
+
|
| 540 |
+
x_rollout = x_next
|
| 541 |
+
|
| 542 |
+
# stores the string sequences for reward evaluation
|
| 543 |
+
with self.timer.section("expand.decode"):
|
| 544 |
+
childSequences = self.tokenizer.batch_decode(x_rollout)
|
| 545 |
+
|
| 546 |
+
## FOR PEPTIDES ONLY ##
|
| 547 |
+
valid_x_children = []
|
| 548 |
+
valid_x_final = []
|
| 549 |
+
validSequences = []
|
| 550 |
+
valid_traj_log_rnd = []
|
| 551 |
+
|
| 552 |
+
with self.timer.section("expand.filter_is_peptide"):
|
| 553 |
+
for i in range(num_children):
|
| 554 |
+
# string sequence
|
| 555 |
+
childSeq = childSequences[i]
|
| 556 |
+
|
| 557 |
+
# check if the peptide is valid
|
| 558 |
+
if self.analyzer.is_peptide(childSeq):
|
| 559 |
+
valid_x_children.append(x_children[i])
|
| 560 |
+
valid_x_final.append(x_rollout[i])
|
| 561 |
+
validSequences.append(childSeq)
|
| 562 |
+
valid_traj_log_rnd.append(traj_log_rnd[i])
|
| 563 |
+
else:
|
| 564 |
+
childTokens = {'seqs': x_children[i].to(dtype=torch.long), 'attention_mask': attn_mask}
|
| 565 |
+
parentNode.addChildNode(tokens=childTokens,
|
| 566 |
+
log_rnd=child_log_rnd[i],
|
| 567 |
+
log_policy_step=child_log_policy_step[i],
|
| 568 |
+
log_pretrained_step=child_log_pretrained_step[i],
|
| 569 |
+
totalReward=np.zeros(self.num_obj))
|
| 570 |
+
|
| 571 |
+
del traj_log_rnd
|
| 572 |
+
|
| 573 |
+
if (len(validSequences) != 0):
|
| 574 |
+
# add scores to log
|
| 575 |
+
with self.timer.section("expand.scoring_functions"):
|
| 576 |
+
score_vectors = self.rewardFunc(input_seqs=validSequences) # (num_children, num_objectives)
|
| 577 |
+
|
| 578 |
+
average_scores = score_vectors.T
|
| 579 |
+
|
| 580 |
+
self.affinity1_log.append(average_scores[0])
|
| 581 |
+
self.sol_log.append(average_scores[1])
|
| 582 |
+
self.hemo_log.append(average_scores[2])
|
| 583 |
+
self.nf_log.append(average_scores[3])
|
| 584 |
+
self.permeability_log.append(average_scores[4])
|
| 585 |
+
|
| 586 |
+
else:
|
| 587 |
+
# set the values added to log as 0s if there are no valid sequences
|
| 588 |
+
self.affinity1_log.append(np.zeros((self.num_obj, self.num_children)))
|
| 589 |
+
self.sol_log.append(np.zeros((self.num_obj, self.num_children)))
|
| 590 |
+
self.hemo_log.append(np.zeros((self.num_obj, self.num_children)))
|
| 591 |
+
self.nf_log.append(np.zeros((self.num_obj, self.num_children)))
|
| 592 |
+
self.permeability_log.append(np.zeros((self.num_obj, self.num_children)))
|
| 593 |
+
|
| 594 |
+
# convert to tensor
|
| 595 |
+
if len(valid_x_final) == 0:
|
| 596 |
+
# log and bail out gracefully for this expansion
|
| 597 |
+
self.valid_fraction_log.append(0.0)
|
| 598 |
+
return
|
| 599 |
+
|
| 600 |
+
valid_x_final = torch.stack(valid_x_final, dim=0)
|
| 601 |
+
valid_traj_log_rnd = torch.stack(valid_traj_log_rnd, dim=0)
|
| 602 |
+
# update buffer and get rewards
|
| 603 |
+
with self.timer.section("expand.update_buffer"):
|
| 604 |
+
traj_log_rnds, scalar_rewards = self.updateBuffer(valid_x_final, valid_traj_log_rnd, score_vectors, childSequences)
|
| 605 |
+
|
| 606 |
+
allChildReward = np.zeros_like(score_vectors[0])
|
| 607 |
+
|
| 608 |
+
for i in range(len(score_vectors)):
|
| 609 |
+
reward = score_vectors[i]
|
| 610 |
+
|
| 611 |
+
# add to all child reward vector for backprop
|
| 612 |
+
allChildReward += reward # (num_objectives,)
|
| 613 |
+
|
| 614 |
+
# create node for sequence and add to the children node of parent
|
| 615 |
+
childTokens = {'seqs': valid_x_children[i].to(dtype=torch.long), 'attention_mask': attn_mask}
|
| 616 |
+
parentNode.addChildNode(tokens=childTokens,
|
| 617 |
+
log_rnd=child_log_rnd[i],
|
| 618 |
+
log_policy_step=child_log_policy_step[i],
|
| 619 |
+
log_pretrained_step=child_log_pretrained_step[i],
|
| 620 |
+
totalReward=reward)
|
| 621 |
+
|
| 622 |
+
### END OF FOR PEPTIDES ONLY ###
|
| 623 |
+
|
| 624 |
+
valid_fraction = len(validSequences) / num_children
|
| 625 |
+
self.valid_fraction_log.append(valid_fraction)
|
| 626 |
+
|
| 627 |
+
# debugging
|
| 628 |
+
print(f"[EXPAND] iter={self.iter_num} parent_t={parentNode.timestep} "
|
| 629 |
+
f"num_children={num_children} valid={len(validSequences)} any_mask={any_mask_global}")
|
| 630 |
+
if score_vectors is not None:
|
| 631 |
+
print(f"[SCORES] min={np.min(score_vectors,0)} max={np.max(score_vectors,0)} "
|
| 632 |
+
f"nan_any={np.isnan(score_vectors).any()}")
|
| 633 |
+
# end debugging
|
| 634 |
+
|
| 635 |
+
self.reward_log.append(scalar_rewards)
|
| 636 |
+
self.logrnd_log.append(traj_log_rnds.detach().cpu().numpy())
|
| 637 |
+
|
| 638 |
+
allChildReward = allChildReward / len(validSequences) # normalize by number of valid children
|
| 639 |
+
# backpropogate all child rewards
|
| 640 |
+
with self.timer.section("expand.backprop"):
|
| 641 |
+
self.backprop(parentNode, allChildReward)
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def backprop(self, node, allChildReward):
|
| 645 |
+
# backpropogate rewards through the path leading to the leaf node from the root
|
| 646 |
+
while node:
|
| 647 |
+
node.updateNode(allChildReward)
|
| 648 |
+
node = node.parentNode
|
tr2d2-pep/plotting.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import seaborn as sns
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def plot_data_with_distribution_seaborn(log1, log2=None,
|
| 13 |
+
save_path=None,
|
| 14 |
+
label1=None,
|
| 15 |
+
label2=None,
|
| 16 |
+
title=None):
|
| 17 |
+
"""
|
| 18 |
+
Plots one or two datasets with the average values and distributions over iterations using Seaborn.
|
| 19 |
+
|
| 20 |
+
Parameters:
|
| 21 |
+
log1 (list of lists): The first list of scores (each element is a list of scores for an iteration).
|
| 22 |
+
log2 (list of lists, optional): The second list of scores (each element is a list of scores for an iteration). Defaults to None.
|
| 23 |
+
save_path (str): Path to save the plot. Defaults to None.
|
| 24 |
+
label1 (str): Label for the first dataset. Defaults to "Fraction of Valid Peptide SMILES".
|
| 25 |
+
label2 (str, optional): Label for the second dataset. Defaults to None.
|
| 26 |
+
title (str): Title of the plot. Defaults to "Fraction of Valid Peptides Over Iterations".
|
| 27 |
+
"""
|
| 28 |
+
# Prepare data for log1
|
| 29 |
+
data1 = pd.DataFrame({
|
| 30 |
+
"Iteration": np.repeat(range(1, len(log1) + 1), [len(scores) for scores in log1]),
|
| 31 |
+
label1: [score for scores in log1 for score in scores],
|
| 32 |
+
"Dataset": label1,
|
| 33 |
+
"Style": "Log1"
|
| 34 |
+
})
|
| 35 |
+
|
| 36 |
+
# Prepare data for log2 if provided
|
| 37 |
+
if log2 is not None:
|
| 38 |
+
data2 = pd.DataFrame({
|
| 39 |
+
"Iteration": np.repeat(range(1, len(log2) + 1), [len(scores) for scores in log2]),
|
| 40 |
+
label2: [score for scores in log2 for score in scores],
|
| 41 |
+
"Dataset": label2,
|
| 42 |
+
"Style": "Log2"
|
| 43 |
+
})
|
| 44 |
+
data = pd.concat([data1, data2], ignore_index=True)
|
| 45 |
+
else:
|
| 46 |
+
data = data1
|
| 47 |
+
|
| 48 |
+
palette = {
|
| 49 |
+
label1: "#8181ED", # Default color for log1
|
| 50 |
+
label2: "#D577FF" # Default color for log2 (if provided)
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
# Set Seaborn theme
|
| 54 |
+
sns.set_theme()
|
| 55 |
+
sns.set_context("paper")
|
| 56 |
+
|
| 57 |
+
# Create the plot
|
| 58 |
+
sns.relplot(
|
| 59 |
+
data=data,
|
| 60 |
+
kind="line",
|
| 61 |
+
x="Iteration",
|
| 62 |
+
y=label1,
|
| 63 |
+
hue="Dataset",
|
| 64 |
+
style="Style",
|
| 65 |
+
markers=True,
|
| 66 |
+
dashes=True,
|
| 67 |
+
ci="sd", # Show standard deviation
|
| 68 |
+
height=5,
|
| 69 |
+
aspect=1.5,
|
| 70 |
+
palette=palette
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Titles and labels
|
| 74 |
+
plt.title(title)
|
| 75 |
+
plt.xlabel("Iteration")
|
| 76 |
+
plt.ylabel(label1)
|
| 77 |
+
|
| 78 |
+
if save_path:
|
| 79 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 80 |
+
print(f"Plot saved to {save_path}")
|
| 81 |
+
plt.show()
|
| 82 |
+
|
| 83 |
+
def plot_data(log1, log2=None,
|
| 84 |
+
save_path=None,
|
| 85 |
+
label1="Log 1",
|
| 86 |
+
label2=None,
|
| 87 |
+
title="Fraction of Valid Peptides Over Iterations",
|
| 88 |
+
palette=None):
|
| 89 |
+
"""
|
| 90 |
+
Plots one or two datasets with their mean values over iterations.
|
| 91 |
+
|
| 92 |
+
Parameters:
|
| 93 |
+
log1 (list): The first list of mean values for each iteration.
|
| 94 |
+
log2 (list, optional): The second list of mean values for each iteration. Defaults to None.
|
| 95 |
+
save_path (str): Path to save the plot. Defaults to None.
|
| 96 |
+
label1 (str): Label for the first dataset. Defaults to "Log 1".
|
| 97 |
+
label2 (str, optional): Label for the second dataset. Defaults to None.
|
| 98 |
+
title (str): Title of the plot. Defaults to "Mean Values Over Iterations".
|
| 99 |
+
palette (dict, optional): A dictionary defining custom colors for datasets. Defaults to None.
|
| 100 |
+
"""
|
| 101 |
+
# Prepare data for log1
|
| 102 |
+
data1 = pd.DataFrame({
|
| 103 |
+
"Iteration": range(1, len(log1) + 1),
|
| 104 |
+
"Fraction of Valid Peptides": log1,
|
| 105 |
+
"Dataset": label1
|
| 106 |
+
})
|
| 107 |
+
|
| 108 |
+
# Prepare data for log2 if provided
|
| 109 |
+
if log2 is not None:
|
| 110 |
+
data2 = pd.DataFrame({
|
| 111 |
+
"Iteration": range(1, len(log2) + 1),
|
| 112 |
+
"Fraction of Valid Peptides": log2,
|
| 113 |
+
"Dataset": label2
|
| 114 |
+
})
|
| 115 |
+
data = pd.concat([data1, data2], ignore_index=True)
|
| 116 |
+
else:
|
| 117 |
+
data = data1
|
| 118 |
+
|
| 119 |
+
palette = {
|
| 120 |
+
label1: "#8181ED", # Default color for log1
|
| 121 |
+
label2: "#D577FF" # Default color for log2 (if provided)
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
# Set Seaborn theme
|
| 125 |
+
sns.set_theme()
|
| 126 |
+
sns.set_context("paper")
|
| 127 |
+
|
| 128 |
+
# Create the plot
|
| 129 |
+
sns.lineplot(
|
| 130 |
+
data=data,
|
| 131 |
+
x="Iteration",
|
| 132 |
+
y="Fraction of Valid Peptides",
|
| 133 |
+
hue="Dataset",
|
| 134 |
+
style="Dataset",
|
| 135 |
+
markers=True,
|
| 136 |
+
dashes=False,
|
| 137 |
+
palette=palette
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Titles and labels
|
| 141 |
+
plt.title(title)
|
| 142 |
+
plt.xlabel("Iteration")
|
| 143 |
+
plt.ylabel("Fraction of Valid Peptides")
|
| 144 |
+
|
| 145 |
+
if save_path:
|
| 146 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 147 |
+
print(f"Plot saved to {save_path}")
|
| 148 |
+
plt.show()
|
tr2d2-pep/roformer.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import RoFormerConfig, RoFormerForMaskedLM
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
class Roformer(nn.Module):
|
| 7 |
+
def __init__(self, config, tokenizer, device=None):
|
| 8 |
+
super(Roformer, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.tokenizer = tokenizer
|
| 11 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 12 |
+
|
| 13 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
roformer_config = RoFormerConfig(
|
| 17 |
+
vocab_size=self.tokenizer.vocab_size,
|
| 18 |
+
embedding_size=config.roformer.hidden_size,
|
| 19 |
+
hidden_size=config.roformer.hidden_size,
|
| 20 |
+
num_hidden_layers=config.roformer.n_layers,
|
| 21 |
+
num_attention_heads=config.roformer.n_heads,
|
| 22 |
+
intermediate_size=config.roformer.hidden_size * 4,
|
| 23 |
+
max_position_embeddings=config.roformer.max_position_embeddings,
|
| 24 |
+
hidden_dropout_prob=0.1,
|
| 25 |
+
attention_probs_dropout_prob=0.1,
|
| 26 |
+
pad_token_id=0,
|
| 27 |
+
rotary_value=False
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
self.model = RoFormerForMaskedLM(roformer_config).to(self.device)
|
| 31 |
+
|
| 32 |
+
def freeze_model(self):
|
| 33 |
+
for param in self.model.parameters():
|
| 34 |
+
param.requires_grad = False
|
| 35 |
+
|
| 36 |
+
def unfreeze_all_layers(self):
|
| 37 |
+
for param in self.model.parameters():
|
| 38 |
+
param.requires_grad = True
|
| 39 |
+
|
| 40 |
+
def unfreeze_n_layers(self, n):
|
| 41 |
+
num_layers = 8
|
| 42 |
+
|
| 43 |
+
for i, layer in enumerate(self.model.roformer.encoder.layer):
|
| 44 |
+
# finetune final n layers
|
| 45 |
+
if i >= num_layers - n:
|
| 46 |
+
# unfreeze query weights
|
| 47 |
+
for module in layer.attention.self.query.modules():
|
| 48 |
+
for param in module.parameters():
|
| 49 |
+
param.requires_grad = True
|
| 50 |
+
# unfreeze key weights
|
| 51 |
+
for module in layer.attention.self.key.modules():
|
| 52 |
+
for param in module.parameters():
|
| 53 |
+
param.requires_grad = True
|
| 54 |
+
|
| 55 |
+
def forward(self, input_ids, attn_mask):
|
| 56 |
+
|
| 57 |
+
input_ids = input_ids.to(self.device)
|
| 58 |
+
attn_mask = attn_mask.to(self.device)
|
| 59 |
+
|
| 60 |
+
# get logits embeddings
|
| 61 |
+
logits = self.model(input_ids=input_ids, attention_mask=attn_mask)
|
| 62 |
+
# return logits
|
| 63 |
+
#print(logits.logits)
|
| 64 |
+
return logits.logits
|
| 65 |
+
|
| 66 |
+
def save_model(self, save_dir):
|
| 67 |
+
self.model.save_pretrained(save_dir)
|
| 68 |
+
self.tokenizer.save_pretrained(save_dir)
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
def load_model(cls, save_dir, config, tokenizer):
|
| 72 |
+
roformer = cls(config, tokenizer)
|
| 73 |
+
roformer.model = RoFormerForMaskedLM.from_pretrained(save_dir)
|
| 74 |
+
return roformer
|
tr2d2-pep/run_mcts.sh
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
HOME_LOC=/path/to/your/home
|
| 4 |
+
ENV_PATH=/path/to/your/env
|
| 5 |
+
SCRIPT_LOC=$HOME_LOC/tr2d2/peptides
|
| 6 |
+
LOG_LOC=$HOME_LOC/tr2d2/peptides/logs
|
| 7 |
+
DATE=$(date +%m_%d)
|
| 8 |
+
SPECIAL_PREFIX='tfr-peptune-baseline'
|
| 9 |
+
# set 3 have skip connection
|
| 10 |
+
PYTHON_EXECUTABLE=$ENV_PATH/bin/python
|
| 11 |
+
|
| 12 |
+
# ===================================================================
|
| 13 |
+
|
| 14 |
+
source "$(conda info --base)/etc/profile.d/conda.sh"
|
| 15 |
+
conda activate $ENV_PATH
|
| 16 |
+
|
| 17 |
+
$PYTHON_EXECUTABLE $SCRIPT_LOC/generate_mcts.py \
|
| 18 |
+
--base_path $HOME_LOC \
|
| 19 |
+
--device "cuda:0" \
|
| 20 |
+
--noise_removal \
|
| 21 |
+
--run_name "tfr-peptune-baseline" \
|
| 22 |
+
--num_children 50 \
|
| 23 |
+
--num_iter 100 \
|
| 24 |
+
--buffer_size 100 \
|
| 25 |
+
--seq_length 200 \
|
| 26 |
+
--total_num_steps 128 \
|
| 27 |
+
--exploration 0.1 > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}.log" 2>&1
|
| 28 |
+
|
| 29 |
+
conda deactivate
|
tr2d2-pep/scoring/functions/binding.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os, torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import esm
|
| 8 |
+
from transformers import AutoModelForMaskedLM
|
| 9 |
+
|
| 10 |
+
class ImprovedBindingPredictor(nn.Module):
|
| 11 |
+
def __init__(self,
|
| 12 |
+
esm_dim=1280,
|
| 13 |
+
smiles_dim=768,
|
| 14 |
+
hidden_dim=512,
|
| 15 |
+
n_heads=8,
|
| 16 |
+
n_layers=3,
|
| 17 |
+
dropout=0.1):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
# Define binding thresholds
|
| 21 |
+
self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
|
| 22 |
+
self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
|
| 23 |
+
|
| 24 |
+
# Project to same dimension
|
| 25 |
+
self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
|
| 26 |
+
self.protein_projection = nn.Linear(esm_dim, hidden_dim)
|
| 27 |
+
self.protein_norm = nn.LayerNorm(hidden_dim)
|
| 28 |
+
self.smiles_norm = nn.LayerNorm(hidden_dim)
|
| 29 |
+
|
| 30 |
+
# Cross attention blocks with layer norm
|
| 31 |
+
self.cross_attention_layers = nn.ModuleList([
|
| 32 |
+
nn.ModuleDict({
|
| 33 |
+
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
|
| 34 |
+
'norm1': nn.LayerNorm(hidden_dim),
|
| 35 |
+
'ffn': nn.Sequential(
|
| 36 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
| 37 |
+
nn.ReLU(),
|
| 38 |
+
nn.Dropout(dropout),
|
| 39 |
+
nn.Linear(hidden_dim * 4, hidden_dim)
|
| 40 |
+
),
|
| 41 |
+
'norm2': nn.LayerNorm(hidden_dim)
|
| 42 |
+
}) for _ in range(n_layers)
|
| 43 |
+
])
|
| 44 |
+
|
| 45 |
+
# Prediction heads
|
| 46 |
+
self.shared_head = nn.Sequential(
|
| 47 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 48 |
+
nn.ReLU(),
|
| 49 |
+
nn.Dropout(dropout),
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Regression head
|
| 53 |
+
self.regression_head = nn.Linear(hidden_dim, 1)
|
| 54 |
+
|
| 55 |
+
# Classification head (3 classes: tight, medium, loose binding)
|
| 56 |
+
self.classification_head = nn.Linear(hidden_dim, 3)
|
| 57 |
+
|
| 58 |
+
def get_binding_class(self, affinity):
|
| 59 |
+
"""Convert affinity values to class indices
|
| 60 |
+
0: tight binding (>= 7.5)
|
| 61 |
+
1: medium binding (6.0-7.5)
|
| 62 |
+
2: weak binding (< 6.0)
|
| 63 |
+
"""
|
| 64 |
+
if isinstance(affinity, torch.Tensor):
|
| 65 |
+
tight_mask = affinity >= self.tight_threshold
|
| 66 |
+
weak_mask = affinity < self.weak_threshold
|
| 67 |
+
medium_mask = ~(tight_mask | weak_mask)
|
| 68 |
+
|
| 69 |
+
classes = torch.zeros_like(affinity, dtype=torch.long)
|
| 70 |
+
classes[medium_mask] = 1
|
| 71 |
+
classes[weak_mask] = 2
|
| 72 |
+
return classes
|
| 73 |
+
else:
|
| 74 |
+
if affinity >= self.tight_threshold:
|
| 75 |
+
return 0 # tight binding
|
| 76 |
+
elif affinity < self.weak_threshold:
|
| 77 |
+
return 2 # weak binding
|
| 78 |
+
else:
|
| 79 |
+
return 1 # medium binding
|
| 80 |
+
|
| 81 |
+
def forward(self, protein_emb, smiles_emb):
|
| 82 |
+
protein = self.protein_norm(self.protein_projection(protein_emb))
|
| 83 |
+
smiles = self.smiles_norm(self.smiles_projection(smiles_emb))
|
| 84 |
+
|
| 85 |
+
#protein = protein.transpose(0, 1)
|
| 86 |
+
#smiles = smiles.transpose(0, 1)
|
| 87 |
+
|
| 88 |
+
# Cross attention layers
|
| 89 |
+
for layer in self.cross_attention_layers:
|
| 90 |
+
# Protein attending to SMILES
|
| 91 |
+
attended_protein = layer['attention'](
|
| 92 |
+
protein, smiles, smiles
|
| 93 |
+
)[0]
|
| 94 |
+
protein = layer['norm1'](protein + attended_protein)
|
| 95 |
+
protein = layer['norm2'](protein + layer['ffn'](protein))
|
| 96 |
+
|
| 97 |
+
# SMILES attending to protein
|
| 98 |
+
attended_smiles = layer['attention'](
|
| 99 |
+
smiles, protein, protein
|
| 100 |
+
)[0]
|
| 101 |
+
smiles = layer['norm1'](smiles + attended_smiles)
|
| 102 |
+
smiles = layer['norm2'](smiles + layer['ffn'](smiles))
|
| 103 |
+
|
| 104 |
+
# Get sequence-level representations
|
| 105 |
+
protein_pool = torch.mean(protein, dim=0)
|
| 106 |
+
smiles_pool = torch.mean(smiles, dim=0)
|
| 107 |
+
|
| 108 |
+
# Concatenate both representations
|
| 109 |
+
combined = torch.cat([protein_pool, smiles_pool], dim=-1)
|
| 110 |
+
|
| 111 |
+
# Shared features
|
| 112 |
+
shared_features = self.shared_head(combined)
|
| 113 |
+
|
| 114 |
+
regression_output = self.regression_head(shared_features)
|
| 115 |
+
classification_logits = self.classification_head(shared_features)
|
| 116 |
+
|
| 117 |
+
return regression_output, classification_logits
|
| 118 |
+
|
| 119 |
+
class BindingAffinity:
|
| 120 |
+
def __init__(self, prot_seq, tokenizer, base_path, device=None, emb_model=None):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 123 |
+
|
| 124 |
+
# peptide embeddings
|
| 125 |
+
if emb_model is not None:
|
| 126 |
+
self.pep_model = emb_model.to(self.device).eval()
|
| 127 |
+
else:
|
| 128 |
+
self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
|
| 129 |
+
|
| 130 |
+
self.pep_tokenizer = tokenizer
|
| 131 |
+
|
| 132 |
+
self.model = ImprovedBindingPredictor().to(self.device)
|
| 133 |
+
checkpoint = torch.load(f'{base_path}/TR2-D2/tr2d2-pep/scoring/functions/classifiers/binding-affinity.pt',
|
| 134 |
+
map_location=self.device,
|
| 135 |
+
weights_only=False)
|
| 136 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 137 |
+
|
| 138 |
+
self.model.eval()
|
| 139 |
+
|
| 140 |
+
self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() # load ESM-2 model
|
| 141 |
+
self.esm_model = self.esm_model.to(self.device).eval()
|
| 142 |
+
self.prot_tokenizer = alphabet.get_batch_converter() # load esm tokenizer
|
| 143 |
+
|
| 144 |
+
data = [("target", prot_seq)]
|
| 145 |
+
# get tokenized protein
|
| 146 |
+
_, _, prot_tokens = self.prot_tokenizer(data)
|
| 147 |
+
prot_tokens = prot_tokens.to(self.device)
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
results = self.esm_model.forward(prot_tokens, repr_layers=[33]) # Example with ESM-2
|
| 150 |
+
prot_emb = results["representations"][33]
|
| 151 |
+
|
| 152 |
+
self.prot_emb = prot_emb[0].to(self.device)
|
| 153 |
+
self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def forward(self, input_seqs):
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
scores = []
|
| 159 |
+
for seq in input_seqs:
|
| 160 |
+
pep_tokens = self.pep_tokenizer(seq, return_tensors='pt', padding=True)
|
| 161 |
+
|
| 162 |
+
pep_tokens = {k: v.to(self.device) for k, v in pep_tokens.items()}
|
| 163 |
+
|
| 164 |
+
with torch.no_grad():
|
| 165 |
+
emb = self.pep_model(input_ids=pep_tokens['input_ids'],
|
| 166 |
+
attention_mask=pep_tokens['attention_mask'],
|
| 167 |
+
output_hidden_states=True)
|
| 168 |
+
|
| 169 |
+
#emb = self.pep_model(input_ids=pep_tokens['input_ids'], attention_mask=pep_tokens['attention_mask'])
|
| 170 |
+
pep_emb = emb.last_hidden_state.squeeze(0)
|
| 171 |
+
pep_emb = torch.mean(pep_emb, dim=0, keepdim=True)
|
| 172 |
+
|
| 173 |
+
score, logits = self.model.forward(self.prot_emb, pep_emb)
|
| 174 |
+
scores.append(score.item())
|
| 175 |
+
return scores
|
| 176 |
+
|
| 177 |
+
def __call__(self, input_seqs: list):
|
| 178 |
+
return self.forward(input_seqs)
|
tr2d2-pep/scoring/functions/binding_utils.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def to_var(x):
|
| 6 |
+
if torch.cuda.is_available():
|
| 7 |
+
x = x.cuda()
|
| 8 |
+
return x
|
| 9 |
+
|
| 10 |
+
class MultiHeadAttentionSequence(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
| 13 |
+
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
self.n_head = n_head
|
| 17 |
+
self.d_model = d_model
|
| 18 |
+
self.d_k = d_k
|
| 19 |
+
self.d_v = d_v
|
| 20 |
+
|
| 21 |
+
self.W_Q = nn.Linear(d_model, n_head*d_k)
|
| 22 |
+
self.W_K = nn.Linear(d_model, n_head*d_k)
|
| 23 |
+
self.W_V = nn.Linear(d_model, n_head*d_v)
|
| 24 |
+
self.W_O = nn.Linear(n_head*d_v, d_model)
|
| 25 |
+
|
| 26 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 27 |
+
|
| 28 |
+
self.dropout = nn.Dropout(dropout)
|
| 29 |
+
|
| 30 |
+
def forward(self, q, k, v):
|
| 31 |
+
|
| 32 |
+
batch, len_q, _ = q.size()
|
| 33 |
+
batch, len_k, _ = k.size()
|
| 34 |
+
batch, len_v, _ = v.size()
|
| 35 |
+
|
| 36 |
+
Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
|
| 37 |
+
K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
|
| 38 |
+
V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
|
| 39 |
+
|
| 40 |
+
Q = Q.transpose(1, 2)
|
| 41 |
+
K = K.transpose(1, 2).transpose(2, 3)
|
| 42 |
+
V = V.transpose(1, 2)
|
| 43 |
+
|
| 44 |
+
attention = torch.matmul(Q, K)
|
| 45 |
+
|
| 46 |
+
attention = attention / np.sqrt(self.d_k)
|
| 47 |
+
|
| 48 |
+
attention = F.softmax(attention, dim=-1)
|
| 49 |
+
|
| 50 |
+
output = torch.matmul(attention, V)
|
| 51 |
+
|
| 52 |
+
output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
|
| 53 |
+
|
| 54 |
+
output = self.W_O(output)
|
| 55 |
+
|
| 56 |
+
output = self.dropout(output)
|
| 57 |
+
|
| 58 |
+
output = self.layer_norm(output + q)
|
| 59 |
+
|
| 60 |
+
return output, attention
|
| 61 |
+
|
| 62 |
+
class MultiHeadAttentionReciprocal(nn.Module):
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
| 66 |
+
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.n_head = n_head
|
| 70 |
+
self.d_model = d_model
|
| 71 |
+
self.d_k = d_k
|
| 72 |
+
self.d_v = d_v
|
| 73 |
+
|
| 74 |
+
self.W_Q = nn.Linear(d_model, n_head*d_k)
|
| 75 |
+
self.W_K = nn.Linear(d_model, n_head*d_k)
|
| 76 |
+
self.W_V = nn.Linear(d_model, n_head*d_v)
|
| 77 |
+
self.W_O = nn.Linear(n_head*d_v, d_model)
|
| 78 |
+
self.W_V_2 = nn.Linear(d_model, n_head*d_v)
|
| 79 |
+
self.W_O_2 = nn.Linear(n_head*d_v, d_model)
|
| 80 |
+
|
| 81 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 82 |
+
|
| 83 |
+
self.dropout = nn.Dropout(dropout)
|
| 84 |
+
|
| 85 |
+
self.layer_norm_2 = nn.LayerNorm(d_model)
|
| 86 |
+
|
| 87 |
+
self.dropout_2 = nn.Dropout(dropout)
|
| 88 |
+
|
| 89 |
+
def forward(self, q, k, v, v_2):
|
| 90 |
+
|
| 91 |
+
batch, len_q, _ = q.size()
|
| 92 |
+
batch, len_k, _ = k.size()
|
| 93 |
+
batch, len_v, _ = v.size()
|
| 94 |
+
batch, len_v_2, _ = v_2.size()
|
| 95 |
+
|
| 96 |
+
Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
|
| 97 |
+
K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
|
| 98 |
+
V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
|
| 99 |
+
V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v])
|
| 100 |
+
|
| 101 |
+
Q = Q.transpose(1, 2)
|
| 102 |
+
K = K.transpose(1, 2).transpose(2, 3)
|
| 103 |
+
V = V.transpose(1, 2)
|
| 104 |
+
V_2 = V_2.transpose(1,2)
|
| 105 |
+
|
| 106 |
+
attention = torch.matmul(Q, K)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
attention = attention /np.sqrt(self.d_k)
|
| 110 |
+
|
| 111 |
+
attention_2 = attention.transpose(-2, -1)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
attention = F.softmax(attention, dim=-1)
|
| 116 |
+
|
| 117 |
+
attention_2 = F.softmax(attention_2, dim=-1)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
output = torch.matmul(attention, V)
|
| 121 |
+
|
| 122 |
+
output_2 = torch.matmul(attention_2, V_2)
|
| 123 |
+
|
| 124 |
+
output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
|
| 125 |
+
|
| 126 |
+
output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head])
|
| 127 |
+
|
| 128 |
+
output = self.W_O(output)
|
| 129 |
+
|
| 130 |
+
output_2 = self.W_O_2(output_2)
|
| 131 |
+
|
| 132 |
+
output = self.dropout(output)
|
| 133 |
+
|
| 134 |
+
output = self.layer_norm(output + q)
|
| 135 |
+
|
| 136 |
+
output_2 = self.dropout(output_2)
|
| 137 |
+
|
| 138 |
+
output_2 = self.layer_norm(output_2 + k)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
return output, output_2, attention, attention_2
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class FFN(nn.Module):
|
| 145 |
+
|
| 146 |
+
def __init__(self, d_in, d_hid, dropout=0.1):
|
| 147 |
+
super().__init__()
|
| 148 |
+
|
| 149 |
+
self.layer_1 = nn.Conv1d(d_in, d_hid,1)
|
| 150 |
+
self.layer_2 = nn.Conv1d(d_hid, d_in,1)
|
| 151 |
+
self.relu = nn.ReLU()
|
| 152 |
+
self.layer_norm = nn.LayerNorm(d_in)
|
| 153 |
+
|
| 154 |
+
self.dropout = nn.Dropout(dropout)
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
|
| 158 |
+
residual = x
|
| 159 |
+
output = self.layer_1(x.transpose(1, 2))
|
| 160 |
+
|
| 161 |
+
output = self.relu(output)
|
| 162 |
+
|
| 163 |
+
output = self.layer_2(output)
|
| 164 |
+
|
| 165 |
+
output = self.dropout(output)
|
| 166 |
+
|
| 167 |
+
output = self.layer_norm(output.transpose(1, 2)+residual)
|
| 168 |
+
|
| 169 |
+
return output
|
| 170 |
+
|
| 171 |
+
class ConvLayer(nn.Module):
|
| 172 |
+
def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
|
| 173 |
+
super(ConvLayer, self).__init__()
|
| 174 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
| 175 |
+
self.relu = nn.ReLU()
|
| 176 |
+
|
| 177 |
+
def forward(self, x):
|
| 178 |
+
out = self.conv(x)
|
| 179 |
+
out = self.relu(out)
|
| 180 |
+
return out
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class DilatedCNN(nn.Module):
|
| 184 |
+
def __init__(self, d_model, d_hidden):
|
| 185 |
+
super(DilatedCNN, self).__init__()
|
| 186 |
+
self.first_ = nn.ModuleList()
|
| 187 |
+
self.second_ = nn.ModuleList()
|
| 188 |
+
self.third_ = nn.ModuleList()
|
| 189 |
+
|
| 190 |
+
dilation_tuple = (1, 2, 3)
|
| 191 |
+
dim_in_tuple = (d_model, d_hidden, d_hidden)
|
| 192 |
+
dim_out_tuple = (d_hidden, d_hidden, d_hidden)
|
| 193 |
+
|
| 194 |
+
for i, dilation_rate in enumerate(dilation_tuple):
|
| 195 |
+
self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate,
|
| 196 |
+
dilation=dilation_rate))
|
| 197 |
+
|
| 198 |
+
for i, dilation_rate in enumerate(dilation_tuple):
|
| 199 |
+
self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate,
|
| 200 |
+
dilation=dilation_rate))
|
| 201 |
+
|
| 202 |
+
for i, dilation_rate in enumerate(dilation_tuple):
|
| 203 |
+
self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate,
|
| 204 |
+
dilation=dilation_rate))
|
| 205 |
+
|
| 206 |
+
def forward(self, protein_seq_enc):
|
| 207 |
+
# pdb.set_trace()
|
| 208 |
+
protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L
|
| 209 |
+
|
| 210 |
+
first_embedding = protein_seq_enc
|
| 211 |
+
second_embedding = protein_seq_enc
|
| 212 |
+
third_embedding = protein_seq_enc
|
| 213 |
+
|
| 214 |
+
for i in range(len(self.first_)):
|
| 215 |
+
first_embedding = self.first_[i](first_embedding)
|
| 216 |
+
|
| 217 |
+
for i in range(len(self.second_)):
|
| 218 |
+
second_embedding = self.second_[i](second_embedding)
|
| 219 |
+
|
| 220 |
+
for i in range(len(self.third_)):
|
| 221 |
+
third_embedding = self.third_[i](third_embedding)
|
| 222 |
+
|
| 223 |
+
# pdb.set_trace()
|
| 224 |
+
|
| 225 |
+
protein_seq_enc = first_embedding + second_embedding + third_embedding
|
| 226 |
+
|
| 227 |
+
return protein_seq_enc.transpose(1, 2)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class ReciprocalLayerwithCNN(nn.Module):
|
| 231 |
+
|
| 232 |
+
def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v):
|
| 233 |
+
super().__init__()
|
| 234 |
+
|
| 235 |
+
self.cnn = DilatedCNN(d_model, d_hidden)
|
| 236 |
+
|
| 237 |
+
self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
|
| 238 |
+
|
| 239 |
+
self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
|
| 240 |
+
|
| 241 |
+
self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, d_k, d_v)
|
| 242 |
+
|
| 243 |
+
self.ffn_seq = FFN(d_hidden, d_inner)
|
| 244 |
+
|
| 245 |
+
self.ffn_protein = FFN(d_hidden, d_inner)
|
| 246 |
+
|
| 247 |
+
def forward(self, sequence_enc, protein_seq_enc):
|
| 248 |
+
# pdb.set_trace() # protein_seq_enc.shape = B * L * d_model
|
| 249 |
+
protein_seq_enc = self.cnn(protein_seq_enc)
|
| 250 |
+
prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
|
| 251 |
+
|
| 252 |
+
seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
|
| 253 |
+
|
| 254 |
+
prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
|
| 255 |
+
|
| 256 |
+
prot_enc = self.ffn_protein(prot_enc)
|
| 257 |
+
|
| 258 |
+
seq_enc = self.ffn_seq(seq_enc)
|
| 259 |
+
|
| 260 |
+
return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class ReciprocalLayer(nn.Module):
|
| 264 |
+
|
| 265 |
+
def __init__(self, d_model, d_inner, n_head, d_k, d_v):
|
| 266 |
+
|
| 267 |
+
super().__init__()
|
| 268 |
+
|
| 269 |
+
self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
|
| 270 |
+
|
| 271 |
+
self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
|
| 272 |
+
|
| 273 |
+
self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, d_k, d_v)
|
| 274 |
+
|
| 275 |
+
self.ffn_seq = FFN(d_model, d_inner)
|
| 276 |
+
|
| 277 |
+
self.ffn_protein = FFN(d_model, d_inner)
|
| 278 |
+
|
| 279 |
+
def forward(self, sequence_enc, protein_seq_enc):
|
| 280 |
+
prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
|
| 281 |
+
|
| 282 |
+
seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
|
| 286 |
+
prot_enc = self.ffn_protein(prot_enc)
|
| 287 |
+
|
| 288 |
+
seq_enc = self.ffn_seq(seq_enc)
|
| 289 |
+
|
| 290 |
+
return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
|
tr2d2-pep/scoring/functions/classifiers/hemolysis-xgboost.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tr2d2-pep/scoring/functions/classifiers/nonfouling-xgboost.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tr2d2-pep/scoring/functions/classifiers/permeability-xgboost.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7e5d8c84bdad75f7091b5b3963133d4b0ebd180ae45654618ca6c090eee0bc06
|
| 3 |
+
size 45249160
|
tr2d2-pep/scoring/functions/classifiers/solubility-xgboost.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tr2d2-pep/scoring/functions/hemolysis.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import xgboost as xgb
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import AutoModelForMaskedLM
|
| 5 |
+
import warnings
|
| 6 |
+
import numpy as np
|
| 7 |
+
from rdkit import rdBase
|
| 8 |
+
|
| 9 |
+
rdBase.DisableLog('rdApp.error')
|
| 10 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 11 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 12 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 13 |
+
|
| 14 |
+
class Hemolysis:
|
| 15 |
+
|
| 16 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 17 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 18 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/TR2-D2/tr2d2-pep/scoring/functions/classifiers/hemolysis-xgboost.json')
|
| 19 |
+
self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
|
| 20 |
+
self.tokenizer = tokenizer
|
| 21 |
+
|
| 22 |
+
def generate_embeddings(self, sequences):
|
| 23 |
+
embeddings = []
|
| 24 |
+
for sequence in sequences:
|
| 25 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 26 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
output = self.emb_model(**tokenized)
|
| 29 |
+
# Mean pooling across sequence length
|
| 30 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 31 |
+
embeddings.append(embedding)
|
| 32 |
+
return np.array(embeddings)
|
| 33 |
+
|
| 34 |
+
def get_scores(self, input_seqs: list):
|
| 35 |
+
scores = np.ones(len(input_seqs))
|
| 36 |
+
features = self.generate_embeddings(input_seqs)
|
| 37 |
+
|
| 38 |
+
if len(features) == 0:
|
| 39 |
+
return scores
|
| 40 |
+
|
| 41 |
+
features = np.nan_to_num(features, nan=0.)
|
| 42 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 43 |
+
|
| 44 |
+
features = xgb.DMatrix(features)
|
| 45 |
+
|
| 46 |
+
probs = self.predictor.predict(features)
|
| 47 |
+
# return the probability of it being not hemolytic
|
| 48 |
+
return scores - probs
|
| 49 |
+
|
| 50 |
+
def __call__(self, input_seqs: list):
|
| 51 |
+
scores = self.get_scores(input_seqs)
|
| 52 |
+
return scores
|
| 53 |
+
|
| 54 |
+
def unittest():
|
| 55 |
+
hemo = Hemolysis()
|
| 56 |
+
seq = ["[te]NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 57 |
+
print(hemo.tokenizer.vocab_size)
|
| 58 |
+
scores = hemo(input_seqs=seq)
|
| 59 |
+
print(scores)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
unittest()
|
tr2d2-pep/scoring/functions/nonfouling.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import xgboost as xgb
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import AutoModelForMaskedLM
|
| 7 |
+
import warnings
|
| 8 |
+
import numpy as np
|
| 9 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
rdBase.DisableLog('rdApp.error')
|
| 13 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 14 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 15 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 16 |
+
|
| 17 |
+
class Nonfouling:
|
| 18 |
+
|
| 19 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 20 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 21 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/TR2-D2/tr2d2-pep/scoring/functions/classifiers/nonfouling-xgboost.json')
|
| 22 |
+
self.emb_model = emb_model if emb_model is not None else AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
|
| 23 |
+
self.tokenizer = tokenizer
|
| 24 |
+
|
| 25 |
+
def generate_embeddings(self, sequences):
|
| 26 |
+
embeddings = []
|
| 27 |
+
for sequence in sequences:
|
| 28 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 29 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
output = self.emb_model(**tokenized)
|
| 32 |
+
# Mean pooling across sequence length
|
| 33 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 34 |
+
embeddings.append(embedding)
|
| 35 |
+
return np.array(embeddings)
|
| 36 |
+
|
| 37 |
+
def get_scores(self, input_seqs: list):
|
| 38 |
+
scores = np.zeros(len(input_seqs))
|
| 39 |
+
features = self.generate_embeddings(input_seqs)
|
| 40 |
+
|
| 41 |
+
if len(features) == 0:
|
| 42 |
+
return scores
|
| 43 |
+
|
| 44 |
+
features = np.nan_to_num(features, nan=0.)
|
| 45 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 46 |
+
|
| 47 |
+
features = xgb.DMatrix(features)
|
| 48 |
+
|
| 49 |
+
scores = self.predictor.predict(features)
|
| 50 |
+
# return the probability of it being not hemolytic
|
| 51 |
+
return scores
|
| 52 |
+
|
| 53 |
+
def __call__(self, input_seqs: list):
|
| 54 |
+
scores = self.get_scores(input_seqs)
|
| 55 |
+
return scores
|
| 56 |
+
|
| 57 |
+
def unittest():
|
| 58 |
+
nf = Nonfouling()
|
| 59 |
+
seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 60 |
+
|
| 61 |
+
scores = nf(input_seqs=seq)
|
| 62 |
+
print(scores)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == '__main__':
|
| 66 |
+
unittest()
|
tr2d2-pep/scoring/functions/permeability.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import xgboost as xgb
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import AutoModelForMaskedLM
|
| 7 |
+
import warnings
|
| 8 |
+
import numpy as np
|
| 9 |
+
from rdkit.Chem import Descriptors, rdMolDescriptors
|
| 10 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 11 |
+
from rdkit.Chem import AllChem
|
| 12 |
+
from typing import List
|
| 13 |
+
from transformers import AutoModelForMaskedLM
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
rdBase.DisableLog('rdApp.error')
|
| 17 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 18 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 19 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 20 |
+
|
| 21 |
+
def fingerprints_from_smiles(smiles: List, size=2048):
|
| 22 |
+
""" Create ECFP fingerprints of smiles, with validity check """
|
| 23 |
+
fps = []
|
| 24 |
+
valid_mask = []
|
| 25 |
+
for i, smile in enumerate(smiles):
|
| 26 |
+
mol = Chem.MolFromSmiles(smile)
|
| 27 |
+
valid_mask.append(int(mol is not None))
|
| 28 |
+
fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
|
| 29 |
+
fps.append(fp)
|
| 30 |
+
|
| 31 |
+
fps = np.concatenate(fps, axis=0)
|
| 32 |
+
return fps, valid_mask
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
|
| 36 |
+
""" Create ECFP fingerprint of a molecule """
|
| 37 |
+
if hashed:
|
| 38 |
+
fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
|
| 39 |
+
else:
|
| 40 |
+
fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
|
| 41 |
+
fp_np = np.zeros((1,))
|
| 42 |
+
DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
|
| 43 |
+
return fp_np.reshape(1, -1)
|
| 44 |
+
|
| 45 |
+
def getMolDescriptors(mol, missingVal=0):
|
| 46 |
+
""" calculate the full list of descriptors for a molecule """
|
| 47 |
+
|
| 48 |
+
values, names = [], []
|
| 49 |
+
for nm, fn in Descriptors._descList:
|
| 50 |
+
try:
|
| 51 |
+
val = fn(mol)
|
| 52 |
+
except:
|
| 53 |
+
val = missingVal
|
| 54 |
+
values.append(val)
|
| 55 |
+
names.append(nm)
|
| 56 |
+
|
| 57 |
+
custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
|
| 58 |
+
'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
|
| 59 |
+
'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
|
| 60 |
+
|
| 61 |
+
for nm, fn in custom_descriptors.items():
|
| 62 |
+
try:
|
| 63 |
+
val = fn(mol)
|
| 64 |
+
except:
|
| 65 |
+
val = missingVal
|
| 66 |
+
values.append(val)
|
| 67 |
+
names.append(nm)
|
| 68 |
+
return values, names
|
| 69 |
+
|
| 70 |
+
def get_pep_dps_from_smi(smi):
|
| 71 |
+
try:
|
| 72 |
+
mol = Chem.MolFromSmiles(smi)
|
| 73 |
+
except:
|
| 74 |
+
print(f"convert smi {smi} to molecule failed!")
|
| 75 |
+
mol = None
|
| 76 |
+
|
| 77 |
+
dps, _ = getMolDescriptors(mol)
|
| 78 |
+
return np.array(dps)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_pep_dps(smi_list):
|
| 82 |
+
if len(smi_list) == 0:
|
| 83 |
+
return np.zeros((0, 213))
|
| 84 |
+
return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
|
| 85 |
+
|
| 86 |
+
def check_smi_validity(smiles: list):
|
| 87 |
+
valid_smi, valid_idx = [], []
|
| 88 |
+
for idx, smi in enumerate(smiles):
|
| 89 |
+
try:
|
| 90 |
+
mol = Chem.MolFromSmiles(smi) if smi else None
|
| 91 |
+
if mol:
|
| 92 |
+
valid_smi.append(smi)
|
| 93 |
+
valid_idx.append(idx)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
# logger.debug(f'Error: {e} in smiles {smi}')
|
| 96 |
+
pass
|
| 97 |
+
return valid_smi, valid_idx
|
| 98 |
+
|
| 99 |
+
class Permeability:
|
| 100 |
+
|
| 101 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 102 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 103 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/TR2-D2/tr2d2-pep/scoring/functions/classifiers/permeability-xgboost.json')
|
| 104 |
+
if emb_model is not None:
|
| 105 |
+
self.emb_model = emb_model.to(self.device).eval()
|
| 106 |
+
else:
|
| 107 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
|
| 108 |
+
|
| 109 |
+
self.tokenizer = tokenizer
|
| 110 |
+
|
| 111 |
+
def generate_embeddings(self, sequences):
|
| 112 |
+
embeddings = []
|
| 113 |
+
for sequence in sequences:
|
| 114 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 115 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
output = self.emb_model(**tokenized)
|
| 118 |
+
# Mean pooling across sequence length
|
| 119 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 120 |
+
embeddings.append(embedding)
|
| 121 |
+
return np.array(embeddings)
|
| 122 |
+
|
| 123 |
+
def get_features(self, input_seqs: list, dps=False, fps=False):
|
| 124 |
+
#valid_smiles, valid_idxes = check_smi_validity(input_seqs)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if fps:
|
| 128 |
+
fingerprints = fingerprints_from_smiles(input_seqs)[0]
|
| 129 |
+
else:
|
| 130 |
+
fingerprints = torch.empty((len(input_seqs), 0))
|
| 131 |
+
|
| 132 |
+
if dps:
|
| 133 |
+
descriptors = get_pep_dps(input_seqs)
|
| 134 |
+
else:
|
| 135 |
+
descriptors = torch.empty((len(input_seqs), 0))
|
| 136 |
+
|
| 137 |
+
embeddings = self.generate_embeddings(input_seqs)
|
| 138 |
+
# logger.debug(f'X_fps.shape: {X_fps.shape}, X_dps.shape: {X_dps.shape}')
|
| 139 |
+
|
| 140 |
+
features = np.concatenate([fingerprints, descriptors, embeddings], axis=1)
|
| 141 |
+
|
| 142 |
+
return features
|
| 143 |
+
|
| 144 |
+
def get_scores(self, input_seqs: list):
|
| 145 |
+
scores = -10 * np.ones(len(input_seqs))
|
| 146 |
+
features = self.get_features(input_seqs)
|
| 147 |
+
|
| 148 |
+
if len(features) == 0:
|
| 149 |
+
return scores
|
| 150 |
+
|
| 151 |
+
features = np.nan_to_num(features, nan=0.)
|
| 152 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 153 |
+
|
| 154 |
+
features = xgb.DMatrix(features)
|
| 155 |
+
|
| 156 |
+
scores = self.predictor.predict(features)
|
| 157 |
+
return scores
|
| 158 |
+
|
| 159 |
+
def __call__(self, input_seqs: list):
|
| 160 |
+
scores = self.get_scores(input_seqs)
|
| 161 |
+
return scores
|
| 162 |
+
|
| 163 |
+
def unittest():
|
| 164 |
+
permeability = Permeability()
|
| 165 |
+
seq = ['N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1cNc2c1cc(O)cc2)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCCNC(=N)N)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H]([C@@H](O)C(C)C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O']
|
| 166 |
+
scores = permeability(input_seqs=seq)
|
| 167 |
+
print(scores)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
if __name__ == '__main__':
|
| 171 |
+
unittest()
|
tr2d2-pep/scoring/functions/scoring_utils.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import numpy as np
|
| 3 |
+
from loguru import logger
|
| 4 |
+
from sklearn.ensemble import RandomForestRegressor
|
| 5 |
+
from rdkit.Chem import Descriptors, rdMolDescriptors
|
| 6 |
+
import joblib
|
| 7 |
+
from rdkit import Chem, rdBase, DataStructs
|
| 8 |
+
from rdkit.Chem import AllChem
|
| 9 |
+
from typing import List
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False):
|
| 13 |
+
"""
|
| 14 |
+
Create ECFP fingerprint of a molecule
|
| 15 |
+
"""
|
| 16 |
+
if hashed:
|
| 17 |
+
fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size)
|
| 18 |
+
else:
|
| 19 |
+
fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size)
|
| 20 |
+
fp_np = np.zeros((1,))
|
| 21 |
+
DataStructs.ConvertToNumpyArray(fp_bits, fp_np)
|
| 22 |
+
return fp_np.reshape(1, -1)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def fingerprints_from_smiles(smiles: List, size=2048):
|
| 26 |
+
""" Create ECFP fingerprints of smiles, with validity check """
|
| 27 |
+
fps = []
|
| 28 |
+
valid_mask = []
|
| 29 |
+
for i, smile in enumerate(smiles):
|
| 30 |
+
mol = Chem.MolFromSmiles(smile)
|
| 31 |
+
valid_mask.append(int(mol is not None))
|
| 32 |
+
fp = fingerprints_from_mol(mol, size=size) if mol else np.zeros((1, size))
|
| 33 |
+
fps.append(fp)
|
| 34 |
+
|
| 35 |
+
fps = np.concatenate(fps, axis=0) if len(fps) > 0 else np.zeros((0, size))
|
| 36 |
+
return fps, valid_mask
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def getMolDescriptors(mol, missingVal=0):
|
| 40 |
+
""" calculate the full list of descriptors for a molecule """
|
| 41 |
+
|
| 42 |
+
values, names = [], []
|
| 43 |
+
for nm, fn in Descriptors._descList:
|
| 44 |
+
try:
|
| 45 |
+
val = fn(mol)
|
| 46 |
+
except:
|
| 47 |
+
val = missingVal
|
| 48 |
+
values.append(val)
|
| 49 |
+
names.append(nm)
|
| 50 |
+
|
| 51 |
+
custom_descriptors = {'hydrogen-bond donors': rdMolDescriptors.CalcNumLipinskiHBD,
|
| 52 |
+
'hydrogen-bond acceptors': rdMolDescriptors.CalcNumLipinskiHBA,
|
| 53 |
+
'rotatable bonds': rdMolDescriptors.CalcNumRotatableBonds,}
|
| 54 |
+
|
| 55 |
+
for nm, fn in custom_descriptors.items():
|
| 56 |
+
try:
|
| 57 |
+
val = fn(mol)
|
| 58 |
+
except:
|
| 59 |
+
val = missingVal
|
| 60 |
+
values.append(val)
|
| 61 |
+
names.append(nm)
|
| 62 |
+
return values, names
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_pep_dps_from_smi(smi):
|
| 66 |
+
try:
|
| 67 |
+
mol = Chem.MolFromSmiles(smi)
|
| 68 |
+
except:
|
| 69 |
+
print(f"convert smi {smi} to molecule failed!")
|
| 70 |
+
mol = None
|
| 71 |
+
|
| 72 |
+
dps, _ = getMolDescriptors(mol)
|
| 73 |
+
return np.array(dps)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_pep_dps(smi_list):
|
| 77 |
+
if len(smi_list) == 0:
|
| 78 |
+
return np.zeros((0, 211))
|
| 79 |
+
return np.array([get_pep_dps_from_smi(smi) for smi in smi_list])
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def check_smi_validity(smiles: list):
|
| 84 |
+
valid_smi, valid_idx = [], []
|
| 85 |
+
for idx, smi in enumerate(smiles):
|
| 86 |
+
try:
|
| 87 |
+
mol = Chem.MolFromSmiles(smi) if smi else None
|
| 88 |
+
if mol:
|
| 89 |
+
valid_smi.append(smi)
|
| 90 |
+
valid_idx.append(idx)
|
| 91 |
+
except Exception as e:
|
| 92 |
+
# logger.debug(f'Error: {e} in smiles {smi}')
|
| 93 |
+
pass
|
| 94 |
+
return valid_smi, valid_idx
|
tr2d2-pep/scoring/functions/solubility.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import xgboost as xgb
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import AutoModelForMaskedLM
|
| 5 |
+
import warnings
|
| 6 |
+
import numpy as np
|
| 7 |
+
from rdkit import rdBase
|
| 8 |
+
|
| 9 |
+
rdBase.DisableLog('rdApp.error')
|
| 10 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
| 11 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 12 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 13 |
+
|
| 14 |
+
class Solubility:
|
| 15 |
+
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
|
| 16 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 17 |
+
self.predictor = xgb.Booster(model_file=f'{base_path}/TR2-D2/tr2d2-pep/scoring/functions/classifiers/solubility-xgboost.json')
|
| 18 |
+
if emb_model is not None:
|
| 19 |
+
self.emb_model = emb_model.to(self.device).eval()
|
| 20 |
+
else:
|
| 21 |
+
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
|
| 22 |
+
|
| 23 |
+
self.tokenizer = tokenizer
|
| 24 |
+
|
| 25 |
+
def generate_embeddings(self, sequences):
|
| 26 |
+
embeddings = []
|
| 27 |
+
for sequence in sequences:
|
| 28 |
+
tokenized = self.tokenizer(sequence, return_tensors='pt')
|
| 29 |
+
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
output = self.emb_model(**tokenized)
|
| 32 |
+
# Mean pooling across sequence length
|
| 33 |
+
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
|
| 34 |
+
embeddings.append(embedding)
|
| 35 |
+
return np.array(embeddings)
|
| 36 |
+
|
| 37 |
+
def get_scores(self, input_seqs: list):
|
| 38 |
+
scores = np.zeros(len(input_seqs))
|
| 39 |
+
features = self.generate_embeddings(input_seqs)
|
| 40 |
+
|
| 41 |
+
if len(features) == 0:
|
| 42 |
+
return scores
|
| 43 |
+
|
| 44 |
+
features = np.nan_to_num(features, nan=0.)
|
| 45 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 46 |
+
|
| 47 |
+
features = xgb.DMatrix(features)
|
| 48 |
+
|
| 49 |
+
scores = self.predictor.predict(features)
|
| 50 |
+
return scores
|
| 51 |
+
|
| 52 |
+
def __call__(self, input_seqs: list):
|
| 53 |
+
scores = self.get_scores(input_seqs)
|
| 54 |
+
return scores
|
| 55 |
+
|
| 56 |
+
def unittest():
|
| 57 |
+
solubility = Solubility()
|
| 58 |
+
seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
|
| 59 |
+
scores = solubility(input_seqs=seq)
|
| 60 |
+
print(scores)
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
unittest()
|
tr2d2-pep/scoring/scoring_functions.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 2 |
+
from transformers import AutoModelForMaskedLM
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scoring.functions.binding import BindingAffinity
|
| 5 |
+
from scoring.functions.permeability import Permeability
|
| 6 |
+
from scoring.functions.solubility import Solubility
|
| 7 |
+
from scoring.functions.hemolysis import Hemolysis
|
| 8 |
+
from scoring.functions.nonfouling import Nonfouling
|
| 9 |
+
|
| 10 |
+
base_path = '/path/to/your/home'
|
| 11 |
+
|
| 12 |
+
class ScoringFunctions:
|
| 13 |
+
def __init__(self, score_func_names=None, prot_seqs=None, device=None):
|
| 14 |
+
"""
|
| 15 |
+
Class for generating score vectors given generated sequence
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
score_func_names: list of scoring function names to be evaluated
|
| 19 |
+
score_weights: weights to scale scores (default: 1)
|
| 20 |
+
target_protein: sequence of target protein binder
|
| 21 |
+
"""
|
| 22 |
+
emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device).eval()
|
| 23 |
+
tokenizer = SMILES_SPE_Tokenizer(f'{base_path}/TR2-D2/tr2d2-pep/tokenizer/new_vocab.txt',
|
| 24 |
+
f'{base_path}/TR2-D2/tr2d2-pep/tokenizer/new_splits.txt')
|
| 25 |
+
prot_seqs = prot_seqs if prot_seqs is not None else []
|
| 26 |
+
|
| 27 |
+
if score_func_names is None:
|
| 28 |
+
# just do unmasking based on validity of peptide bonds
|
| 29 |
+
self.score_func_names = []
|
| 30 |
+
else:
|
| 31 |
+
self.score_func_names = score_func_names
|
| 32 |
+
|
| 33 |
+
# self.weights = np.array([1] * len(self.score_func_names) if score_weights is None else score_weights)
|
| 34 |
+
|
| 35 |
+
# binding affinities
|
| 36 |
+
self.target_protein = prot_seqs
|
| 37 |
+
print(len(prot_seqs))
|
| 38 |
+
|
| 39 |
+
if ('binding_affinity1' in score_func_names) and (len(prot_seqs) == 1):
|
| 40 |
+
binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
|
| 41 |
+
binding_affinity2 = None
|
| 42 |
+
elif ('binding_affinity1' in score_func_names) and ('binding_affinity2' in score_func_names) and (len(prot_seqs) == 2):
|
| 43 |
+
binding_affinity1 = BindingAffinity(prot_seqs[0], tokenizer=tokenizer, base_path=base_path, device=device)
|
| 44 |
+
binding_affinity2 = BindingAffinity(prot_seqs[1], tokenizer=tokenizer, base_path=base_path, device=device)
|
| 45 |
+
else:
|
| 46 |
+
print("here")
|
| 47 |
+
binding_affinity1 = None
|
| 48 |
+
binding_affinity2 = None
|
| 49 |
+
|
| 50 |
+
permeability = Permeability(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 51 |
+
sol = Solubility(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 52 |
+
nonfouling = Nonfouling(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 53 |
+
hemo = Hemolysis(tokenizer=tokenizer, base_path=base_path, device=device, emb_model=emb_model)
|
| 54 |
+
|
| 55 |
+
self.all_funcs = {'binding_affinity1': binding_affinity1,
|
| 56 |
+
'binding_affinity2': binding_affinity2,
|
| 57 |
+
'permeability': permeability,
|
| 58 |
+
'nonfouling': nonfouling,
|
| 59 |
+
'solubility': sol,
|
| 60 |
+
'hemolysis': hemo
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
def forward(self, input_seqs):
|
| 64 |
+
scores = []
|
| 65 |
+
|
| 66 |
+
for i, score_func in enumerate(self.score_func_names):
|
| 67 |
+
score = self.all_funcs[score_func](input_seqs = input_seqs)
|
| 68 |
+
|
| 69 |
+
scores.append(score)
|
| 70 |
+
|
| 71 |
+
# convert to numpy arrays with shape (num_sequences, num_functions)
|
| 72 |
+
scores = np.float32(scores).T
|
| 73 |
+
|
| 74 |
+
return scores
|
| 75 |
+
|
| 76 |
+
def __call__(self, input_seqs: list):
|
| 77 |
+
return self.forward(input_seqs)
|
tr2d2-pep/scoring/tokenizer/my_tokenizers.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
from transformers import PreTrainedTokenizer
|
| 6 |
+
from SmilesPE.tokenizer import SPE_Tokenizer
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
def load_vocab(vocab_file):
|
| 10 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 11 |
+
vocab = collections.OrderedDict()
|
| 12 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 13 |
+
tokens = reader.readlines()
|
| 14 |
+
for index, token in enumerate(tokens):
|
| 15 |
+
token = token.rstrip("\n")
|
| 16 |
+
vocab[token] = index
|
| 17 |
+
return vocab
|
| 18 |
+
|
| 19 |
+
class Atomwise_Tokenizer(object):
|
| 20 |
+
"""Run atom-level SMILES tokenization"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
""" Constructs a atom-level Tokenizer.
|
| 24 |
+
"""
|
| 25 |
+
# self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 26 |
+
self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 27 |
+
|
| 28 |
+
self.regex = re.compile(self.regex_pattern)
|
| 29 |
+
|
| 30 |
+
def tokenize(self, text):
|
| 31 |
+
""" Basic Tokenization of a SMILES.
|
| 32 |
+
"""
|
| 33 |
+
tokens = [token for token in self.regex.findall(text)]
|
| 34 |
+
return tokens
|
| 35 |
+
|
| 36 |
+
class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
|
| 37 |
+
r"""
|
| 38 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 39 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 40 |
+
should refer to the superclass for more information regarding methods.
|
| 41 |
+
Args:
|
| 42 |
+
vocab_file (:obj:`string`):
|
| 43 |
+
File containing the vocabulary.
|
| 44 |
+
spe_file (:obj:`string`):
|
| 45 |
+
File containing the trained SMILES Pair Encoding vocabulary.
|
| 46 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 47 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 48 |
+
token instead.
|
| 49 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 50 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 51 |
+
for sequence classification or for a text and a question for question answering.
|
| 52 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 53 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 54 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 55 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 56 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 57 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 58 |
+
special tokens.
|
| 59 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 60 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 61 |
+
modeling. This is the token which the model will try to predict.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, vocab_file, spe_file,
|
| 65 |
+
unk_token="[UNK]",
|
| 66 |
+
sep_token="[SEP]",
|
| 67 |
+
pad_token="[PAD]",
|
| 68 |
+
cls_token="[CLS]",
|
| 69 |
+
mask_token="[MASK]",
|
| 70 |
+
**kwargs):
|
| 71 |
+
if not os.path.isfile(vocab_file):
|
| 72 |
+
raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
|
| 73 |
+
if not os.path.isfile(spe_file):
|
| 74 |
+
raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
|
| 75 |
+
|
| 76 |
+
self.vocab = load_vocab(vocab_file)
|
| 77 |
+
self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
|
| 78 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 79 |
+
self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
|
| 80 |
+
|
| 81 |
+
super().__init__(
|
| 82 |
+
unk_token=unk_token,
|
| 83 |
+
sep_token=sep_token,
|
| 84 |
+
pad_token=pad_token,
|
| 85 |
+
cls_token=cls_token,
|
| 86 |
+
mask_token=mask_token,
|
| 87 |
+
**kwargs)
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def vocab_size(self):
|
| 91 |
+
return len(self.vocab)
|
| 92 |
+
|
| 93 |
+
def get_vocab(self):
|
| 94 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 95 |
+
|
| 96 |
+
def _tokenize(self, text):
|
| 97 |
+
return self.spe_tokenizer.tokenize(text).split(' ')
|
| 98 |
+
|
| 99 |
+
def _convert_token_to_id(self, token):
|
| 100 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 101 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 102 |
+
|
| 103 |
+
# changed encode and decode functions
|
| 104 |
+
def encode(self, token_array):
|
| 105 |
+
token_ids = []
|
| 106 |
+
token_ids.append(2)
|
| 107 |
+
for token in token_array:
|
| 108 |
+
id = self._convert_token_to_id(token)
|
| 109 |
+
token_ids.append(id)
|
| 110 |
+
token_ids.append(3)
|
| 111 |
+
token_ids = torch.tensor([token_ids])
|
| 112 |
+
attn_mask = torch.ones_like(token_ids)
|
| 113 |
+
return {'input_ids': token_ids, 'attention_mask': attn_mask}
|
| 114 |
+
|
| 115 |
+
def decode(self, token_ids, skip_special_tokens=True):
|
| 116 |
+
token_ids = token_ids.squeeze(0).cpu().tolist()
|
| 117 |
+
token_array = []
|
| 118 |
+
for idx in token_ids:
|
| 119 |
+
if idx == 3: # Stop decoding when token ID 3 is encountered
|
| 120 |
+
break
|
| 121 |
+
if skip_special_tokens and idx in self.all_special_ids:
|
| 122 |
+
continue
|
| 123 |
+
token = self._convert_id_to_token(idx)
|
| 124 |
+
token_array.append(token)
|
| 125 |
+
sequence = "".join(token_array)
|
| 126 |
+
return sequence
|
| 127 |
+
|
| 128 |
+
def batch_decode(self, batch_token_ids, skip_special_tokens=True):
|
| 129 |
+
sequences = []
|
| 130 |
+
for token_ids in batch_token_ids:
|
| 131 |
+
sequences.append(self.decode(token_ids))
|
| 132 |
+
return sequences
|
| 133 |
+
|
| 134 |
+
def get_token_split(self, token_ids):
|
| 135 |
+
if isinstance(token_ids, torch.Tensor):
|
| 136 |
+
token_ids = token_ids.cpu().tolist()
|
| 137 |
+
|
| 138 |
+
token_array = []
|
| 139 |
+
for seq_ids in token_ids:
|
| 140 |
+
seq_array = []
|
| 141 |
+
for id in seq_ids:
|
| 142 |
+
token = self._convert_id_to_token(id)
|
| 143 |
+
seq_array.append(token)
|
| 144 |
+
token_array.append(seq_array)
|
| 145 |
+
|
| 146 |
+
return token_array
|
| 147 |
+
|
| 148 |
+
def _convert_id_to_token(self, index):
|
| 149 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 150 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 151 |
+
|
| 152 |
+
def convert_tokens_to_string(self, tokens):
|
| 153 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 154 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 155 |
+
return out_string
|
| 156 |
+
|
| 157 |
+
def build_inputs_with_special_tokens(
|
| 158 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 159 |
+
) -> List[int]:
|
| 160 |
+
"""
|
| 161 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 162 |
+
by concatenating and adding special tokens.
|
| 163 |
+
A BERT sequence has the following format:
|
| 164 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 165 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 166 |
+
Args:
|
| 167 |
+
token_ids_0 (:obj:`List[int]`):
|
| 168 |
+
List of IDs to which the special tokens will be added
|
| 169 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 170 |
+
Optional second list of IDs for sequence pairs.
|
| 171 |
+
Returns:
|
| 172 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 173 |
+
"""
|
| 174 |
+
if token_ids_1 is None:
|
| 175 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 176 |
+
cls = [self.cls_token_id]
|
| 177 |
+
sep = [self.sep_token_id]
|
| 178 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 179 |
+
|
| 180 |
+
def get_special_tokens_mask(
|
| 181 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 182 |
+
) -> List[int]:
|
| 183 |
+
"""
|
| 184 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 185 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 186 |
+
Args:
|
| 187 |
+
token_ids_0 (:obj:`List[int]`):
|
| 188 |
+
List of ids.
|
| 189 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 190 |
+
Optional second list of IDs for sequence pairs.
|
| 191 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 192 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 193 |
+
Returns:
|
| 194 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
if already_has_special_tokens:
|
| 198 |
+
if token_ids_1 is not None:
|
| 199 |
+
raise ValueError(
|
| 200 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 201 |
+
"ids is already formated with special tokens for the model."
|
| 202 |
+
)
|
| 203 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 204 |
+
|
| 205 |
+
if token_ids_1 is not None:
|
| 206 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 207 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 208 |
+
|
| 209 |
+
def create_token_type_ids_from_sequences(
|
| 210 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 211 |
+
) -> List[int]:
|
| 212 |
+
"""
|
| 213 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 214 |
+
A BERT sequence pair mask has the following format:
|
| 215 |
+
::
|
| 216 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 217 |
+
| first sequence | second sequence |
|
| 218 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 219 |
+
Args:
|
| 220 |
+
token_ids_0 (:obj:`List[int]`):
|
| 221 |
+
List of ids.
|
| 222 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 223 |
+
Optional second list of IDs for sequence pairs.
|
| 224 |
+
Returns:
|
| 225 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 226 |
+
sequence(s).
|
| 227 |
+
"""
|
| 228 |
+
sep = [self.sep_token_id]
|
| 229 |
+
cls = [self.cls_token_id]
|
| 230 |
+
if token_ids_1 is None:
|
| 231 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 232 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 233 |
+
|
| 234 |
+
def save_vocabulary(self, vocab_path):
|
| 235 |
+
"""
|
| 236 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 237 |
+
Args:
|
| 238 |
+
vocab_path (:obj:`str`):
|
| 239 |
+
The directory in which to save the vocabulary.
|
| 240 |
+
Returns:
|
| 241 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 242 |
+
"""
|
| 243 |
+
index = 0
|
| 244 |
+
vocab_file = vocab_path
|
| 245 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 246 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 247 |
+
if index != token_index:
|
| 248 |
+
index = token_index
|
| 249 |
+
writer.write(token + "\n")
|
| 250 |
+
index += 1
|
| 251 |
+
return (vocab_file,)
|
| 252 |
+
|
| 253 |
+
class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
|
| 254 |
+
r"""
|
| 255 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 256 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 257 |
+
should refer to the superclass for more information regarding methods.
|
| 258 |
+
Args:
|
| 259 |
+
vocab_file (:obj:`string`):
|
| 260 |
+
File containing the vocabulary.
|
| 261 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 262 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 263 |
+
token instead.
|
| 264 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 265 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 266 |
+
for sequence classification or for a text and a question for question answering.
|
| 267 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 268 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 269 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 270 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 271 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 272 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 273 |
+
special tokens.
|
| 274 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 275 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 276 |
+
modeling. This is the token which the model will try to predict.
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
def __init__(
|
| 280 |
+
self,
|
| 281 |
+
vocab_file,
|
| 282 |
+
unk_token="[UNK]",
|
| 283 |
+
sep_token="[SEP]",
|
| 284 |
+
pad_token="[PAD]",
|
| 285 |
+
cls_token="[CLS]",
|
| 286 |
+
mask_token="[MASK]",
|
| 287 |
+
**kwargs
|
| 288 |
+
):
|
| 289 |
+
super().__init__(
|
| 290 |
+
unk_token=unk_token,
|
| 291 |
+
sep_token=sep_token,
|
| 292 |
+
pad_token=pad_token,
|
| 293 |
+
cls_token=cls_token,
|
| 294 |
+
mask_token=mask_token,
|
| 295 |
+
**kwargs,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if not os.path.isfile(vocab_file):
|
| 299 |
+
raise ValueError(
|
| 300 |
+
"Can't find a vocabulary file at path '{}'.".format(vocab_file)
|
| 301 |
+
)
|
| 302 |
+
self.vocab = load_vocab(vocab_file)
|
| 303 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 304 |
+
self.tokenizer = Atomwise_Tokenizer()
|
| 305 |
+
|
| 306 |
+
@property
|
| 307 |
+
def vocab_size(self):
|
| 308 |
+
return len(self.vocab)
|
| 309 |
+
|
| 310 |
+
def get_vocab(self):
|
| 311 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def _tokenize(self, text):
|
| 315 |
+
return self.tokenizer.tokenize(text)
|
| 316 |
+
|
| 317 |
+
def _convert_token_to_id(self, token):
|
| 318 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 319 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 320 |
+
|
| 321 |
+
def _convert_id_to_token(self, index):
|
| 322 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 323 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 324 |
+
|
| 325 |
+
def convert_tokens_to_string(self, tokens):
|
| 326 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 327 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 328 |
+
return out_string
|
| 329 |
+
|
| 330 |
+
def build_inputs_with_special_tokens(
|
| 331 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 332 |
+
) -> List[int]:
|
| 333 |
+
"""
|
| 334 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 335 |
+
by concatenating and adding special tokens.
|
| 336 |
+
A BERT sequence has the following format:
|
| 337 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 338 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 339 |
+
Args:
|
| 340 |
+
token_ids_0 (:obj:`List[int]`):
|
| 341 |
+
List of IDs to which the special tokens will be added
|
| 342 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 343 |
+
Optional second list of IDs for sequence pairs.
|
| 344 |
+
Returns:
|
| 345 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 346 |
+
"""
|
| 347 |
+
if token_ids_1 is None:
|
| 348 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 349 |
+
cls = [self.cls_token_id]
|
| 350 |
+
sep = [self.sep_token_id]
|
| 351 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 352 |
+
|
| 353 |
+
def get_special_tokens_mask(
|
| 354 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 355 |
+
) -> List[int]:
|
| 356 |
+
"""
|
| 357 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 358 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 359 |
+
Args:
|
| 360 |
+
token_ids_0 (:obj:`List[int]`):
|
| 361 |
+
List of ids.
|
| 362 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 363 |
+
Optional second list of IDs for sequence pairs.
|
| 364 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 365 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 366 |
+
Returns:
|
| 367 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 368 |
+
"""
|
| 369 |
+
|
| 370 |
+
if already_has_special_tokens:
|
| 371 |
+
if token_ids_1 is not None:
|
| 372 |
+
raise ValueError(
|
| 373 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 374 |
+
"ids is already formated with special tokens for the model."
|
| 375 |
+
)
|
| 376 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 377 |
+
|
| 378 |
+
if token_ids_1 is not None:
|
| 379 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 380 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 381 |
+
|
| 382 |
+
def create_token_type_ids_from_sequences(
|
| 383 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 384 |
+
) -> List[int]:
|
| 385 |
+
"""
|
| 386 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 387 |
+
A BERT sequence pair mask has the following format:
|
| 388 |
+
::
|
| 389 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 390 |
+
| first sequence | second sequence |
|
| 391 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 392 |
+
Args:
|
| 393 |
+
token_ids_0 (:obj:`List[int]`):
|
| 394 |
+
List of ids.
|
| 395 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 396 |
+
Optional second list of IDs for sequence pairs.
|
| 397 |
+
Returns:
|
| 398 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 399 |
+
sequence(s).
|
| 400 |
+
"""
|
| 401 |
+
sep = [self.sep_token_id]
|
| 402 |
+
cls = [self.cls_token_id]
|
| 403 |
+
if token_ids_1 is None:
|
| 404 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 405 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 406 |
+
|
| 407 |
+
def save_vocabulary(self, vocab_path):
|
| 408 |
+
"""
|
| 409 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 410 |
+
Args:
|
| 411 |
+
vocab_path (:obj:`str`):
|
| 412 |
+
The directory in which to save the vocabulary.
|
| 413 |
+
Returns:
|
| 414 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 415 |
+
"""
|
| 416 |
+
index = 0
|
| 417 |
+
vocab_file = vocab_path
|
| 418 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 419 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 420 |
+
if index != token_index:
|
| 421 |
+
index = token_index
|
| 422 |
+
writer.write(token + "\n")
|
| 423 |
+
index += 1
|
| 424 |
+
return (vocab_file,)
|
tr2d2-pep/scoring/tokenizer/new_splits.txt
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
c 1
|
| 2 |
+
c 2
|
| 3 |
+
c 3
|
| 4 |
+
c 4
|
| 5 |
+
c 5
|
| 6 |
+
c 6
|
| 7 |
+
c 7
|
| 8 |
+
c 8
|
| 9 |
+
c 9
|
| 10 |
+
( c1
|
| 11 |
+
( c2
|
| 12 |
+
c1 )
|
| 13 |
+
c2 )
|
| 14 |
+
n 1
|
| 15 |
+
n 2
|
| 16 |
+
n 3
|
| 17 |
+
n 4
|
| 18 |
+
n 5
|
| 19 |
+
n 6
|
| 20 |
+
n 7
|
| 21 |
+
n 8
|
| 22 |
+
n 9
|
| 23 |
+
( n1
|
| 24 |
+
( n2
|
| 25 |
+
n1 )
|
| 26 |
+
n2 )
|
| 27 |
+
O 1
|
| 28 |
+
O 2
|
| 29 |
+
O 3
|
| 30 |
+
O 4
|
| 31 |
+
O 5
|
| 32 |
+
O 6
|
| 33 |
+
O 7
|
| 34 |
+
O 8
|
| 35 |
+
O 9
|
| 36 |
+
( O1
|
| 37 |
+
( O2
|
| 38 |
+
O2 )
|
| 39 |
+
O2 )
|
| 40 |
+
= O
|
| 41 |
+
= C
|
| 42 |
+
= c
|
| 43 |
+
= N
|
| 44 |
+
= n
|
| 45 |
+
=C C
|
| 46 |
+
=C N
|
| 47 |
+
=C c
|
| 48 |
+
=c c
|
| 49 |
+
=N C
|
| 50 |
+
=N c
|
| 51 |
+
=n C
|
| 52 |
+
=n c
|
| 53 |
+
# N
|
| 54 |
+
# C
|
| 55 |
+
#N C
|
| 56 |
+
#C C
|
| 57 |
+
#C N
|
| 58 |
+
#N N
|
| 59 |
+
( C
|
| 60 |
+
C )
|
| 61 |
+
( O
|
| 62 |
+
O )
|
| 63 |
+
( N
|
| 64 |
+
N )
|
| 65 |
+
Br c
|
| 66 |
+
( =O
|
| 67 |
+
(=O )
|
| 68 |
+
C (=O)
|
| 69 |
+
C =O
|
| 70 |
+
C =N
|
| 71 |
+
C #N
|
| 72 |
+
C #C
|
| 73 |
+
C C
|
| 74 |
+
CC C
|
| 75 |
+
CC N
|
| 76 |
+
CC O
|
| 77 |
+
CC S
|
| 78 |
+
CC c
|
| 79 |
+
CC n
|
| 80 |
+
C N
|
| 81 |
+
CN C
|
| 82 |
+
CN c
|
| 83 |
+
C O
|
| 84 |
+
CO C
|
| 85 |
+
CO N
|
| 86 |
+
CO c
|
| 87 |
+
C S
|
| 88 |
+
CS C
|
| 89 |
+
CS S
|
| 90 |
+
CS c
|
| 91 |
+
C c
|
| 92 |
+
Cl c
|
| 93 |
+
C n
|
| 94 |
+
F c
|
| 95 |
+
N C
|
| 96 |
+
NC C
|
| 97 |
+
NC c
|
| 98 |
+
N N
|
| 99 |
+
N O
|
| 100 |
+
N c
|
| 101 |
+
N n
|
| 102 |
+
O C
|
| 103 |
+
OC C
|
| 104 |
+
OC O
|
| 105 |
+
OC c
|
| 106 |
+
O N
|
| 107 |
+
O O
|
| 108 |
+
O c
|
| 109 |
+
S C
|
| 110 |
+
SC C
|
| 111 |
+
SC c
|
| 112 |
+
S S
|
| 113 |
+
S c
|
| 114 |
+
c c
|
| 115 |
+
cc c
|
| 116 |
+
cc n
|
| 117 |
+
cc o
|
| 118 |
+
cc s
|
| 119 |
+
cc cc
|
| 120 |
+
c n
|
| 121 |
+
cn c
|
| 122 |
+
cn n
|
| 123 |
+
c o
|
| 124 |
+
co c
|
| 125 |
+
c s
|
| 126 |
+
cs c
|
| 127 |
+
cs n
|
| 128 |
+
n c
|
| 129 |
+
nc c
|
| 130 |
+
nc n
|
| 131 |
+
nc o
|
| 132 |
+
nc s
|
| 133 |
+
n n
|
| 134 |
+
nn c
|
| 135 |
+
nn n
|
| 136 |
+
n o
|
| 137 |
+
no c
|
| 138 |
+
no n
|
| 139 |
+
n s
|
| 140 |
+
ns c
|
| 141 |
+
ns n
|
| 142 |
+
o c
|
| 143 |
+
oc c
|
| 144 |
+
o n
|
| 145 |
+
s c
|
| 146 |
+
sc c
|
| 147 |
+
sc n
|
| 148 |
+
s n
|
| 149 |
+
N P
|
| 150 |
+
P N
|
| 151 |
+
C P
|
| 152 |
+
P C
|
| 153 |
+
N S
|
| 154 |
+
S N
|
| 155 |
+
C S
|
| 156 |
+
S C
|
| 157 |
+
S P
|
| 158 |
+
P S
|
| 159 |
+
C I
|
tr2d2-pep/scoring/tokenizer/new_vocab.txt
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[PAD]
|
| 2 |
+
[UNK]
|
| 3 |
+
[CLS]
|
| 4 |
+
[SEP]
|
| 5 |
+
[MASK]
|
| 6 |
+
#
|
| 7 |
+
%
|
| 8 |
+
(
|
| 9 |
+
)
|
| 10 |
+
+
|
| 11 |
+
-
|
| 12 |
+
/
|
| 13 |
+
0
|
| 14 |
+
1
|
| 15 |
+
2
|
| 16 |
+
3
|
| 17 |
+
4
|
| 18 |
+
5
|
| 19 |
+
6
|
| 20 |
+
7
|
| 21 |
+
8
|
| 22 |
+
9
|
| 23 |
+
=
|
| 24 |
+
@
|
| 25 |
+
A
|
| 26 |
+
B
|
| 27 |
+
Br
|
| 28 |
+
Brc
|
| 29 |
+
C
|
| 30 |
+
CC
|
| 31 |
+
CCC
|
| 32 |
+
CCN
|
| 33 |
+
CCO
|
| 34 |
+
CCS
|
| 35 |
+
CCc
|
| 36 |
+
CCn
|
| 37 |
+
CN
|
| 38 |
+
CNC
|
| 39 |
+
CNc
|
| 40 |
+
CO
|
| 41 |
+
COC
|
| 42 |
+
CON
|
| 43 |
+
COc
|
| 44 |
+
CS
|
| 45 |
+
CSC
|
| 46 |
+
CSS
|
| 47 |
+
CSc
|
| 48 |
+
Cc
|
| 49 |
+
Cl
|
| 50 |
+
Clc
|
| 51 |
+
Cn
|
| 52 |
+
F
|
| 53 |
+
Fc
|
| 54 |
+
H
|
| 55 |
+
I
|
| 56 |
+
K
|
| 57 |
+
L
|
| 58 |
+
M
|
| 59 |
+
N
|
| 60 |
+
NC
|
| 61 |
+
NCC
|
| 62 |
+
NCc
|
| 63 |
+
NN
|
| 64 |
+
NO
|
| 65 |
+
Nc
|
| 66 |
+
Nn
|
| 67 |
+
O
|
| 68 |
+
OC
|
| 69 |
+
OCC
|
| 70 |
+
OCO
|
| 71 |
+
OCc
|
| 72 |
+
ON
|
| 73 |
+
OO
|
| 74 |
+
Oc
|
| 75 |
+
P
|
| 76 |
+
R
|
| 77 |
+
S
|
| 78 |
+
SC
|
| 79 |
+
SCC
|
| 80 |
+
SCc
|
| 81 |
+
SS
|
| 82 |
+
Sc
|
| 83 |
+
T
|
| 84 |
+
X
|
| 85 |
+
Z
|
| 86 |
+
[
|
| 87 |
+
\\
|
| 88 |
+
(/
|
| 89 |
+
]
|
| 90 |
+
a
|
| 91 |
+
b
|
| 92 |
+
c
|
| 93 |
+
cc
|
| 94 |
+
ccc
|
| 95 |
+
cccc
|
| 96 |
+
ccn
|
| 97 |
+
cco
|
| 98 |
+
ccs
|
| 99 |
+
cn
|
| 100 |
+
cnc
|
| 101 |
+
cnn
|
| 102 |
+
co
|
| 103 |
+
coc
|
| 104 |
+
cs
|
| 105 |
+
csc
|
| 106 |
+
csn
|
| 107 |
+
e
|
| 108 |
+
g
|
| 109 |
+
i
|
| 110 |
+
l
|
| 111 |
+
n
|
| 112 |
+
nc
|
| 113 |
+
ncc
|
| 114 |
+
ncn
|
| 115 |
+
nco
|
| 116 |
+
ncs
|
| 117 |
+
nn
|
| 118 |
+
nnc
|
| 119 |
+
nnn
|
| 120 |
+
no
|
| 121 |
+
noc
|
| 122 |
+
non
|
| 123 |
+
ns
|
| 124 |
+
nsc
|
| 125 |
+
nsn
|
| 126 |
+
o
|
| 127 |
+
oc
|
| 128 |
+
occ
|
| 129 |
+
on
|
| 130 |
+
p
|
| 131 |
+
r
|
| 132 |
+
s
|
| 133 |
+
sc
|
| 134 |
+
scc
|
| 135 |
+
scn
|
| 136 |
+
sn
|
| 137 |
+
t
|
| 138 |
+
c1
|
| 139 |
+
c2
|
| 140 |
+
c3
|
| 141 |
+
c4
|
| 142 |
+
c5
|
| 143 |
+
c6
|
| 144 |
+
c7
|
| 145 |
+
c8
|
| 146 |
+
c9
|
| 147 |
+
n1
|
| 148 |
+
n2
|
| 149 |
+
n3
|
| 150 |
+
n4
|
| 151 |
+
n5
|
| 152 |
+
n6
|
| 153 |
+
n7
|
| 154 |
+
n8
|
| 155 |
+
n9
|
| 156 |
+
O1
|
| 157 |
+
O2
|
| 158 |
+
O3
|
| 159 |
+
O4
|
| 160 |
+
O5
|
| 161 |
+
O6
|
| 162 |
+
O7
|
| 163 |
+
O8
|
| 164 |
+
O9
|
| 165 |
+
(c1
|
| 166 |
+
(c2
|
| 167 |
+
c1)
|
| 168 |
+
c2)
|
| 169 |
+
(n1
|
| 170 |
+
(n2
|
| 171 |
+
n1)
|
| 172 |
+
n2)
|
| 173 |
+
(O1
|
| 174 |
+
(O2
|
| 175 |
+
O2)
|
| 176 |
+
=O
|
| 177 |
+
=C
|
| 178 |
+
=c
|
| 179 |
+
=N
|
| 180 |
+
=n
|
| 181 |
+
=CC
|
| 182 |
+
=CN
|
| 183 |
+
=Cc
|
| 184 |
+
=cc
|
| 185 |
+
=NC
|
| 186 |
+
=Nc
|
| 187 |
+
=nC
|
| 188 |
+
=nc
|
| 189 |
+
#C
|
| 190 |
+
#CC
|
| 191 |
+
#CN
|
| 192 |
+
#N
|
| 193 |
+
#NC
|
| 194 |
+
#NN
|
| 195 |
+
(C
|
| 196 |
+
C)
|
| 197 |
+
(O
|
| 198 |
+
O)
|
| 199 |
+
(N
|
| 200 |
+
N)
|
| 201 |
+
NP
|
| 202 |
+
PN
|
| 203 |
+
CP
|
| 204 |
+
PC
|
| 205 |
+
NS
|
| 206 |
+
SN
|
| 207 |
+
SP
|
| 208 |
+
PS
|
| 209 |
+
C(=O)
|
| 210 |
+
(/Br)
|
| 211 |
+
(/C#N)
|
| 212 |
+
(/C)
|
| 213 |
+
(/C=N)
|
| 214 |
+
(/C=O)
|
| 215 |
+
(/CBr)
|
| 216 |
+
(/CC)
|
| 217 |
+
(/CCC)
|
| 218 |
+
(/CCF)
|
| 219 |
+
(/CCN)
|
| 220 |
+
(/CCO)
|
| 221 |
+
(/CCl)
|
| 222 |
+
(/CI)
|
| 223 |
+
(/CN)
|
| 224 |
+
(/CO)
|
| 225 |
+
(/CS)
|
| 226 |
+
(/Cl)
|
| 227 |
+
(/F)
|
| 228 |
+
(/I)
|
| 229 |
+
(/N)
|
| 230 |
+
(/NC)
|
| 231 |
+
(/NCC)
|
| 232 |
+
(/NO)
|
| 233 |
+
(/O)
|
| 234 |
+
(/OC)
|
| 235 |
+
(/OCC)
|
| 236 |
+
(/S)
|
| 237 |
+
(/SC)
|
| 238 |
+
(=C)
|
| 239 |
+
(=C/C)
|
| 240 |
+
(=C/F)
|
| 241 |
+
(=C/I)
|
| 242 |
+
(=C/N)
|
| 243 |
+
(=C/O)
|
| 244 |
+
(=CBr)
|
| 245 |
+
(=CC)
|
| 246 |
+
(=CCF)
|
| 247 |
+
(=CCN)
|
| 248 |
+
(=CCO)
|
| 249 |
+
(=CCl)
|
| 250 |
+
(=CF)
|
| 251 |
+
(=CI)
|
| 252 |
+
(=CN)
|
| 253 |
+
(=CO)
|
| 254 |
+
(=C\\C)
|
| 255 |
+
(=C\\F)
|
| 256 |
+
(=C\\I)
|
| 257 |
+
(=C\\N)
|
| 258 |
+
(=C\\O)
|
| 259 |
+
(=N)
|
| 260 |
+
(=N/C)
|
| 261 |
+
(=N/N)
|
| 262 |
+
(=N/O)
|
| 263 |
+
(=NBr)
|
| 264 |
+
(=NC)
|
| 265 |
+
(=NCC)
|
| 266 |
+
(=NCl)
|
| 267 |
+
(=NN)
|
| 268 |
+
(=NO)
|
| 269 |
+
(=NOC)
|
| 270 |
+
(=N\\C)
|
| 271 |
+
(=N\\N)
|
| 272 |
+
(=N\\O)
|
| 273 |
+
(=O)
|
| 274 |
+
(=S)
|
| 275 |
+
(B)
|
| 276 |
+
(Br)
|
| 277 |
+
(C#C)
|
| 278 |
+
(C#CC)
|
| 279 |
+
(C#CI)
|
| 280 |
+
(C#CO)
|
| 281 |
+
(C#N)
|
| 282 |
+
(C#SN)
|
| 283 |
+
(C)
|
| 284 |
+
(C=C)
|
| 285 |
+
(C=CF)
|
| 286 |
+
(C=CI)
|
| 287 |
+
(C=N)
|
| 288 |
+
(C=NN)
|
| 289 |
+
(C=NO)
|
| 290 |
+
(C=O)
|
| 291 |
+
(C=S)
|
| 292 |
+
(CBr)
|
| 293 |
+
(CC#C)
|
| 294 |
+
(CC#N)
|
| 295 |
+
(CC)
|
| 296 |
+
(CC=C)
|
| 297 |
+
(CC=O)
|
| 298 |
+
(CCBr)
|
| 299 |
+
(CCC)
|
| 300 |
+
(CCCC)
|
| 301 |
+
(CCCF)
|
| 302 |
+
(CCCI)
|
| 303 |
+
(CCCN)
|
| 304 |
+
(CCCO)
|
| 305 |
+
(CCCS)
|
| 306 |
+
(CCCl)
|
| 307 |
+
(CCF)
|
| 308 |
+
(CCI)
|
| 309 |
+
(CCN)
|
| 310 |
+
(CCNC)
|
| 311 |
+
(CCNN)
|
| 312 |
+
(CCNO)
|
| 313 |
+
(CCO)
|
| 314 |
+
(CCOC)
|
| 315 |
+
(CCON)
|
| 316 |
+
(CCS)
|
| 317 |
+
(CCSC)
|
| 318 |
+
(CCl)
|
| 319 |
+
(CF)
|
| 320 |
+
(CI)
|
| 321 |
+
(CN)
|
| 322 |
+
(CN=O)
|
| 323 |
+
(CNC)
|
| 324 |
+
(CNCC)
|
| 325 |
+
(CNCO)
|
| 326 |
+
(CNN)
|
| 327 |
+
(CNNC)
|
| 328 |
+
(CNO)
|
| 329 |
+
(CNOC)
|
| 330 |
+
(CO)
|
| 331 |
+
(COC)
|
| 332 |
+
(COCC)
|
| 333 |
+
(COCI)
|
| 334 |
+
(COCN)
|
| 335 |
+
(COCO)
|
| 336 |
+
(COF)
|
| 337 |
+
(CON)
|
| 338 |
+
(COO)
|
| 339 |
+
(CS)
|
| 340 |
+
(CSC)
|
| 341 |
+
(CSCC)
|
| 342 |
+
(CSCF)
|
| 343 |
+
(CSO)
|
| 344 |
+
(Cl)
|
| 345 |
+
(F)
|
| 346 |
+
(I)
|
| 347 |
+
(N)
|
| 348 |
+
(N=N)
|
| 349 |
+
(N=NO)
|
| 350 |
+
(N=O)
|
| 351 |
+
(N=S)
|
| 352 |
+
(NBr)
|
| 353 |
+
(NC#N)
|
| 354 |
+
(NC)
|
| 355 |
+
(NC=N)
|
| 356 |
+
(NC=O)
|
| 357 |
+
(NC=S)
|
| 358 |
+
(NCBr)
|
| 359 |
+
(NCC)
|
| 360 |
+
(NCCC)
|
| 361 |
+
(NCCF)
|
| 362 |
+
(NCCN)
|
| 363 |
+
(NCCO)
|
| 364 |
+
(NCCS)
|
| 365 |
+
(NCCl)
|
| 366 |
+
(NCNC)
|
| 367 |
+
(NCO)
|
| 368 |
+
(NCS)
|
| 369 |
+
(NCl)
|
| 370 |
+
(NN)
|
| 371 |
+
(NN=O)
|
| 372 |
+
(NNC)
|
| 373 |
+
(NO)
|
| 374 |
+
(NOC)
|
| 375 |
+
(O)
|
| 376 |
+
(OC#N)
|
| 377 |
+
(OC)
|
| 378 |
+
(OC=C)
|
| 379 |
+
(OC=O)
|
| 380 |
+
(OC=S)
|
| 381 |
+
(OCBr)
|
| 382 |
+
(OCC)
|
| 383 |
+
(OCCC)
|
| 384 |
+
(OCCF)
|
| 385 |
+
(OCCI)
|
| 386 |
+
(OCCN)
|
| 387 |
+
(OCCO)
|
| 388 |
+
(OCCS)
|
| 389 |
+
(OCCl)
|
| 390 |
+
(OCF)
|
| 391 |
+
(OCI)
|
| 392 |
+
(OCO)
|
| 393 |
+
(OCOC)
|
| 394 |
+
(OCON)
|
| 395 |
+
(OCSC)
|
| 396 |
+
(OCl)
|
| 397 |
+
(OI)
|
| 398 |
+
(ON)
|
| 399 |
+
(OO)
|
| 400 |
+
(OOC)
|
| 401 |
+
(OOCC)
|
| 402 |
+
(OOSN)
|
| 403 |
+
(OSC)
|
| 404 |
+
(P)
|
| 405 |
+
(S)
|
| 406 |
+
(SC#N)
|
| 407 |
+
(SC)
|
| 408 |
+
(SCC)
|
| 409 |
+
(SCCC)
|
| 410 |
+
(SCCF)
|
| 411 |
+
(SCCN)
|
| 412 |
+
(SCCO)
|
| 413 |
+
(SCCS)
|
| 414 |
+
(SCCl)
|
| 415 |
+
(SCF)
|
| 416 |
+
(SCN)
|
| 417 |
+
(SCOC)
|
| 418 |
+
(SCSC)
|
| 419 |
+
(SCl)
|
| 420 |
+
(SI)
|
| 421 |
+
(SN)
|
| 422 |
+
(SN=O)
|
| 423 |
+
(SO)
|
| 424 |
+
(SOC)
|
| 425 |
+
(SOOO)
|
| 426 |
+
(SS)
|
| 427 |
+
(SSC)
|
| 428 |
+
(SSCC)
|
| 429 |
+
([At])
|
| 430 |
+
([O-])
|
| 431 |
+
([O])
|
| 432 |
+
([S-])
|
| 433 |
+
(\\Br)
|
| 434 |
+
(\\C#N)
|
| 435 |
+
(\\C)
|
| 436 |
+
(\\C=N)
|
| 437 |
+
(\\C=O)
|
| 438 |
+
(\\CBr)
|
| 439 |
+
(\\CC)
|
| 440 |
+
(\\CCC)
|
| 441 |
+
(\\CCO)
|
| 442 |
+
(\\CCl)
|
| 443 |
+
(\\CF)
|
| 444 |
+
(\\CN)
|
| 445 |
+
(\\CNC)
|
| 446 |
+
(\\CO)
|
| 447 |
+
(\\COC)
|
| 448 |
+
(\\Cl)
|
| 449 |
+
(\\F)
|
| 450 |
+
(\\I)
|
| 451 |
+
(\\N)
|
| 452 |
+
(\\NC)
|
| 453 |
+
(\\NCC)
|
| 454 |
+
(\\NN)
|
| 455 |
+
(\\NO)
|
| 456 |
+
(\\NOC)
|
| 457 |
+
(\\O)
|
| 458 |
+
(\\OC)
|
| 459 |
+
(\\OCC)
|
| 460 |
+
(\\ON)
|
| 461 |
+
(\\S)
|
| 462 |
+
(\\SC)
|
| 463 |
+
(\\SCC)
|
| 464 |
+
[Ag+]
|
| 465 |
+
[Ag-4]
|
| 466 |
+
[Ag]
|
| 467 |
+
[Al-3]
|
| 468 |
+
[Al]
|
| 469 |
+
[As+]
|
| 470 |
+
[AsH3]
|
| 471 |
+
[AsH]
|
| 472 |
+
[As]
|
| 473 |
+
[At]
|
| 474 |
+
[B-]
|
| 475 |
+
[B@-]
|
| 476 |
+
[B@@-]
|
| 477 |
+
[BH-]
|
| 478 |
+
[BH2-]
|
| 479 |
+
[BH3-]
|
| 480 |
+
[B]
|
| 481 |
+
[Ba]
|
| 482 |
+
[Br+2]
|
| 483 |
+
[BrH]
|
| 484 |
+
[Br]
|
| 485 |
+
[C+]
|
| 486 |
+
[C-]
|
| 487 |
+
[C@@H]
|
| 488 |
+
[C@@]
|
| 489 |
+
[C@H]
|
| 490 |
+
[C@]
|
| 491 |
+
[CH-]
|
| 492 |
+
[CH2]
|
| 493 |
+
[CH3]
|
| 494 |
+
[CH]
|
| 495 |
+
[C]
|
| 496 |
+
[CaH2]
|
| 497 |
+
[Ca]
|
| 498 |
+
[Cl+2]
|
| 499 |
+
[Cl+3]
|
| 500 |
+
[Cl+]
|
| 501 |
+
[Cs]
|
| 502 |
+
[FH]
|
| 503 |
+
[F]
|
| 504 |
+
[H]
|
| 505 |
+
[He]
|
| 506 |
+
[I+2]
|
| 507 |
+
[I+3]
|
| 508 |
+
[I+]
|
| 509 |
+
[IH]
|
| 510 |
+
[I]
|
| 511 |
+
[K]
|
| 512 |
+
[Kr]
|
| 513 |
+
[Li+]
|
| 514 |
+
[LiH]
|
| 515 |
+
[MgH2]
|
| 516 |
+
[Mg]
|
| 517 |
+
[N+]
|
| 518 |
+
[N-]
|
| 519 |
+
[N@+]
|
| 520 |
+
[N@@+]
|
| 521 |
+
[N@@]
|
| 522 |
+
[N@]
|
| 523 |
+
[NH+]
|
| 524 |
+
[NH-]
|
| 525 |
+
[NH2+]
|
| 526 |
+
[NH3]
|
| 527 |
+
[NH]
|
| 528 |
+
[N]
|
| 529 |
+
[Na]
|
| 530 |
+
[O+]
|
| 531 |
+
[O-]
|
| 532 |
+
[OH+]
|
| 533 |
+
[OH2]
|
| 534 |
+
[OH]
|
| 535 |
+
[O]
|
| 536 |
+
[P+]
|
| 537 |
+
[P@+]
|
| 538 |
+
[P@@+]
|
| 539 |
+
[P@@]
|
| 540 |
+
[P@]
|
| 541 |
+
[PH2]
|
| 542 |
+
[PH]
|
| 543 |
+
[P]
|
| 544 |
+
[Ra]
|
| 545 |
+
[Rb]
|
| 546 |
+
[S+]
|
| 547 |
+
[S-]
|
| 548 |
+
[S@+]
|
| 549 |
+
[S@@+]
|
| 550 |
+
[S@@]
|
| 551 |
+
[S@]
|
| 552 |
+
[SH+]
|
| 553 |
+
[SH2]
|
| 554 |
+
[SH]
|
| 555 |
+
[S]
|
| 556 |
+
[Se+]
|
| 557 |
+
[Se-2]
|
| 558 |
+
[SeH2]
|
| 559 |
+
[SeH]
|
| 560 |
+
[Se]
|
| 561 |
+
[Si@]
|
| 562 |
+
[SiH2]
|
| 563 |
+
[SiH]
|
| 564 |
+
[Si]
|
| 565 |
+
[SrH2]
|
| 566 |
+
[TeH]
|
| 567 |
+
[Te]
|
| 568 |
+
[Xe]
|
| 569 |
+
[Zn+2]
|
| 570 |
+
[Zn-2]
|
| 571 |
+
[Zn]
|
| 572 |
+
[b-]
|
| 573 |
+
[c+]
|
| 574 |
+
[c-]
|
| 575 |
+
[cH-]
|
| 576 |
+
[cH]
|
| 577 |
+
[c]
|
| 578 |
+
[n+]
|
| 579 |
+
[n-]
|
| 580 |
+
[nH]
|
| 581 |
+
[n]
|
| 582 |
+
[o+]
|
| 583 |
+
[s+]
|
| 584 |
+
[se+]
|
| 585 |
+
[se]
|
| 586 |
+
[te+]
|
| 587 |
+
[te]
|
tr2d2-pep/tokenizer/my_tokenizers.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
from transformers import PreTrainedTokenizer
|
| 6 |
+
from SmilesPE.tokenizer import SPE_Tokenizer
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
def load_vocab(vocab_file):
|
| 10 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 11 |
+
vocab = collections.OrderedDict()
|
| 12 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 13 |
+
tokens = reader.readlines()
|
| 14 |
+
for index, token in enumerate(tokens):
|
| 15 |
+
token = token.rstrip("\n")
|
| 16 |
+
vocab[token] = index
|
| 17 |
+
return vocab
|
| 18 |
+
|
| 19 |
+
class Atomwise_Tokenizer(object):
|
| 20 |
+
"""Run atom-level SMILES tokenization"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
""" Constructs a atom-level Tokenizer.
|
| 24 |
+
"""
|
| 25 |
+
# self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 26 |
+
self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
|
| 27 |
+
|
| 28 |
+
self.regex = re.compile(self.regex_pattern)
|
| 29 |
+
|
| 30 |
+
def tokenize(self, text):
|
| 31 |
+
""" Basic Tokenization of a SMILES.
|
| 32 |
+
"""
|
| 33 |
+
tokens = [token for token in self.regex.findall(text)]
|
| 34 |
+
return tokens
|
| 35 |
+
|
| 36 |
+
class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
|
| 37 |
+
r"""
|
| 38 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 39 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 40 |
+
should refer to the superclass for more information regarding methods.
|
| 41 |
+
Args:
|
| 42 |
+
vocab_file (:obj:`string`):
|
| 43 |
+
File containing the vocabulary.
|
| 44 |
+
spe_file (:obj:`string`):
|
| 45 |
+
File containing the trained SMILES Pair Encoding vocabulary.
|
| 46 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 47 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 48 |
+
token instead.
|
| 49 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 50 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 51 |
+
for sequence classification or for a text and a question for question answering.
|
| 52 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 53 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 54 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 55 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 56 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 57 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 58 |
+
special tokens.
|
| 59 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 60 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 61 |
+
modeling. This is the token which the model will try to predict.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, vocab_file, spe_file,
|
| 65 |
+
unk_token="[UNK]",
|
| 66 |
+
sep_token="[SEP]",
|
| 67 |
+
pad_token="[PAD]",
|
| 68 |
+
cls_token="[CLS]",
|
| 69 |
+
mask_token="[MASK]",
|
| 70 |
+
**kwargs):
|
| 71 |
+
if not os.path.isfile(vocab_file):
|
| 72 |
+
raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
|
| 73 |
+
if not os.path.isfile(spe_file):
|
| 74 |
+
raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
|
| 75 |
+
|
| 76 |
+
self.vocab = load_vocab(vocab_file)
|
| 77 |
+
self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
|
| 78 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 79 |
+
self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
|
| 80 |
+
|
| 81 |
+
super().__init__(
|
| 82 |
+
unk_token=unk_token,
|
| 83 |
+
sep_token=sep_token,
|
| 84 |
+
pad_token=pad_token,
|
| 85 |
+
cls_token=cls_token,
|
| 86 |
+
mask_token=mask_token,
|
| 87 |
+
**kwargs)
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def vocab_size(self):
|
| 91 |
+
return len(self.vocab)
|
| 92 |
+
|
| 93 |
+
def get_vocab(self):
|
| 94 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 95 |
+
|
| 96 |
+
def _tokenize(self, text):
|
| 97 |
+
return self.spe_tokenizer.tokenize(text).split(' ')
|
| 98 |
+
|
| 99 |
+
def _convert_token_to_id(self, token):
|
| 100 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 101 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 102 |
+
|
| 103 |
+
# changed encode and decode functions
|
| 104 |
+
def encode(self, token_array):
|
| 105 |
+
token_ids = []
|
| 106 |
+
token_ids.append(2)
|
| 107 |
+
for token in token_array:
|
| 108 |
+
id = self._convert_token_to_id(token)
|
| 109 |
+
token_ids.append(id)
|
| 110 |
+
token_ids.append(3)
|
| 111 |
+
token_ids = torch.tensor([token_ids])
|
| 112 |
+
attn_mask = torch.ones_like(token_ids)
|
| 113 |
+
return {'input_ids': token_ids, 'attention_mask': attn_mask}
|
| 114 |
+
|
| 115 |
+
def decode(self, token_ids, skip_special_tokens=True):
|
| 116 |
+
token_ids = token_ids.squeeze(0).cpu().tolist()
|
| 117 |
+
token_array = []
|
| 118 |
+
for idx in token_ids:
|
| 119 |
+
if idx == 3: # Stop decoding when token ID 3 is encountered
|
| 120 |
+
break
|
| 121 |
+
if skip_special_tokens and idx in self.all_special_ids:
|
| 122 |
+
continue
|
| 123 |
+
token = self._convert_id_to_token(idx)
|
| 124 |
+
token_array.append(token)
|
| 125 |
+
sequence = "".join(token_array)
|
| 126 |
+
return sequence
|
| 127 |
+
|
| 128 |
+
def batch_decode(self, batch_token_ids, skip_special_tokens=True):
|
| 129 |
+
sequences = []
|
| 130 |
+
for token_ids in batch_token_ids:
|
| 131 |
+
sequences.append(self.decode(token_ids))
|
| 132 |
+
return sequences
|
| 133 |
+
|
| 134 |
+
def get_token_split(self, token_ids):
|
| 135 |
+
if isinstance(token_ids, torch.Tensor):
|
| 136 |
+
token_ids = token_ids.cpu().tolist()
|
| 137 |
+
|
| 138 |
+
token_array = []
|
| 139 |
+
for seq_ids in token_ids:
|
| 140 |
+
seq_array = []
|
| 141 |
+
for id in seq_ids:
|
| 142 |
+
token = self._convert_id_to_token(id)
|
| 143 |
+
seq_array.append(token)
|
| 144 |
+
token_array.append(seq_array)
|
| 145 |
+
|
| 146 |
+
return token_array
|
| 147 |
+
|
| 148 |
+
def _convert_id_to_token(self, index):
|
| 149 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 150 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 151 |
+
|
| 152 |
+
def convert_tokens_to_string(self, tokens):
|
| 153 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 154 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 155 |
+
return out_string
|
| 156 |
+
|
| 157 |
+
def build_inputs_with_special_tokens(
|
| 158 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 159 |
+
) -> List[int]:
|
| 160 |
+
"""
|
| 161 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 162 |
+
by concatenating and adding special tokens.
|
| 163 |
+
A BERT sequence has the following format:
|
| 164 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 165 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 166 |
+
Args:
|
| 167 |
+
token_ids_0 (:obj:`List[int]`):
|
| 168 |
+
List of IDs to which the special tokens will be added
|
| 169 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 170 |
+
Optional second list of IDs for sequence pairs.
|
| 171 |
+
Returns:
|
| 172 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 173 |
+
"""
|
| 174 |
+
if token_ids_1 is None:
|
| 175 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 176 |
+
cls = [self.cls_token_id]
|
| 177 |
+
sep = [self.sep_token_id]
|
| 178 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 179 |
+
|
| 180 |
+
def get_special_tokens_mask(
|
| 181 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 182 |
+
) -> List[int]:
|
| 183 |
+
"""
|
| 184 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 185 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 186 |
+
Args:
|
| 187 |
+
token_ids_0 (:obj:`List[int]`):
|
| 188 |
+
List of ids.
|
| 189 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 190 |
+
Optional second list of IDs for sequence pairs.
|
| 191 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 192 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 193 |
+
Returns:
|
| 194 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
if already_has_special_tokens:
|
| 198 |
+
if token_ids_1 is not None:
|
| 199 |
+
raise ValueError(
|
| 200 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 201 |
+
"ids is already formated with special tokens for the model."
|
| 202 |
+
)
|
| 203 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 204 |
+
|
| 205 |
+
if token_ids_1 is not None:
|
| 206 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 207 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 208 |
+
|
| 209 |
+
def create_token_type_ids_from_sequences(
|
| 210 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 211 |
+
) -> List[int]:
|
| 212 |
+
"""
|
| 213 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 214 |
+
A BERT sequence pair mask has the following format:
|
| 215 |
+
::
|
| 216 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 217 |
+
| first sequence | second sequence |
|
| 218 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 219 |
+
Args:
|
| 220 |
+
token_ids_0 (:obj:`List[int]`):
|
| 221 |
+
List of ids.
|
| 222 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 223 |
+
Optional second list of IDs for sequence pairs.
|
| 224 |
+
Returns:
|
| 225 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 226 |
+
sequence(s).
|
| 227 |
+
"""
|
| 228 |
+
sep = [self.sep_token_id]
|
| 229 |
+
cls = [self.cls_token_id]
|
| 230 |
+
if token_ids_1 is None:
|
| 231 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 232 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 233 |
+
|
| 234 |
+
def save_vocabulary(self, vocab_path):
|
| 235 |
+
"""
|
| 236 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 237 |
+
Args:
|
| 238 |
+
vocab_path (:obj:`str`):
|
| 239 |
+
The directory in which to save the vocabulary.
|
| 240 |
+
Returns:
|
| 241 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 242 |
+
"""
|
| 243 |
+
index = 0
|
| 244 |
+
vocab_file = vocab_path
|
| 245 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 246 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 247 |
+
if index != token_index:
|
| 248 |
+
index = token_index
|
| 249 |
+
writer.write(token + "\n")
|
| 250 |
+
index += 1
|
| 251 |
+
return (vocab_file,)
|
| 252 |
+
|
| 253 |
+
class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
|
| 254 |
+
r"""
|
| 255 |
+
Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
|
| 256 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 257 |
+
should refer to the superclass for more information regarding methods.
|
| 258 |
+
Args:
|
| 259 |
+
vocab_file (:obj:`string`):
|
| 260 |
+
File containing the vocabulary.
|
| 261 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 262 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 263 |
+
token instead.
|
| 264 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 265 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 266 |
+
for sequence classification or for a text and a question for question answering.
|
| 267 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 268 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 269 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 270 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 271 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 272 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 273 |
+
special tokens.
|
| 274 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 275 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 276 |
+
modeling. This is the token which the model will try to predict.
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
def __init__(
|
| 280 |
+
self,
|
| 281 |
+
vocab_file,
|
| 282 |
+
unk_token="[UNK]",
|
| 283 |
+
sep_token="[SEP]",
|
| 284 |
+
pad_token="[PAD]",
|
| 285 |
+
cls_token="[CLS]",
|
| 286 |
+
mask_token="[MASK]",
|
| 287 |
+
**kwargs
|
| 288 |
+
):
|
| 289 |
+
super().__init__(
|
| 290 |
+
unk_token=unk_token,
|
| 291 |
+
sep_token=sep_token,
|
| 292 |
+
pad_token=pad_token,
|
| 293 |
+
cls_token=cls_token,
|
| 294 |
+
mask_token=mask_token,
|
| 295 |
+
**kwargs,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if not os.path.isfile(vocab_file):
|
| 299 |
+
raise ValueError(
|
| 300 |
+
"Can't find a vocabulary file at path '{}'.".format(vocab_file)
|
| 301 |
+
)
|
| 302 |
+
self.vocab = load_vocab(vocab_file)
|
| 303 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 304 |
+
self.tokenizer = Atomwise_Tokenizer()
|
| 305 |
+
|
| 306 |
+
@property
|
| 307 |
+
def vocab_size(self):
|
| 308 |
+
return len(self.vocab)
|
| 309 |
+
|
| 310 |
+
def get_vocab(self):
|
| 311 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def _tokenize(self, text):
|
| 315 |
+
return self.tokenizer.tokenize(text)
|
| 316 |
+
|
| 317 |
+
def _convert_token_to_id(self, token):
|
| 318 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 319 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 320 |
+
|
| 321 |
+
def _convert_id_to_token(self, index):
|
| 322 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 323 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 324 |
+
|
| 325 |
+
def convert_tokens_to_string(self, tokens):
|
| 326 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 327 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 328 |
+
return out_string
|
| 329 |
+
|
| 330 |
+
def build_inputs_with_special_tokens(
|
| 331 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 332 |
+
) -> List[int]:
|
| 333 |
+
"""
|
| 334 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 335 |
+
by concatenating and adding special tokens.
|
| 336 |
+
A BERT sequence has the following format:
|
| 337 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 338 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 339 |
+
Args:
|
| 340 |
+
token_ids_0 (:obj:`List[int]`):
|
| 341 |
+
List of IDs to which the special tokens will be added
|
| 342 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 343 |
+
Optional second list of IDs for sequence pairs.
|
| 344 |
+
Returns:
|
| 345 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 346 |
+
"""
|
| 347 |
+
if token_ids_1 is None:
|
| 348 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 349 |
+
cls = [self.cls_token_id]
|
| 350 |
+
sep = [self.sep_token_id]
|
| 351 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 352 |
+
|
| 353 |
+
def get_special_tokens_mask(
|
| 354 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 355 |
+
) -> List[int]:
|
| 356 |
+
"""
|
| 357 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 358 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 359 |
+
Args:
|
| 360 |
+
token_ids_0 (:obj:`List[int]`):
|
| 361 |
+
List of ids.
|
| 362 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 363 |
+
Optional second list of IDs for sequence pairs.
|
| 364 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 365 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 366 |
+
Returns:
|
| 367 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 368 |
+
"""
|
| 369 |
+
|
| 370 |
+
if already_has_special_tokens:
|
| 371 |
+
if token_ids_1 is not None:
|
| 372 |
+
raise ValueError(
|
| 373 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 374 |
+
"ids is already formated with special tokens for the model."
|
| 375 |
+
)
|
| 376 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 377 |
+
|
| 378 |
+
if token_ids_1 is not None:
|
| 379 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 380 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 381 |
+
|
| 382 |
+
def create_token_type_ids_from_sequences(
|
| 383 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 384 |
+
) -> List[int]:
|
| 385 |
+
"""
|
| 386 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 387 |
+
A BERT sequence pair mask has the following format:
|
| 388 |
+
::
|
| 389 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 390 |
+
| first sequence | second sequence |
|
| 391 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 392 |
+
Args:
|
| 393 |
+
token_ids_0 (:obj:`List[int]`):
|
| 394 |
+
List of ids.
|
| 395 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 396 |
+
Optional second list of IDs for sequence pairs.
|
| 397 |
+
Returns:
|
| 398 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 399 |
+
sequence(s).
|
| 400 |
+
"""
|
| 401 |
+
sep = [self.sep_token_id]
|
| 402 |
+
cls = [self.cls_token_id]
|
| 403 |
+
if token_ids_1 is None:
|
| 404 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 405 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 406 |
+
|
| 407 |
+
def save_vocabulary(self, vocab_path):
|
| 408 |
+
"""
|
| 409 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 410 |
+
Args:
|
| 411 |
+
vocab_path (:obj:`str`):
|
| 412 |
+
The directory in which to save the vocabulary.
|
| 413 |
+
Returns:
|
| 414 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 415 |
+
"""
|
| 416 |
+
index = 0
|
| 417 |
+
vocab_file = vocab_path
|
| 418 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 419 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 420 |
+
if index != token_index:
|
| 421 |
+
index = token_index
|
| 422 |
+
writer.write(token + "\n")
|
| 423 |
+
index += 1
|
| 424 |
+
return (vocab_file,)
|
tr2d2-pep/tokenizer/new_splits.txt
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
c 1
|
| 2 |
+
c 2
|
| 3 |
+
c 3
|
| 4 |
+
c 4
|
| 5 |
+
c 5
|
| 6 |
+
c 6
|
| 7 |
+
c 7
|
| 8 |
+
c 8
|
| 9 |
+
c 9
|
| 10 |
+
( c1
|
| 11 |
+
( c2
|
| 12 |
+
c1 )
|
| 13 |
+
c2 )
|
| 14 |
+
n 1
|
| 15 |
+
n 2
|
| 16 |
+
n 3
|
| 17 |
+
n 4
|
| 18 |
+
n 5
|
| 19 |
+
n 6
|
| 20 |
+
n 7
|
| 21 |
+
n 8
|
| 22 |
+
n 9
|
| 23 |
+
( n1
|
| 24 |
+
( n2
|
| 25 |
+
n1 )
|
| 26 |
+
n2 )
|
| 27 |
+
O 1
|
| 28 |
+
O 2
|
| 29 |
+
O 3
|
| 30 |
+
O 4
|
| 31 |
+
O 5
|
| 32 |
+
O 6
|
| 33 |
+
O 7
|
| 34 |
+
O 8
|
| 35 |
+
O 9
|
| 36 |
+
( O1
|
| 37 |
+
( O2
|
| 38 |
+
O2 )
|
| 39 |
+
O2 )
|
| 40 |
+
= O
|
| 41 |
+
= C
|
| 42 |
+
= c
|
| 43 |
+
= N
|
| 44 |
+
= n
|
| 45 |
+
=C C
|
| 46 |
+
=C N
|
| 47 |
+
=C c
|
| 48 |
+
=c c
|
| 49 |
+
=N C
|
| 50 |
+
=N c
|
| 51 |
+
=n C
|
| 52 |
+
=n c
|
| 53 |
+
# N
|
| 54 |
+
# C
|
| 55 |
+
#N C
|
| 56 |
+
#C C
|
| 57 |
+
#C N
|
| 58 |
+
#N N
|
| 59 |
+
( C
|
| 60 |
+
C )
|
| 61 |
+
( O
|
| 62 |
+
O )
|
| 63 |
+
( N
|
| 64 |
+
N )
|
| 65 |
+
Br c
|
| 66 |
+
( =O
|
| 67 |
+
(=O )
|
| 68 |
+
C (=O)
|
| 69 |
+
C =O
|
| 70 |
+
C =N
|
| 71 |
+
C #N
|
| 72 |
+
C #C
|
| 73 |
+
C C
|
| 74 |
+
CC C
|
| 75 |
+
CC N
|
| 76 |
+
CC O
|
| 77 |
+
CC S
|
| 78 |
+
CC c
|
| 79 |
+
CC n
|
| 80 |
+
C N
|
| 81 |
+
CN C
|
| 82 |
+
CN c
|
| 83 |
+
C O
|
| 84 |
+
CO C
|
| 85 |
+
CO N
|
| 86 |
+
CO c
|
| 87 |
+
C S
|
| 88 |
+
CS C
|
| 89 |
+
CS S
|
| 90 |
+
CS c
|
| 91 |
+
C c
|
| 92 |
+
Cl c
|
| 93 |
+
C n
|
| 94 |
+
F c
|
| 95 |
+
N C
|
| 96 |
+
NC C
|
| 97 |
+
NC c
|
| 98 |
+
N N
|
| 99 |
+
N O
|
| 100 |
+
N c
|
| 101 |
+
N n
|
| 102 |
+
O C
|
| 103 |
+
OC C
|
| 104 |
+
OC O
|
| 105 |
+
OC c
|
| 106 |
+
O N
|
| 107 |
+
O O
|
| 108 |
+
O c
|
| 109 |
+
S C
|
| 110 |
+
SC C
|
| 111 |
+
SC c
|
| 112 |
+
S S
|
| 113 |
+
S c
|
| 114 |
+
c c
|
| 115 |
+
cc c
|
| 116 |
+
cc n
|
| 117 |
+
cc o
|
| 118 |
+
cc s
|
| 119 |
+
cc cc
|
| 120 |
+
c n
|
| 121 |
+
cn c
|
| 122 |
+
cn n
|
| 123 |
+
c o
|
| 124 |
+
co c
|
| 125 |
+
c s
|
| 126 |
+
cs c
|
| 127 |
+
cs n
|
| 128 |
+
n c
|
| 129 |
+
nc c
|
| 130 |
+
nc n
|
| 131 |
+
nc o
|
| 132 |
+
nc s
|
| 133 |
+
n n
|
| 134 |
+
nn c
|
| 135 |
+
nn n
|
| 136 |
+
n o
|
| 137 |
+
no c
|
| 138 |
+
no n
|
| 139 |
+
n s
|
| 140 |
+
ns c
|
| 141 |
+
ns n
|
| 142 |
+
o c
|
| 143 |
+
oc c
|
| 144 |
+
o n
|
| 145 |
+
s c
|
| 146 |
+
sc c
|
| 147 |
+
sc n
|
| 148 |
+
s n
|
| 149 |
+
N P
|
| 150 |
+
P N
|
| 151 |
+
C P
|
| 152 |
+
P C
|
| 153 |
+
N S
|
| 154 |
+
S N
|
| 155 |
+
C S
|
| 156 |
+
S C
|
| 157 |
+
S P
|
| 158 |
+
P S
|
| 159 |
+
C I
|
tr2d2-pep/tokenizer/new_vocab.txt
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[PAD]
|
| 2 |
+
[UNK]
|
| 3 |
+
[CLS]
|
| 4 |
+
[SEP]
|
| 5 |
+
[MASK]
|
| 6 |
+
#
|
| 7 |
+
%
|
| 8 |
+
(
|
| 9 |
+
)
|
| 10 |
+
+
|
| 11 |
+
-
|
| 12 |
+
/
|
| 13 |
+
0
|
| 14 |
+
1
|
| 15 |
+
2
|
| 16 |
+
3
|
| 17 |
+
4
|
| 18 |
+
5
|
| 19 |
+
6
|
| 20 |
+
7
|
| 21 |
+
8
|
| 22 |
+
9
|
| 23 |
+
=
|
| 24 |
+
@
|
| 25 |
+
A
|
| 26 |
+
B
|
| 27 |
+
Br
|
| 28 |
+
Brc
|
| 29 |
+
C
|
| 30 |
+
CC
|
| 31 |
+
CCC
|
| 32 |
+
CCN
|
| 33 |
+
CCO
|
| 34 |
+
CCS
|
| 35 |
+
CCc
|
| 36 |
+
CCn
|
| 37 |
+
CN
|
| 38 |
+
CNC
|
| 39 |
+
CNc
|
| 40 |
+
CO
|
| 41 |
+
COC
|
| 42 |
+
CON
|
| 43 |
+
COc
|
| 44 |
+
CS
|
| 45 |
+
CSC
|
| 46 |
+
CSS
|
| 47 |
+
CSc
|
| 48 |
+
Cc
|
| 49 |
+
Cl
|
| 50 |
+
Clc
|
| 51 |
+
Cn
|
| 52 |
+
F
|
| 53 |
+
Fc
|
| 54 |
+
H
|
| 55 |
+
I
|
| 56 |
+
K
|
| 57 |
+
L
|
| 58 |
+
M
|
| 59 |
+
N
|
| 60 |
+
NC
|
| 61 |
+
NCC
|
| 62 |
+
NCc
|
| 63 |
+
NN
|
| 64 |
+
NO
|
| 65 |
+
Nc
|
| 66 |
+
Nn
|
| 67 |
+
O
|
| 68 |
+
OC
|
| 69 |
+
OCC
|
| 70 |
+
OCO
|
| 71 |
+
OCc
|
| 72 |
+
ON
|
| 73 |
+
OO
|
| 74 |
+
Oc
|
| 75 |
+
P
|
| 76 |
+
R
|
| 77 |
+
S
|
| 78 |
+
SC
|
| 79 |
+
SCC
|
| 80 |
+
SCc
|
| 81 |
+
SS
|
| 82 |
+
Sc
|
| 83 |
+
T
|
| 84 |
+
X
|
| 85 |
+
Z
|
| 86 |
+
[
|
| 87 |
+
\\
|
| 88 |
+
(/
|
| 89 |
+
]
|
| 90 |
+
a
|
| 91 |
+
b
|
| 92 |
+
c
|
| 93 |
+
cc
|
| 94 |
+
ccc
|
| 95 |
+
cccc
|
| 96 |
+
ccn
|
| 97 |
+
cco
|
| 98 |
+
ccs
|
| 99 |
+
cn
|
| 100 |
+
cnc
|
| 101 |
+
cnn
|
| 102 |
+
co
|
| 103 |
+
coc
|
| 104 |
+
cs
|
| 105 |
+
csc
|
| 106 |
+
csn
|
| 107 |
+
e
|
| 108 |
+
g
|
| 109 |
+
i
|
| 110 |
+
l
|
| 111 |
+
n
|
| 112 |
+
nc
|
| 113 |
+
ncc
|
| 114 |
+
ncn
|
| 115 |
+
nco
|
| 116 |
+
ncs
|
| 117 |
+
nn
|
| 118 |
+
nnc
|
| 119 |
+
nnn
|
| 120 |
+
no
|
| 121 |
+
noc
|
| 122 |
+
non
|
| 123 |
+
ns
|
| 124 |
+
nsc
|
| 125 |
+
nsn
|
| 126 |
+
o
|
| 127 |
+
oc
|
| 128 |
+
occ
|
| 129 |
+
on
|
| 130 |
+
p
|
| 131 |
+
r
|
| 132 |
+
s
|
| 133 |
+
sc
|
| 134 |
+
scc
|
| 135 |
+
scn
|
| 136 |
+
sn
|
| 137 |
+
t
|
| 138 |
+
c1
|
| 139 |
+
c2
|
| 140 |
+
c3
|
| 141 |
+
c4
|
| 142 |
+
c5
|
| 143 |
+
c6
|
| 144 |
+
c7
|
| 145 |
+
c8
|
| 146 |
+
c9
|
| 147 |
+
n1
|
| 148 |
+
n2
|
| 149 |
+
n3
|
| 150 |
+
n4
|
| 151 |
+
n5
|
| 152 |
+
n6
|
| 153 |
+
n7
|
| 154 |
+
n8
|
| 155 |
+
n9
|
| 156 |
+
O1
|
| 157 |
+
O2
|
| 158 |
+
O3
|
| 159 |
+
O4
|
| 160 |
+
O5
|
| 161 |
+
O6
|
| 162 |
+
O7
|
| 163 |
+
O8
|
| 164 |
+
O9
|
| 165 |
+
(c1
|
| 166 |
+
(c2
|
| 167 |
+
c1)
|
| 168 |
+
c2)
|
| 169 |
+
(n1
|
| 170 |
+
(n2
|
| 171 |
+
n1)
|
| 172 |
+
n2)
|
| 173 |
+
(O1
|
| 174 |
+
(O2
|
| 175 |
+
O2)
|
| 176 |
+
=O
|
| 177 |
+
=C
|
| 178 |
+
=c
|
| 179 |
+
=N
|
| 180 |
+
=n
|
| 181 |
+
=CC
|
| 182 |
+
=CN
|
| 183 |
+
=Cc
|
| 184 |
+
=cc
|
| 185 |
+
=NC
|
| 186 |
+
=Nc
|
| 187 |
+
=nC
|
| 188 |
+
=nc
|
| 189 |
+
#C
|
| 190 |
+
#CC
|
| 191 |
+
#CN
|
| 192 |
+
#N
|
| 193 |
+
#NC
|
| 194 |
+
#NN
|
| 195 |
+
(C
|
| 196 |
+
C)
|
| 197 |
+
(O
|
| 198 |
+
O)
|
| 199 |
+
(N
|
| 200 |
+
N)
|
| 201 |
+
NP
|
| 202 |
+
PN
|
| 203 |
+
CP
|
| 204 |
+
PC
|
| 205 |
+
NS
|
| 206 |
+
SN
|
| 207 |
+
SP
|
| 208 |
+
PS
|
| 209 |
+
C(=O)
|
| 210 |
+
(/Br)
|
| 211 |
+
(/C#N)
|
| 212 |
+
(/C)
|
| 213 |
+
(/C=N)
|
| 214 |
+
(/C=O)
|
| 215 |
+
(/CBr)
|
| 216 |
+
(/CC)
|
| 217 |
+
(/CCC)
|
| 218 |
+
(/CCF)
|
| 219 |
+
(/CCN)
|
| 220 |
+
(/CCO)
|
| 221 |
+
(/CCl)
|
| 222 |
+
(/CI)
|
| 223 |
+
(/CN)
|
| 224 |
+
(/CO)
|
| 225 |
+
(/CS)
|
| 226 |
+
(/Cl)
|
| 227 |
+
(/F)
|
| 228 |
+
(/I)
|
| 229 |
+
(/N)
|
| 230 |
+
(/NC)
|
| 231 |
+
(/NCC)
|
| 232 |
+
(/NO)
|
| 233 |
+
(/O)
|
| 234 |
+
(/OC)
|
| 235 |
+
(/OCC)
|
| 236 |
+
(/S)
|
| 237 |
+
(/SC)
|
| 238 |
+
(=C)
|
| 239 |
+
(=C/C)
|
| 240 |
+
(=C/F)
|
| 241 |
+
(=C/I)
|
| 242 |
+
(=C/N)
|
| 243 |
+
(=C/O)
|
| 244 |
+
(=CBr)
|
| 245 |
+
(=CC)
|
| 246 |
+
(=CCF)
|
| 247 |
+
(=CCN)
|
| 248 |
+
(=CCO)
|
| 249 |
+
(=CCl)
|
| 250 |
+
(=CF)
|
| 251 |
+
(=CI)
|
| 252 |
+
(=CN)
|
| 253 |
+
(=CO)
|
| 254 |
+
(=C\\C)
|
| 255 |
+
(=C\\F)
|
| 256 |
+
(=C\\I)
|
| 257 |
+
(=C\\N)
|
| 258 |
+
(=C\\O)
|
| 259 |
+
(=N)
|
| 260 |
+
(=N/C)
|
| 261 |
+
(=N/N)
|
| 262 |
+
(=N/O)
|
| 263 |
+
(=NBr)
|
| 264 |
+
(=NC)
|
| 265 |
+
(=NCC)
|
| 266 |
+
(=NCl)
|
| 267 |
+
(=NN)
|
| 268 |
+
(=NO)
|
| 269 |
+
(=NOC)
|
| 270 |
+
(=N\\C)
|
| 271 |
+
(=N\\N)
|
| 272 |
+
(=N\\O)
|
| 273 |
+
(=O)
|
| 274 |
+
(=S)
|
| 275 |
+
(B)
|
| 276 |
+
(Br)
|
| 277 |
+
(C#C)
|
| 278 |
+
(C#CC)
|
| 279 |
+
(C#CI)
|
| 280 |
+
(C#CO)
|
| 281 |
+
(C#N)
|
| 282 |
+
(C#SN)
|
| 283 |
+
(C)
|
| 284 |
+
(C=C)
|
| 285 |
+
(C=CF)
|
| 286 |
+
(C=CI)
|
| 287 |
+
(C=N)
|
| 288 |
+
(C=NN)
|
| 289 |
+
(C=NO)
|
| 290 |
+
(C=O)
|
| 291 |
+
(C=S)
|
| 292 |
+
(CBr)
|
| 293 |
+
(CC#C)
|
| 294 |
+
(CC#N)
|
| 295 |
+
(CC)
|
| 296 |
+
(CC=C)
|
| 297 |
+
(CC=O)
|
| 298 |
+
(CCBr)
|
| 299 |
+
(CCC)
|
| 300 |
+
(CCCC)
|
| 301 |
+
(CCCF)
|
| 302 |
+
(CCCI)
|
| 303 |
+
(CCCN)
|
| 304 |
+
(CCCO)
|
| 305 |
+
(CCCS)
|
| 306 |
+
(CCCl)
|
| 307 |
+
(CCF)
|
| 308 |
+
(CCI)
|
| 309 |
+
(CCN)
|
| 310 |
+
(CCNC)
|
| 311 |
+
(CCNN)
|
| 312 |
+
(CCNO)
|
| 313 |
+
(CCO)
|
| 314 |
+
(CCOC)
|
| 315 |
+
(CCON)
|
| 316 |
+
(CCS)
|
| 317 |
+
(CCSC)
|
| 318 |
+
(CCl)
|
| 319 |
+
(CF)
|
| 320 |
+
(CI)
|
| 321 |
+
(CN)
|
| 322 |
+
(CN=O)
|
| 323 |
+
(CNC)
|
| 324 |
+
(CNCC)
|
| 325 |
+
(CNCO)
|
| 326 |
+
(CNN)
|
| 327 |
+
(CNNC)
|
| 328 |
+
(CNO)
|
| 329 |
+
(CNOC)
|
| 330 |
+
(CO)
|
| 331 |
+
(COC)
|
| 332 |
+
(COCC)
|
| 333 |
+
(COCI)
|
| 334 |
+
(COCN)
|
| 335 |
+
(COCO)
|
| 336 |
+
(COF)
|
| 337 |
+
(CON)
|
| 338 |
+
(COO)
|
| 339 |
+
(CS)
|
| 340 |
+
(CSC)
|
| 341 |
+
(CSCC)
|
| 342 |
+
(CSCF)
|
| 343 |
+
(CSO)
|
| 344 |
+
(Cl)
|
| 345 |
+
(F)
|
| 346 |
+
(I)
|
| 347 |
+
(N)
|
| 348 |
+
(N=N)
|
| 349 |
+
(N=NO)
|
| 350 |
+
(N=O)
|
| 351 |
+
(N=S)
|
| 352 |
+
(NBr)
|
| 353 |
+
(NC#N)
|
| 354 |
+
(NC)
|
| 355 |
+
(NC=N)
|
| 356 |
+
(NC=O)
|
| 357 |
+
(NC=S)
|
| 358 |
+
(NCBr)
|
| 359 |
+
(NCC)
|
| 360 |
+
(NCCC)
|
| 361 |
+
(NCCF)
|
| 362 |
+
(NCCN)
|
| 363 |
+
(NCCO)
|
| 364 |
+
(NCCS)
|
| 365 |
+
(NCCl)
|
| 366 |
+
(NCNC)
|
| 367 |
+
(NCO)
|
| 368 |
+
(NCS)
|
| 369 |
+
(NCl)
|
| 370 |
+
(NN)
|
| 371 |
+
(NN=O)
|
| 372 |
+
(NNC)
|
| 373 |
+
(NO)
|
| 374 |
+
(NOC)
|
| 375 |
+
(O)
|
| 376 |
+
(OC#N)
|
| 377 |
+
(OC)
|
| 378 |
+
(OC=C)
|
| 379 |
+
(OC=O)
|
| 380 |
+
(OC=S)
|
| 381 |
+
(OCBr)
|
| 382 |
+
(OCC)
|
| 383 |
+
(OCCC)
|
| 384 |
+
(OCCF)
|
| 385 |
+
(OCCI)
|
| 386 |
+
(OCCN)
|
| 387 |
+
(OCCO)
|
| 388 |
+
(OCCS)
|
| 389 |
+
(OCCl)
|
| 390 |
+
(OCF)
|
| 391 |
+
(OCI)
|
| 392 |
+
(OCO)
|
| 393 |
+
(OCOC)
|
| 394 |
+
(OCON)
|
| 395 |
+
(OCSC)
|
| 396 |
+
(OCl)
|
| 397 |
+
(OI)
|
| 398 |
+
(ON)
|
| 399 |
+
(OO)
|
| 400 |
+
(OOC)
|
| 401 |
+
(OOCC)
|
| 402 |
+
(OOSN)
|
| 403 |
+
(OSC)
|
| 404 |
+
(P)
|
| 405 |
+
(S)
|
| 406 |
+
(SC#N)
|
| 407 |
+
(SC)
|
| 408 |
+
(SCC)
|
| 409 |
+
(SCCC)
|
| 410 |
+
(SCCF)
|
| 411 |
+
(SCCN)
|
| 412 |
+
(SCCO)
|
| 413 |
+
(SCCS)
|
| 414 |
+
(SCCl)
|
| 415 |
+
(SCF)
|
| 416 |
+
(SCN)
|
| 417 |
+
(SCOC)
|
| 418 |
+
(SCSC)
|
| 419 |
+
(SCl)
|
| 420 |
+
(SI)
|
| 421 |
+
(SN)
|
| 422 |
+
(SN=O)
|
| 423 |
+
(SO)
|
| 424 |
+
(SOC)
|
| 425 |
+
(SOOO)
|
| 426 |
+
(SS)
|
| 427 |
+
(SSC)
|
| 428 |
+
(SSCC)
|
| 429 |
+
([At])
|
| 430 |
+
([O-])
|
| 431 |
+
([O])
|
| 432 |
+
([S-])
|
| 433 |
+
(\\Br)
|
| 434 |
+
(\\C#N)
|
| 435 |
+
(\\C)
|
| 436 |
+
(\\C=N)
|
| 437 |
+
(\\C=O)
|
| 438 |
+
(\\CBr)
|
| 439 |
+
(\\CC)
|
| 440 |
+
(\\CCC)
|
| 441 |
+
(\\CCO)
|
| 442 |
+
(\\CCl)
|
| 443 |
+
(\\CF)
|
| 444 |
+
(\\CN)
|
| 445 |
+
(\\CNC)
|
| 446 |
+
(\\CO)
|
| 447 |
+
(\\COC)
|
| 448 |
+
(\\Cl)
|
| 449 |
+
(\\F)
|
| 450 |
+
(\\I)
|
| 451 |
+
(\\N)
|
| 452 |
+
(\\NC)
|
| 453 |
+
(\\NCC)
|
| 454 |
+
(\\NN)
|
| 455 |
+
(\\NO)
|
| 456 |
+
(\\NOC)
|
| 457 |
+
(\\O)
|
| 458 |
+
(\\OC)
|
| 459 |
+
(\\OCC)
|
| 460 |
+
(\\ON)
|
| 461 |
+
(\\S)
|
| 462 |
+
(\\SC)
|
| 463 |
+
(\\SCC)
|
| 464 |
+
[Ag+]
|
| 465 |
+
[Ag-4]
|
| 466 |
+
[Ag]
|
| 467 |
+
[Al-3]
|
| 468 |
+
[Al]
|
| 469 |
+
[As+]
|
| 470 |
+
[AsH3]
|
| 471 |
+
[AsH]
|
| 472 |
+
[As]
|
| 473 |
+
[At]
|
| 474 |
+
[B-]
|
| 475 |
+
[B@-]
|
| 476 |
+
[B@@-]
|
| 477 |
+
[BH-]
|
| 478 |
+
[BH2-]
|
| 479 |
+
[BH3-]
|
| 480 |
+
[B]
|
| 481 |
+
[Ba]
|
| 482 |
+
[Br+2]
|
| 483 |
+
[BrH]
|
| 484 |
+
[Br]
|
| 485 |
+
[C+]
|
| 486 |
+
[C-]
|
| 487 |
+
[C@@H]
|
| 488 |
+
[C@@]
|
| 489 |
+
[C@H]
|
| 490 |
+
[C@]
|
| 491 |
+
[CH-]
|
| 492 |
+
[CH2]
|
| 493 |
+
[CH3]
|
| 494 |
+
[CH]
|
| 495 |
+
[C]
|
| 496 |
+
[CaH2]
|
| 497 |
+
[Ca]
|
| 498 |
+
[Cl+2]
|
| 499 |
+
[Cl+3]
|
| 500 |
+
[Cl+]
|
| 501 |
+
[Cs]
|
| 502 |
+
[FH]
|
| 503 |
+
[F]
|
| 504 |
+
[H]
|
| 505 |
+
[He]
|
| 506 |
+
[I+2]
|
| 507 |
+
[I+3]
|
| 508 |
+
[I+]
|
| 509 |
+
[IH]
|
| 510 |
+
[I]
|
| 511 |
+
[K]
|
| 512 |
+
[Kr]
|
| 513 |
+
[Li+]
|
| 514 |
+
[LiH]
|
| 515 |
+
[MgH2]
|
| 516 |
+
[Mg]
|
| 517 |
+
[N+]
|
| 518 |
+
[N-]
|
| 519 |
+
[N@+]
|
| 520 |
+
[N@@+]
|
| 521 |
+
[N@@]
|
| 522 |
+
[N@]
|
| 523 |
+
[NH+]
|
| 524 |
+
[NH-]
|
| 525 |
+
[NH2+]
|
| 526 |
+
[NH3]
|
| 527 |
+
[NH]
|
| 528 |
+
[N]
|
| 529 |
+
[Na]
|
| 530 |
+
[O+]
|
| 531 |
+
[O-]
|
| 532 |
+
[OH+]
|
| 533 |
+
[OH2]
|
| 534 |
+
[OH]
|
| 535 |
+
[O]
|
| 536 |
+
[P+]
|
| 537 |
+
[P@+]
|
| 538 |
+
[P@@+]
|
| 539 |
+
[P@@]
|
| 540 |
+
[P@]
|
| 541 |
+
[PH2]
|
| 542 |
+
[PH]
|
| 543 |
+
[P]
|
| 544 |
+
[Ra]
|
| 545 |
+
[Rb]
|
| 546 |
+
[S+]
|
| 547 |
+
[S-]
|
| 548 |
+
[S@+]
|
| 549 |
+
[S@@+]
|
| 550 |
+
[S@@]
|
| 551 |
+
[S@]
|
| 552 |
+
[SH+]
|
| 553 |
+
[SH2]
|
| 554 |
+
[SH]
|
| 555 |
+
[S]
|
| 556 |
+
[Se+]
|
| 557 |
+
[Se-2]
|
| 558 |
+
[SeH2]
|
| 559 |
+
[SeH]
|
| 560 |
+
[Se]
|
| 561 |
+
[Si@]
|
| 562 |
+
[SiH2]
|
| 563 |
+
[SiH]
|
| 564 |
+
[Si]
|
| 565 |
+
[SrH2]
|
| 566 |
+
[TeH]
|
| 567 |
+
[Te]
|
| 568 |
+
[Xe]
|
| 569 |
+
[Zn+2]
|
| 570 |
+
[Zn-2]
|
| 571 |
+
[Zn]
|
| 572 |
+
[b-]
|
| 573 |
+
[c+]
|
| 574 |
+
[c-]
|
| 575 |
+
[cH-]
|
| 576 |
+
[cH]
|
| 577 |
+
[c]
|
| 578 |
+
[n+]
|
| 579 |
+
[n-]
|
| 580 |
+
[nH]
|
| 581 |
+
[n]
|
| 582 |
+
[o+]
|
| 583 |
+
[s+]
|
| 584 |
+
[se+]
|
| 585 |
+
[se]
|
| 586 |
+
[te+]
|
| 587 |
+
[te]
|
tr2d2-pep/utils/app.py
ADDED
|
@@ -0,0 +1,1255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from io import StringIO
|
| 5 |
+
import rdkit
|
| 6 |
+
from rdkit import Chem
|
| 7 |
+
from rdkit.Chem import AllChem, Draw
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import matplotlib.patches as patches
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
import tempfile
|
| 14 |
+
from rdkit import Chem
|
| 15 |
+
|
| 16 |
+
class PeptideAnalyzer:
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.bond_patterns = [
|
| 19 |
+
(r'OC\(=O\)', 'ester'), # Ester bond
|
| 20 |
+
(r'N\(C\)C\(=O\)', 'n_methyl'), # N-methylated peptide bond
|
| 21 |
+
(r'N[0-9]C\(=O\)', 'proline'), # Proline peptide bond
|
| 22 |
+
(r'NC\(=O\)', 'peptide'), # Standard peptide bond
|
| 23 |
+
(r'C\(=O\)N\(C\)', 'n_methyl_reverse'), # Reverse N-methylated
|
| 24 |
+
(r'C\(=O\)N[12]?', 'peptide_reverse') # Reverse peptide bond
|
| 25 |
+
]
|
| 26 |
+
# Three to one letter code mapping
|
| 27 |
+
self.three_to_one = {
|
| 28 |
+
'Ala': 'A', 'Cys': 'C', 'Asp': 'D', 'Glu': 'E',
|
| 29 |
+
'Phe': 'F', 'Gly': 'G', 'His': 'H', 'Ile': 'I',
|
| 30 |
+
'Lys': 'K', 'Leu': 'L', 'Met': 'M', 'Asn': 'N',
|
| 31 |
+
'Pro': 'P', 'Gln': 'Q', 'Arg': 'R', 'Ser': 'S',
|
| 32 |
+
'Thr': 'T', 'Val': 'V', 'Trp': 'W', 'Tyr': 'Y'
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
def is_peptide(self, smiles):
|
| 36 |
+
"""Check if the SMILES represents a peptide structure"""
|
| 37 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 38 |
+
if mol is None:
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
# Look for peptide bonds: NC(=O) pattern
|
| 42 |
+
peptide_bond_pattern = Chem.MolFromSmarts('[NH][C](=O)')
|
| 43 |
+
if mol.HasSubstructMatch(peptide_bond_pattern):
|
| 44 |
+
return True
|
| 45 |
+
|
| 46 |
+
# Look for N-methylated peptide bonds: N(C)C(=O) pattern
|
| 47 |
+
n_methyl_pattern = Chem.MolFromSmarts('[N;H0;$(NC)](C)[C](=O)')
|
| 48 |
+
if mol.HasSubstructMatch(n_methyl_pattern):
|
| 49 |
+
return True
|
| 50 |
+
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
def is_cyclic(self, smiles):
|
| 54 |
+
"""Improved cyclic peptide detection"""
|
| 55 |
+
# Check for C-terminal carboxyl
|
| 56 |
+
if smiles.endswith('C(=O)O'):
|
| 57 |
+
return False, [], []
|
| 58 |
+
|
| 59 |
+
# Find all numbers used in ring closures
|
| 60 |
+
ring_numbers = re.findall(r'(?:^|[^c])[0-9](?=[A-Z@\(\)])', smiles)
|
| 61 |
+
|
| 62 |
+
# Find aromatic ring numbers
|
| 63 |
+
aromatic_matches = re.findall(r'c[0-9](?:ccccc|c\[nH\]c)[0-9]', smiles)
|
| 64 |
+
aromatic_cycles = []
|
| 65 |
+
for match in aromatic_matches:
|
| 66 |
+
numbers = re.findall(r'[0-9]', match)
|
| 67 |
+
aromatic_cycles.extend(numbers)
|
| 68 |
+
|
| 69 |
+
# Numbers that aren't part of aromatic rings are peptide cycles
|
| 70 |
+
peptide_cycles = [n for n in ring_numbers if n not in aromatic_cycles]
|
| 71 |
+
|
| 72 |
+
is_cyclic = len(peptide_cycles) > 0 and not smiles.endswith('C(=O)O')
|
| 73 |
+
return is_cyclic, peptide_cycles, aromatic_cycles
|
| 74 |
+
|
| 75 |
+
def split_on_bonds(self, smiles):
|
| 76 |
+
"""Split SMILES into segments with simplified Pro handling"""
|
| 77 |
+
positions = []
|
| 78 |
+
used = set()
|
| 79 |
+
|
| 80 |
+
# Find Gly pattern first
|
| 81 |
+
gly_pattern = r'NCC\(=O\)'
|
| 82 |
+
for match in re.finditer(gly_pattern, smiles):
|
| 83 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 84 |
+
positions.append({
|
| 85 |
+
'start': match.start(),
|
| 86 |
+
'end': match.end(),
|
| 87 |
+
'type': 'gly',
|
| 88 |
+
'pattern': match.group()
|
| 89 |
+
})
|
| 90 |
+
used.update(range(match.start(), match.end()))
|
| 91 |
+
|
| 92 |
+
for pattern, bond_type in self.bond_patterns:
|
| 93 |
+
for match in re.finditer(pattern, smiles):
|
| 94 |
+
if not any(p in range(match.start(), match.end()) for p in used):
|
| 95 |
+
positions.append({
|
| 96 |
+
'start': match.start(),
|
| 97 |
+
'end': match.end(),
|
| 98 |
+
'type': bond_type,
|
| 99 |
+
'pattern': match.group()
|
| 100 |
+
})
|
| 101 |
+
used.update(range(match.start(), match.end()))
|
| 102 |
+
|
| 103 |
+
# Sort by position
|
| 104 |
+
positions.sort(key=lambda x: x['start'])
|
| 105 |
+
|
| 106 |
+
# Create segments
|
| 107 |
+
segments = []
|
| 108 |
+
|
| 109 |
+
if positions:
|
| 110 |
+
# First segment
|
| 111 |
+
if positions[0]['start'] > 0:
|
| 112 |
+
segments.append({
|
| 113 |
+
'content': smiles[0:positions[0]['start']],
|
| 114 |
+
'bond_after': positions[0]['pattern']
|
| 115 |
+
})
|
| 116 |
+
|
| 117 |
+
# Process segments
|
| 118 |
+
for i in range(len(positions)-1):
|
| 119 |
+
current = positions[i]
|
| 120 |
+
next_pos = positions[i+1]
|
| 121 |
+
|
| 122 |
+
if current['type'] == 'gly':
|
| 123 |
+
segments.append({
|
| 124 |
+
'content': 'NCC(=O)',
|
| 125 |
+
'bond_before': positions[i-1]['pattern'] if i > 0 else None,
|
| 126 |
+
'bond_after': next_pos['pattern']
|
| 127 |
+
})
|
| 128 |
+
else:
|
| 129 |
+
content = smiles[current['end']:next_pos['start']]
|
| 130 |
+
if content:
|
| 131 |
+
segments.append({
|
| 132 |
+
'content': content,
|
| 133 |
+
'bond_before': current['pattern'],
|
| 134 |
+
'bond_after': next_pos['pattern']
|
| 135 |
+
})
|
| 136 |
+
|
| 137 |
+
# Last segment
|
| 138 |
+
if positions[-1]['end'] < len(smiles):
|
| 139 |
+
segments.append({
|
| 140 |
+
'content': smiles[positions[-1]['end']:],
|
| 141 |
+
'bond_before': positions[-1]['pattern']
|
| 142 |
+
})
|
| 143 |
+
|
| 144 |
+
return segments
|
| 145 |
+
|
| 146 |
+
def clean_terminal_carboxyl(self, segment):
|
| 147 |
+
"""Remove C-terminal carboxyl only if it's the true terminus"""
|
| 148 |
+
content = segment['content']
|
| 149 |
+
|
| 150 |
+
# Only clean if:
|
| 151 |
+
# 1. Contains C(=O)O
|
| 152 |
+
# 2. No bond_after exists (meaning it's the last segment)
|
| 153 |
+
# 3. C(=O)O is at the end of the content
|
| 154 |
+
if 'C(=O)O' in content and not segment.get('bond_after'):
|
| 155 |
+
print('recognized?')
|
| 156 |
+
# Remove C(=O)O pattern regardless of position
|
| 157 |
+
cleaned = re.sub(r'\(C\(=O\)O\)', '', content)
|
| 158 |
+
# Remove any leftover empty parentheses
|
| 159 |
+
cleaned = re.sub(r'\(\)', '', cleaned)
|
| 160 |
+
print(cleaned)
|
| 161 |
+
return cleaned
|
| 162 |
+
return content
|
| 163 |
+
|
| 164 |
+
def identify_residue(self, segment):
|
| 165 |
+
"""Identify residue with Pro reconstruction"""
|
| 166 |
+
# Only clean terminal carboxyl if this is the last segment
|
| 167 |
+
content = self.clean_terminal_carboxyl(segment)
|
| 168 |
+
mods = self.get_modifications(segment)
|
| 169 |
+
|
| 170 |
+
# UAA pattern matching section - before regular residues
|
| 171 |
+
# Phenylglycine and derivatives
|
| 172 |
+
if 'c1ccccc1' in content:
|
| 173 |
+
if '[C@@H](c1ccccc1)' in content or '[C@H](c1ccccc1)' in content:
|
| 174 |
+
return '4', mods # Base phenylglycine
|
| 175 |
+
|
| 176 |
+
# 4-substituted phenylalanines
|
| 177 |
+
if 'Cc1ccc' in content:
|
| 178 |
+
if 'OMe' in content or 'OCc1ccc' in content:
|
| 179 |
+
return '0A1', mods # 4-methoxy-Phenylalanine
|
| 180 |
+
elif 'Clc1ccc' in content:
|
| 181 |
+
return '200', mods # 4-chloro-Phenylalanine
|
| 182 |
+
elif 'Brc1ccc' in content:
|
| 183 |
+
return '4BF', mods # 4-Bromo-phenylalanine
|
| 184 |
+
elif 'C#Nc1ccc' in content:
|
| 185 |
+
return '4CF', mods # 4-cyano-phenylalanine
|
| 186 |
+
elif 'Ic1ccc' in content:
|
| 187 |
+
return 'PHI', mods # 4-Iodo-phenylalanine
|
| 188 |
+
elif 'Fc1ccc' in content:
|
| 189 |
+
return 'PFF', mods # 4-Fluoro-phenylalanine
|
| 190 |
+
|
| 191 |
+
# Modified tryptophans
|
| 192 |
+
if 'c[nH]c2' in content:
|
| 193 |
+
if 'Oc2cccc2' in content:
|
| 194 |
+
return '0AF', mods # 7-hydroxy-tryptophan
|
| 195 |
+
elif 'Fc2cccc2' in content:
|
| 196 |
+
return '4FW', mods # 4-fluoro-tryptophan
|
| 197 |
+
elif 'Clc2cccc2' in content:
|
| 198 |
+
return '6CW', mods # 6-chloro-tryptophan
|
| 199 |
+
elif 'Brc2cccc2' in content:
|
| 200 |
+
return 'BTR', mods # 6-bromo-tryptophan
|
| 201 |
+
elif 'COc2cccc2' in content:
|
| 202 |
+
return 'MOT5', mods # 5-Methoxy-tryptophan
|
| 203 |
+
elif 'Cc2cccc2' in content:
|
| 204 |
+
return 'MTR5', mods # 5-Methyl-tryptophan
|
| 205 |
+
|
| 206 |
+
# Special amino acids
|
| 207 |
+
if 'CC(C)(C)[C@@H]' in content or 'CC(C)(C)[C@H]' in content:
|
| 208 |
+
return 'BUG', mods # Tertleucine
|
| 209 |
+
|
| 210 |
+
if 'CCCNC(=N)N' in content:
|
| 211 |
+
return 'CIR', mods # Citrulline
|
| 212 |
+
|
| 213 |
+
if '[SeH]' in content:
|
| 214 |
+
return 'CSE', mods # Selenocysteine
|
| 215 |
+
|
| 216 |
+
if '[NH3]CC[C@@H]' in content or '[NH3]CC[C@H]' in content:
|
| 217 |
+
return 'DAB', mods # Diaminobutyric acid
|
| 218 |
+
|
| 219 |
+
if 'C1CCCCC1' in content:
|
| 220 |
+
if 'C1CCCCC1[C@@H]' in content or 'C1CCCCC1[C@H]' in content:
|
| 221 |
+
return 'CHG', mods # Cyclohexylglycine
|
| 222 |
+
elif 'C1CCCCC1C[C@@H]' in content or 'C1CCCCC1C[C@H]' in content:
|
| 223 |
+
return 'ALC', mods # 3-cyclohexyl-alanine
|
| 224 |
+
|
| 225 |
+
# Naphthalene derivatives
|
| 226 |
+
if 'c1cccc2c1cccc2' in content:
|
| 227 |
+
if 'c1cccc2c1cccc2[C@@H]' in content or 'c1cccc2c1cccc2[C@H]' in content:
|
| 228 |
+
return 'NAL', mods # 2-Naphthyl-alanine
|
| 229 |
+
|
| 230 |
+
# Heteroaromatic derivatives
|
| 231 |
+
if 'c1cncc' in content:
|
| 232 |
+
return 'PYR4', mods # 3-(4-Pyridyl)-alanine
|
| 233 |
+
if 'c1cscc' in content:
|
| 234 |
+
return 'THA3', mods # 3-(3-thienyl)-alanine
|
| 235 |
+
if 'c1nnc' in content:
|
| 236 |
+
return 'TRZ4', mods # 3-(1,2,4-Triazol-1-yl)-alanine
|
| 237 |
+
|
| 238 |
+
# Modified serines and threonines
|
| 239 |
+
if 'OP(O)(O)O' in content:
|
| 240 |
+
if '[C@@H](COP' in content or '[C@H](COP' in content:
|
| 241 |
+
return 'SEP', mods # phosphoserine
|
| 242 |
+
elif '[C@@H](OP' in content or '[C@H](OP' in content:
|
| 243 |
+
return 'TPO', mods # phosphothreonine
|
| 244 |
+
|
| 245 |
+
# Specialized ring systems
|
| 246 |
+
if 'c1c2ccccc2cc2c1cccc2' in content:
|
| 247 |
+
return 'ANTH', mods # 3-(9-anthryl)-alanine
|
| 248 |
+
if 'c1csc2c1cccc2' in content:
|
| 249 |
+
return 'BTH3', mods # 3-(3-benzothienyl)-alanine
|
| 250 |
+
if '[C@]12C[C@H]3C[C@@H](C2)C[C@@H](C1)C3' in content:
|
| 251 |
+
return 'ADAM', mods # Adamanthane
|
| 252 |
+
|
| 253 |
+
# Fluorinated derivatives
|
| 254 |
+
if 'FC(F)(F)' in content:
|
| 255 |
+
if 'CC(F)(F)F' in content:
|
| 256 |
+
return 'FLA', mods # Trifluoro-alanine
|
| 257 |
+
if 'C(F)(F)F)c1' in content:
|
| 258 |
+
if 'c1ccccc1C(F)(F)F' in content:
|
| 259 |
+
return 'TFG2', mods # 2-(Trifluoromethyl)-phenylglycine
|
| 260 |
+
if 'c1cccc(c1)C(F)(F)F' in content:
|
| 261 |
+
return 'TFG3', mods # 3-(Trifluoromethyl)-phenylglycine
|
| 262 |
+
if 'c1ccc(cc1)C(F)(F)F' in content:
|
| 263 |
+
return 'TFG4', mods # 4-(Trifluoromethyl)-phenylglycine
|
| 264 |
+
|
| 265 |
+
# Multiple halogen patterns
|
| 266 |
+
if 'F' in content and 'c1' in content:
|
| 267 |
+
if 'c1ccc(c(c1)F)F' in content:
|
| 268 |
+
return 'F2F', mods # 3,4-Difluoro-phenylalanine
|
| 269 |
+
if 'cc(F)cc(c1)F' in content:
|
| 270 |
+
return 'WFP', mods # 3,5-Difluoro-phenylalanine
|
| 271 |
+
if 'Cl' in content and 'c1' in content:
|
| 272 |
+
if 'c1ccc(cc1Cl)Cl' in content:
|
| 273 |
+
return 'CP24', mods # 2,4-dichloro-phenylalanine
|
| 274 |
+
if 'c1ccc(c(c1)Cl)Cl' in content:
|
| 275 |
+
return 'CP34', mods # 3,4-dichloro-phenylalanine
|
| 276 |
+
|
| 277 |
+
# Hydroxy and amino derivatives
|
| 278 |
+
if 'O' in content and 'c1' in content:
|
| 279 |
+
if 'c1cc(O)cc(c1)O' in content:
|
| 280 |
+
return '3FG', mods # (2s)-amino(3,5-dihydroxyphenyl)-ethanoic acid
|
| 281 |
+
if 'c1ccc(c(c1)O)O' in content:
|
| 282 |
+
return 'DAH', mods # 3,4-Dihydroxy-phenylalanine
|
| 283 |
+
|
| 284 |
+
# Cyclic amino acids
|
| 285 |
+
if 'C1CCCC1' in content:
|
| 286 |
+
return 'CPA3', mods # 3-Cyclopentyl-alanine
|
| 287 |
+
if 'C1CCCCC1' in content:
|
| 288 |
+
if 'CC1CCCCC1' in content:
|
| 289 |
+
return 'ALC', mods # 3-cyclohexyl-alanine
|
| 290 |
+
else:
|
| 291 |
+
return 'CHG', mods # Cyclohexylglycine
|
| 292 |
+
|
| 293 |
+
# Chain-length variants
|
| 294 |
+
if 'CCC[C@@H]' in content or 'CCC[C@H]' in content:
|
| 295 |
+
return 'NLE', mods # Norleucine
|
| 296 |
+
if 'CC[C@@H]' in content or 'CC[C@H]' in content:
|
| 297 |
+
if not any(x in content for x in ['CC(C)', 'COC', 'CN(']):
|
| 298 |
+
return 'ABA', mods # 2-Aminobutyric acid
|
| 299 |
+
|
| 300 |
+
# Modified histidines
|
| 301 |
+
if 'c1cnc' in content:
|
| 302 |
+
if '[C@@H]1CN[C@@H](N1)F' in content:
|
| 303 |
+
return '2HF', mods # 2-fluoro-l-histidine
|
| 304 |
+
if 'c1cnc([nH]1)F' in content:
|
| 305 |
+
return '2HF1', mods # 2-fluoro-l-histidine variant
|
| 306 |
+
if 'c1c[nH]c(n1)F' in content:
|
| 307 |
+
return '2HF2', mods # 2-fluoro-l-histidine variant
|
| 308 |
+
|
| 309 |
+
# Sulfur and selenium containing
|
| 310 |
+
if '[SeH]' in content:
|
| 311 |
+
return 'CSE', mods # Selenocysteine
|
| 312 |
+
if 'S' in content:
|
| 313 |
+
if 'CSCc1ccccc1' in content:
|
| 314 |
+
return 'BCS', mods # benzylcysteine
|
| 315 |
+
if 'CCSC' in content:
|
| 316 |
+
return 'ESC', mods # Ethionine
|
| 317 |
+
if 'CCS' in content:
|
| 318 |
+
return 'HCS', mods # homocysteine
|
| 319 |
+
|
| 320 |
+
# Additional modifications
|
| 321 |
+
if 'CN=[N]=N' in content:
|
| 322 |
+
return 'AZDA', mods # azido-alanine
|
| 323 |
+
if '[NH]=[C](=[NH2])=[NH2]' in content:
|
| 324 |
+
if 'CCC[NH]=' in content:
|
| 325 |
+
return 'AGM', mods # 5-methyl-arginine
|
| 326 |
+
if 'CC[NH]=' in content:
|
| 327 |
+
return 'GDPR', mods # 2-Amino-3-guanidinopropionic acid
|
| 328 |
+
|
| 329 |
+
if 'CCON' in content:
|
| 330 |
+
return 'CAN', mods # canaline
|
| 331 |
+
if '[C@@H]1C=C[C@@H](C=C1)' in content:
|
| 332 |
+
return 'ACZ', mods # cis-amiclenomycin
|
| 333 |
+
if 'CCC(=O)[NH3]' in content:
|
| 334 |
+
return 'ONL', mods # 5-oxo-l-norleucine
|
| 335 |
+
if 'c1ccncc1' in content:
|
| 336 |
+
return 'PYR4', mods # 3-(4-Pyridyl)-alanine
|
| 337 |
+
if 'c1ccco1' in content:
|
| 338 |
+
return 'FUA2', mods # (2-furyl)-alanine
|
| 339 |
+
|
| 340 |
+
if 'c1ccc' in content:
|
| 341 |
+
if 'c1ccc(cc1)c1ccccc1' in content:
|
| 342 |
+
return 'BIF', mods # 4,4-biphenylalanine
|
| 343 |
+
if 'c1ccc(cc1)C(=O)c1ccccc1' in content:
|
| 344 |
+
return 'PBF', mods # 4-benzoyl-phenylalanine
|
| 345 |
+
if 'c1ccc(cc1)C(C)(C)C' in content:
|
| 346 |
+
return 'TBP4', mods # 4-tert-butyl-phenylalanine
|
| 347 |
+
if 'c1ccc(cc1)[C](=[NH2])=[NH2]' in content:
|
| 348 |
+
return '0BN', mods # 4-carbamimidoyl-l-phenylalanine
|
| 349 |
+
if 'c1cccc(c1)[C](=[NH2])=[NH2]' in content:
|
| 350 |
+
return 'APM', mods # m-amidinophenyl-3-alanine
|
| 351 |
+
|
| 352 |
+
# Multiple hydroxy patterns
|
| 353 |
+
if 'O' in content:
|
| 354 |
+
if '[C@H]([C@H](C)O)O' in content:
|
| 355 |
+
return 'ILX', mods # 4,5-dihydroxy-isoleucine
|
| 356 |
+
if '[C@H]([C@@H](C)O)O' in content:
|
| 357 |
+
return 'ALO', mods # Allo-threonine
|
| 358 |
+
if '[C@H](COP(O)(O)O)' in content:
|
| 359 |
+
return 'SEP', mods # phosphoserine
|
| 360 |
+
if '[C@H]([C@@H](C)OP(O)(O)O)' in content:
|
| 361 |
+
return 'TPO', mods # phosphothreonine
|
| 362 |
+
if '[C@H](c1ccc(O)cc1)O' in content:
|
| 363 |
+
return 'OMX', mods # (betar)-beta-hydroxy-l-tyrosine
|
| 364 |
+
if '[C@H](c1ccc(c(Cl)c1)O)O' in content:
|
| 365 |
+
return 'OMY', mods # (betar)-3-chloro-beta-hydroxy-l-tyrosine
|
| 366 |
+
|
| 367 |
+
# Heterocyclic patterns
|
| 368 |
+
if 'n1' in content:
|
| 369 |
+
if 'n1cccn1' in content:
|
| 370 |
+
return 'PYZ1', mods # 3-(1-Pyrazolyl)-alanine
|
| 371 |
+
if 'n1nncn1' in content:
|
| 372 |
+
return 'TEZA', mods # 3-(2-Tetrazolyl)-alanine
|
| 373 |
+
if 'c2c(n1)cccc2' in content:
|
| 374 |
+
return 'QU32', mods # 3-(2-Quinolyl)-alanine
|
| 375 |
+
if 'c1cnc2c(c1)cccc2' in content:
|
| 376 |
+
return 'QU33', mods # 3-(3-quinolyl)-alanine
|
| 377 |
+
if 'c1ccnc2c1cccc2' in content:
|
| 378 |
+
return 'QU34', mods # 3-(4-quinolyl)-alanine
|
| 379 |
+
if 'c1ccc2c(c1)nccc2' in content:
|
| 380 |
+
return 'QU35', mods # 3-(5-Quinolyl)-alanine
|
| 381 |
+
if 'c1ccc2c(c1)cncc2' in content:
|
| 382 |
+
return 'QU36', mods # 3-(6-Quinolyl)-alanine
|
| 383 |
+
if 'c1cnc2c(n1)cccc2' in content:
|
| 384 |
+
return 'QX32', mods # 3-(2-quinoxalyl)-alanine
|
| 385 |
+
|
| 386 |
+
# Multiple nitrogen patterns
|
| 387 |
+
if 'N' in content:
|
| 388 |
+
if '[NH3]CC[C@@H]' in content:
|
| 389 |
+
return 'DAB', mods # Diaminobutyric acid
|
| 390 |
+
if '[NH3]C[C@@H]' in content:
|
| 391 |
+
return 'DPP', mods # 2,3-Diaminopropanoic acid
|
| 392 |
+
if '[NH3]CCCCCC[C@@H]' in content:
|
| 393 |
+
return 'HHK', mods # (2s)-2,8-diaminooctanoic acid
|
| 394 |
+
if 'CCC[NH]=[C](=[NH2])=[NH2]' in content:
|
| 395 |
+
return 'GBUT', mods # 2-Amino-4-guanidinobutryric acid
|
| 396 |
+
if '[NH]=[C](=S)=[NH2]' in content:
|
| 397 |
+
return 'THIC', mods # Thio-citrulline
|
| 398 |
+
|
| 399 |
+
# Chain modified amino acids
|
| 400 |
+
if 'CC' in content:
|
| 401 |
+
if 'CCCC[C@@H]' in content:
|
| 402 |
+
return 'AHP', mods # 2-Aminoheptanoic acid
|
| 403 |
+
if 'CCC([C@@H])(C)C' in content:
|
| 404 |
+
return 'I2M', mods # 3-methyl-l-alloisoleucine
|
| 405 |
+
if 'CC[C@H]([C@@H])C' in content:
|
| 406 |
+
return 'IIL', mods # Allo-Isoleucine
|
| 407 |
+
if '[C@H](CCC(C)C)' in content:
|
| 408 |
+
return 'HLEU', mods # Homoleucine
|
| 409 |
+
if '[C@@H]([C@@H](C)O)C' in content:
|
| 410 |
+
return 'HLU', mods # beta-hydroxyleucine
|
| 411 |
+
|
| 412 |
+
# Modified glutamate/aspartate patterns
|
| 413 |
+
if '[C@@H]' in content:
|
| 414 |
+
if '[C@@H](C[C@@H](F))' in content:
|
| 415 |
+
return 'FGA4', mods # 4-Fluoro-glutamic acid
|
| 416 |
+
if '[C@@H](C[C@@H](O))' in content:
|
| 417 |
+
return '3GL', mods # 4-hydroxy-glutamic-acid
|
| 418 |
+
if '[C@@H](C[C@H](C))' in content:
|
| 419 |
+
return 'LME', mods # (3r)-3-methyl-l-glutamic acid
|
| 420 |
+
if '[C@@H](CC[C@H](C))' in content:
|
| 421 |
+
return 'MEG', mods # (3s)-3-methyl-l-glutamic acid
|
| 422 |
+
|
| 423 |
+
# Sulfur and selenium modifications
|
| 424 |
+
if 'S' in content:
|
| 425 |
+
if 'SCC[C@@H]' in content:
|
| 426 |
+
return 'HSER', mods # homoserine
|
| 427 |
+
if 'SCCN' in content:
|
| 428 |
+
return 'SLZ', mods # thialysine
|
| 429 |
+
if 'SC(=O)' in content:
|
| 430 |
+
return 'CSA', mods # s-acetonylcysteine
|
| 431 |
+
if '[S@@](=O)' in content:
|
| 432 |
+
return 'SME', mods # Methionine sulfoxide
|
| 433 |
+
if 'S(=O)(=O)' in content:
|
| 434 |
+
return 'OMT', mods # Methionine sulfone
|
| 435 |
+
|
| 436 |
+
# Double bond containing
|
| 437 |
+
if 'C=' in content:
|
| 438 |
+
if 'C=C[C@@H]' in content:
|
| 439 |
+
return '2AG', mods # 2-Allyl-glycine
|
| 440 |
+
if 'C=C[C@@H]' in content:
|
| 441 |
+
return 'LVG', mods # vinylglycine
|
| 442 |
+
if 'C=Cc1ccccc1' in content:
|
| 443 |
+
return 'STYA', mods # Styrylalanine
|
| 444 |
+
|
| 445 |
+
# Special cases
|
| 446 |
+
if '[C@@H]1Cc2c(C1)cccc2' in content:
|
| 447 |
+
return 'IGL', mods # alpha-amino-2-indanacetic acid
|
| 448 |
+
if '[C](=[C](=O)=O)=O' in content:
|
| 449 |
+
return '26P', mods # 2-amino-6-oxopimelic acid
|
| 450 |
+
if '[C](=[C](=O)=O)=C' in content:
|
| 451 |
+
return '2NP', mods # l-2-amino-6-methylene-pimelic acid
|
| 452 |
+
if 'c2cnc[nH]2' in content:
|
| 453 |
+
return 'HIS', mods # histidine core
|
| 454 |
+
if 'c1cccc2c1cc(O)cc2' in content:
|
| 455 |
+
return 'NAO1', mods # 5-hydroxy-1-naphthalene
|
| 456 |
+
if 'c1ccc2c(c1)cc(O)cc2' in content:
|
| 457 |
+
return 'NAO2', mods # 6-hydroxy-2-naphthalene
|
| 458 |
+
|
| 459 |
+
# Proline (P) - flexible ring numbers
|
| 460 |
+
if any([
|
| 461 |
+
# Check for any ring number in bond patterns
|
| 462 |
+
(segment.get('bond_after', '').startswith(f'N{n}C(=O)') and 'CCC' in content and
|
| 463 |
+
any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
|
| 464 |
+
for n in '123456789'
|
| 465 |
+
]) or any([
|
| 466 |
+
# Check ending patterns with any ring number
|
| 467 |
+
(f'CCCN{n}' in content and content.endswith('=O') and
|
| 468 |
+
any(f'[C@@H]{n}' in content or f'[C@H]{n}' in content for n in '123456789'))
|
| 469 |
+
for n in '123456789'
|
| 470 |
+
]) or any([
|
| 471 |
+
# Handle CCC[C@H]n patterns
|
| 472 |
+
(content == f'CCC[C@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
|
| 473 |
+
(content == f'CCC[C@@H]{n}' and segment.get('bond_before', '').startswith(f'C(=O)N{n}')) or
|
| 474 |
+
# N-terminal Pro with any ring number
|
| 475 |
+
(f'N{n}CCC[C@H]{n}' in content) or
|
| 476 |
+
(f'N{n}CCC[C@@H]{n}' in content)
|
| 477 |
+
for n in '123456789'
|
| 478 |
+
]):
|
| 479 |
+
return 'Pro', mods
|
| 480 |
+
|
| 481 |
+
# Tryptophan (W) - more specific indole pattern
|
| 482 |
+
if re.search(r'c[0-9]c\[nH\]c[0-9]ccccc[0-9][0-9]', content) and \
|
| 483 |
+
'c[nH]c' in content.replace(' ', ''):
|
| 484 |
+
return 'Trp', mods
|
| 485 |
+
|
| 486 |
+
# Lysine (K) - both patterns
|
| 487 |
+
if '[C@@H](CCCCN)' in content or '[C@H](CCCCN)' in content:
|
| 488 |
+
return 'Lys', mods
|
| 489 |
+
|
| 490 |
+
# Arginine (R) - both patterns
|
| 491 |
+
if '[C@@H](CCCNC(=N)N)' in content or '[C@H](CCCNC(=N)N)' in content:
|
| 492 |
+
return 'Arg', mods
|
| 493 |
+
|
| 494 |
+
if ('C[C@H](CCCC)' in content or 'C[C@@H](CCCC)' in content) and 'CC(C)' not in content:
|
| 495 |
+
return 'Nle', mods
|
| 496 |
+
|
| 497 |
+
# Ornithine (Orn) - 3-carbon chain with NH2
|
| 498 |
+
if ('C[C@H](CCCN)' in content or 'C[C@@H](CCCN)' in content) and 'CC(C)' not in content:
|
| 499 |
+
return 'Orn', mods
|
| 500 |
+
|
| 501 |
+
# 2-Naphthylalanine (2Nal) - distinct from Phe pattern
|
| 502 |
+
if ('Cc3cc2ccccc2c3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 503 |
+
return '2Nal', mods
|
| 504 |
+
|
| 505 |
+
# Cyclohexylalanine (Cha) - already in your code but moved here for clarity
|
| 506 |
+
if 'N2CCCCC2' in content or 'CCCCC2' in content:
|
| 507 |
+
return 'Cha', mods
|
| 508 |
+
|
| 509 |
+
# Aminobutyric acid (Abu) - 2-carbon chain
|
| 510 |
+
if ('C[C@H](CC)' in content or 'C[C@@H](CC)' in content) and not any(p in content for p in ['CC(C)', 'CCCC', 'CCC(C)']):
|
| 511 |
+
return 'Abu', mods
|
| 512 |
+
|
| 513 |
+
# Pipecolic acid (Pip) - 6-membered ring like Pro
|
| 514 |
+
if ('N3CCCCC3' in content or 'CCCCC3' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 515 |
+
return 'Pip', mods
|
| 516 |
+
|
| 517 |
+
# Cyclohexylglycine (Chg) - direct cyclohexyl without CH2
|
| 518 |
+
if ('C[C@H](C1CCCCC1)' in content or 'C[C@@H](C1CCCCC1)' in content):
|
| 519 |
+
return 'Chg', mods
|
| 520 |
+
|
| 521 |
+
# 4-Fluorophenylalanine (4F-Phe)
|
| 522 |
+
if ('Cc2ccc(F)cc2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 523 |
+
return '4F-Phe', mods
|
| 524 |
+
|
| 525 |
+
# Regular residue identification
|
| 526 |
+
if ('NCC(=O)' in content) or (content == 'C'):
|
| 527 |
+
# Middle case - between bonds
|
| 528 |
+
if segment.get('bond_before') and segment.get('bond_after'):
|
| 529 |
+
if ('C(=O)N' in segment['bond_before'] or 'C(=O)N(C)' in segment['bond_before']):
|
| 530 |
+
return 'Gly', mods
|
| 531 |
+
# Terminal case - at the end
|
| 532 |
+
elif segment.get('bond_before') and segment.get('bond_before').startswith('C(=O)N'):
|
| 533 |
+
return 'Gly', mods
|
| 534 |
+
|
| 535 |
+
if 'CC(C)C[C@H]' in content or 'CC(C)C[C@@H]' in content:
|
| 536 |
+
return 'Leu', mods
|
| 537 |
+
if '[C@@H](CC(C)C)' in content or '[C@H](CC(C)C)' in content:
|
| 538 |
+
return 'Leu', mods
|
| 539 |
+
|
| 540 |
+
if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content:
|
| 541 |
+
return 'Thr', mods
|
| 542 |
+
|
| 543 |
+
if '[C@H](Cc2ccccc2)' in content or '[C@@H](Cc2ccccc2)' in content:
|
| 544 |
+
return 'Phe', mods
|
| 545 |
+
|
| 546 |
+
if ('[C@H](C(C)C)' in content or # With outer parentheses
|
| 547 |
+
'[C@@H](C(C)C)' in content or # With outer parentheses
|
| 548 |
+
'[C@H]C(C)C' in content or # Without outer parentheses
|
| 549 |
+
'[C@@H]C(C)C' in content): # Without outer parentheses
|
| 550 |
+
if not any(p in content for p in ['CC(C)C[C@H]', 'CC(C)C[C@@H]']): # Still check not Leu
|
| 551 |
+
return 'Val', mods
|
| 552 |
+
|
| 553 |
+
if '[C@H](COC(C)(C)C)' in content or '[C@@H](COC(C)(C)C)' in content:
|
| 554 |
+
return 'O-tBu', mods
|
| 555 |
+
|
| 556 |
+
if any([
|
| 557 |
+
'CC[C@H](C)' in content,
|
| 558 |
+
'CC[C@@H](C)' in content,
|
| 559 |
+
'C(C)C[C@H]' in content and 'CC(C)C' not in content,
|
| 560 |
+
'C(C)C[C@@H]' in content and 'CC(C)C' not in content
|
| 561 |
+
]):
|
| 562 |
+
return 'Ile', mods
|
| 563 |
+
|
| 564 |
+
if ('[C@H](C)' in content or '[C@@H](C)' in content):
|
| 565 |
+
if not any(p in content for p in ['C(C)C', 'COC', 'CN(', 'C(C)O', 'CC[C@H]', 'CC[C@@H]']):
|
| 566 |
+
return 'Ala', mods
|
| 567 |
+
|
| 568 |
+
# Tyrosine (Tyr) - 4-hydroxybenzyl side chain
|
| 569 |
+
if re.search(r'Cc[0-9]ccc\(O\)cc[0-9]', content):
|
| 570 |
+
return 'Tyr', mods
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
# Serine (Ser) - Hydroxymethyl side chain
|
| 574 |
+
if '[C@H](CO)' in content or '[C@@H](CO)' in content:
|
| 575 |
+
if not ('C(C)O' in content or 'COC' in content):
|
| 576 |
+
return 'Ser', mods
|
| 577 |
+
|
| 578 |
+
# Threonine (Thr) - 1-hydroxyethyl side chain
|
| 579 |
+
if '[C@@H]([C@@H](C)O)' in content or '[C@H]([C@H](C)O)' in content or '[C@@H](C)O' in content or '[C@H](C)O' in content:
|
| 580 |
+
return 'Thr', mods
|
| 581 |
+
|
| 582 |
+
# Cysteine (Cys) - Thiol side chain
|
| 583 |
+
if '[C@H](CS)' in content or '[C@@H](CS)' in content:
|
| 584 |
+
return 'Cys', mods
|
| 585 |
+
|
| 586 |
+
# Methionine (Met) - Methylthioethyl side chain
|
| 587 |
+
if ('C[C@H](CCSC)' in content or 'C[C@@H](CCSC)' in content):
|
| 588 |
+
return 'Met', mods
|
| 589 |
+
|
| 590 |
+
# Asparagine (Asn) - Carbamoylmethyl side chain
|
| 591 |
+
if ('CC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 592 |
+
return 'Asn', mods
|
| 593 |
+
|
| 594 |
+
# Glutamine (Gln) - Carbamoylethyl side chain
|
| 595 |
+
if ('CCC(=O)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 596 |
+
return 'Gln', mods
|
| 597 |
+
|
| 598 |
+
# Aspartic acid (Asp) - Carboxymethyl side chain
|
| 599 |
+
if ('CC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 600 |
+
return 'Asp', mods
|
| 601 |
+
|
| 602 |
+
# Glutamic acid (Glu) - Carboxyethyl side chain
|
| 603 |
+
if ('CCC(=O)O' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 604 |
+
return 'Glu', mods
|
| 605 |
+
|
| 606 |
+
# Arginine (Arg) - 3-guanidinopropyl side chain
|
| 607 |
+
if ('CCCNC(=N)N' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 608 |
+
return 'Arg', mods
|
| 609 |
+
|
| 610 |
+
# Histidine (His) - Imidazole side chain
|
| 611 |
+
if ('Cc2cnc[nH]2' in content) and ('C[C@H]' in content or 'C[C@@H]' in content):
|
| 612 |
+
return 'His', mods
|
| 613 |
+
|
| 614 |
+
return None, mods
|
| 615 |
+
|
| 616 |
+
def get_modifications(self, segment):
|
| 617 |
+
"""Get modifications based on bond types"""
|
| 618 |
+
mods = []
|
| 619 |
+
if segment.get('bond_after'):
|
| 620 |
+
if 'N(C)' in segment['bond_after'] or segment['bond_after'].startswith('C(=O)N(C)'):
|
| 621 |
+
mods.append('N-Me')
|
| 622 |
+
if 'OC(=O)' in segment['bond_after']:
|
| 623 |
+
mods.append('O-linked')
|
| 624 |
+
return mods
|
| 625 |
+
|
| 626 |
+
def analyze_structure(self, smiles):
|
| 627 |
+
"""Main analysis function with debug output"""
|
| 628 |
+
print("\nAnalyzing structure:", smiles)
|
| 629 |
+
|
| 630 |
+
# Split into segments
|
| 631 |
+
segments = self.split_on_bonds(smiles)
|
| 632 |
+
|
| 633 |
+
print("\nSegment Analysis:")
|
| 634 |
+
sequence = []
|
| 635 |
+
for i, segment in enumerate(segments):
|
| 636 |
+
print(f"\nSegment {i}:")
|
| 637 |
+
print(f"Content: {segment['content']}")
|
| 638 |
+
print(f"Bond before: {segment.get('bond_before', 'None')}")
|
| 639 |
+
print(f"Bond after: {segment.get('bond_after', 'None')}")
|
| 640 |
+
|
| 641 |
+
residue, mods = self.identify_residue(segment)
|
| 642 |
+
if residue:
|
| 643 |
+
if mods:
|
| 644 |
+
sequence.append(f"{residue}({','.join(mods)})")
|
| 645 |
+
else:
|
| 646 |
+
sequence.append(residue)
|
| 647 |
+
print(f"Identified as: {residue}")
|
| 648 |
+
print(f"Modifications: {mods}")
|
| 649 |
+
else:
|
| 650 |
+
print(f"Warning: Could not identify residue in segment: {segment['content']}")
|
| 651 |
+
|
| 652 |
+
# Check if cyclic
|
| 653 |
+
is_cyclic, peptide_cycles, aromatic_cycles = self.is_cyclic(smiles)
|
| 654 |
+
three_letter = '-'.join(sequence)
|
| 655 |
+
one_letter = ''.join(self.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence)
|
| 656 |
+
|
| 657 |
+
if is_cyclic:
|
| 658 |
+
three_letter = f"cyclo({three_letter})"
|
| 659 |
+
one_letter = f"cyclo({one_letter})"
|
| 660 |
+
|
| 661 |
+
print(f"\nFinal sequence: {three_letter}")
|
| 662 |
+
print(f"One-letter code: {one_letter}")
|
| 663 |
+
print(f"Is cyclic: {is_cyclic}")
|
| 664 |
+
#print(f"Peptide cycles: {peptide_cycles}")
|
| 665 |
+
#print(f"Aromatic cycles: {aromatic_cycles}")
|
| 666 |
+
|
| 667 |
+
return three_letter, len(segments)
|
| 668 |
+
"""return {
|
| 669 |
+
'three_letter': three_letter,
|
| 670 |
+
#'one_letter': one_letter,
|
| 671 |
+
'is_cyclic': is_cyclic
|
| 672 |
+
}"""
|
| 673 |
+
|
| 674 |
+
def return_sequence(self, smiles):
|
| 675 |
+
"""Main analysis function with debug output"""
|
| 676 |
+
print("\nAnalyzing structure:", smiles)
|
| 677 |
+
|
| 678 |
+
# Split into segments
|
| 679 |
+
segments = self.split_on_bonds(smiles)
|
| 680 |
+
|
| 681 |
+
print("\nSegment Analysis:")
|
| 682 |
+
sequence = []
|
| 683 |
+
for i, segment in enumerate(segments):
|
| 684 |
+
print(f"\nSegment {i}:")
|
| 685 |
+
print(f"Content: {segment['content']}")
|
| 686 |
+
print(f"Bond before: {segment.get('bond_before', 'None')}")
|
| 687 |
+
print(f"Bond after: {segment.get('bond_after', 'None')}")
|
| 688 |
+
|
| 689 |
+
residue, mods = self.identify_residue(segment)
|
| 690 |
+
if residue:
|
| 691 |
+
if mods:
|
| 692 |
+
sequence.append(f"{residue}({','.join(mods)})")
|
| 693 |
+
else:
|
| 694 |
+
sequence.append(residue)
|
| 695 |
+
print(f"Identified as: {residue}")
|
| 696 |
+
print(f"Modifications: {mods}")
|
| 697 |
+
else:
|
| 698 |
+
print(f"Warning: Could not identify residue in segment: {segment['content']}")
|
| 699 |
+
|
| 700 |
+
return sequence
|
| 701 |
+
|
| 702 |
+
"""
|
| 703 |
+
def annotate_cyclic_structure(mol, sequence):
|
| 704 |
+
'''Create annotated 2D structure with clear, non-overlapping residue labels'''
|
| 705 |
+
# Generate 2D coordinates
|
| 706 |
+
# Generate 2D coordinates
|
| 707 |
+
AllChem.Compute2DCoords(mol)
|
| 708 |
+
|
| 709 |
+
# Create drawer with larger size for annotations
|
| 710 |
+
drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000) # Even larger size
|
| 711 |
+
|
| 712 |
+
# Get residue list and reverse it to match structural representation
|
| 713 |
+
if sequence.startswith('cyclo('):
|
| 714 |
+
residues = sequence[6:-1].split('-')
|
| 715 |
+
else:
|
| 716 |
+
residues = sequence.split('-')
|
| 717 |
+
residues = list(reversed(residues)) # Reverse the sequence
|
| 718 |
+
|
| 719 |
+
# Draw molecule first to get its bounds
|
| 720 |
+
drawer.drawOptions().addAtomIndices = False
|
| 721 |
+
drawer.DrawMolecule(mol)
|
| 722 |
+
drawer.FinishDrawing()
|
| 723 |
+
|
| 724 |
+
# Convert to PIL Image
|
| 725 |
+
img = Image.open(BytesIO(drawer.GetDrawingText()))
|
| 726 |
+
draw = ImageDraw.Draw(img)
|
| 727 |
+
|
| 728 |
+
try:
|
| 729 |
+
# Try to use DejaVuSans as it's commonly available on Linux systems
|
| 730 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
|
| 731 |
+
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
|
| 732 |
+
except OSError:
|
| 733 |
+
try:
|
| 734 |
+
# Fallback to Arial if available (common on Windows)
|
| 735 |
+
font = ImageFont.truetype("arial.ttf", 60)
|
| 736 |
+
small_font = ImageFont.truetype("arial.ttf", 60)
|
| 737 |
+
except OSError:
|
| 738 |
+
# If no TrueType fonts are available, fall back to default
|
| 739 |
+
print("Warning: TrueType fonts not available, using default font")
|
| 740 |
+
font = ImageFont.load_default()
|
| 741 |
+
small_font = ImageFont.load_default()
|
| 742 |
+
# Get molecule bounds
|
| 743 |
+
conf = mol.GetConformer()
|
| 744 |
+
positions = []
|
| 745 |
+
for i in range(mol.GetNumAtoms()):
|
| 746 |
+
pos = conf.GetAtomPosition(i)
|
| 747 |
+
positions.append((pos.x, pos.y))
|
| 748 |
+
|
| 749 |
+
x_coords = [p[0] for p in positions]
|
| 750 |
+
y_coords = [p[1] for p in positions]
|
| 751 |
+
min_x, max_x = min(x_coords), max(x_coords)
|
| 752 |
+
min_y, max_y = min(y_coords), max(y_coords)
|
| 753 |
+
|
| 754 |
+
# Calculate scaling factors
|
| 755 |
+
scale = 150 # Increased scale factor
|
| 756 |
+
center_x = 1000 # Image center
|
| 757 |
+
center_y = 1000
|
| 758 |
+
|
| 759 |
+
# Add residue labels in a circular arrangement around the structure
|
| 760 |
+
n_residues = len(residues)
|
| 761 |
+
radius = 700 # Distance of labels from center
|
| 762 |
+
|
| 763 |
+
# Start from the rightmost point (3 o'clock position) and go counterclockwise
|
| 764 |
+
# Offset by -3 positions to align with structure
|
| 765 |
+
offset = 0 # Adjust this value to match the structure alignment
|
| 766 |
+
for i, residue in enumerate(residues):
|
| 767 |
+
# Calculate position in a circle around the structure
|
| 768 |
+
# Start from 0 (3 o'clock) and go counterclockwise
|
| 769 |
+
angle = -(2 * np.pi * ((i + offset) % n_residues) / n_residues)
|
| 770 |
+
|
| 771 |
+
# Calculate label position
|
| 772 |
+
label_x = center_x + radius * np.cos(angle)
|
| 773 |
+
label_y = center_y + radius * np.sin(angle)
|
| 774 |
+
|
| 775 |
+
# Draw residue label
|
| 776 |
+
text = f"{i+1}. {residue}"
|
| 777 |
+
bbox = draw.textbbox((label_x, label_y), text, font=font)
|
| 778 |
+
padding = 10
|
| 779 |
+
draw.rectangle([bbox[0]-padding, bbox[1]-padding,
|
| 780 |
+
bbox[2]+padding, bbox[3]+padding],
|
| 781 |
+
fill='white', outline='white')
|
| 782 |
+
draw.text((label_x, label_y), text,
|
| 783 |
+
font=font, fill='black', anchor="mm")
|
| 784 |
+
|
| 785 |
+
# Add sequence at the top with white background
|
| 786 |
+
seq_text = f"Sequence: {sequence}"
|
| 787 |
+
bbox = draw.textbbox((center_x, 100), seq_text, font=small_font)
|
| 788 |
+
padding = 10
|
| 789 |
+
draw.rectangle([bbox[0]-padding, bbox[1]-padding,
|
| 790 |
+
bbox[2]+padding, bbox[3]+padding],
|
| 791 |
+
fill='white', outline='white')
|
| 792 |
+
draw.text((center_x, 100), seq_text,
|
| 793 |
+
font=small_font, fill='black', anchor="mm")
|
| 794 |
+
|
| 795 |
+
return img
|
| 796 |
+
|
| 797 |
+
"""
|
| 798 |
+
def annotate_cyclic_structure(mol, sequence):
|
| 799 |
+
"""Create structure visualization with just the sequence header"""
|
| 800 |
+
# Generate 2D coordinates
|
| 801 |
+
AllChem.Compute2DCoords(mol)
|
| 802 |
+
|
| 803 |
+
# Create drawer with larger size for annotations
|
| 804 |
+
drawer = Draw.rdMolDraw2D.MolDraw2DCairo(2000, 2000)
|
| 805 |
+
|
| 806 |
+
# Draw molecule first
|
| 807 |
+
drawer.drawOptions().addAtomIndices = False
|
| 808 |
+
drawer.DrawMolecule(mol)
|
| 809 |
+
drawer.FinishDrawing()
|
| 810 |
+
|
| 811 |
+
# Convert to PIL Image
|
| 812 |
+
img = Image.open(BytesIO(drawer.GetDrawingText()))
|
| 813 |
+
draw = ImageDraw.Draw(img)
|
| 814 |
+
try:
|
| 815 |
+
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 60)
|
| 816 |
+
except OSError:
|
| 817 |
+
try:
|
| 818 |
+
small_font = ImageFont.truetype("arial.ttf", 60)
|
| 819 |
+
except OSError:
|
| 820 |
+
print("Warning: TrueType fonts not available, using default font")
|
| 821 |
+
small_font = ImageFont.load_default()
|
| 822 |
+
|
| 823 |
+
# Add just the sequence header at the top
|
| 824 |
+
seq_text = f"Sequence: {sequence}"
|
| 825 |
+
bbox = draw.textbbox((1000, 100), seq_text, font=small_font)
|
| 826 |
+
padding = 10
|
| 827 |
+
draw.rectangle([bbox[0]-padding, bbox[1]-padding,
|
| 828 |
+
bbox[2]+padding, bbox[3]+padding],
|
| 829 |
+
fill='white', outline='white')
|
| 830 |
+
draw.text((1000, 100), seq_text,
|
| 831 |
+
font=small_font, fill='black', anchor="mm")
|
| 832 |
+
|
| 833 |
+
return img
|
| 834 |
+
|
| 835 |
+
def create_enhanced_linear_viz(sequence, smiles):
|
| 836 |
+
"""Create an enhanced linear representation using PeptideAnalyzer"""
|
| 837 |
+
analyzer = PeptideAnalyzer() # Create analyzer instance
|
| 838 |
+
|
| 839 |
+
# Create figure with two subplots
|
| 840 |
+
fig = plt.figure(figsize=(15, 10))
|
| 841 |
+
gs = fig.add_gridspec(2, 1, height_ratios=[1, 2])
|
| 842 |
+
ax_struct = fig.add_subplot(gs[0])
|
| 843 |
+
ax_detail = fig.add_subplot(gs[1])
|
| 844 |
+
|
| 845 |
+
# Parse sequence and get residues
|
| 846 |
+
if sequence.startswith('cyclo('):
|
| 847 |
+
residues = sequence[6:-1].split('-')
|
| 848 |
+
else:
|
| 849 |
+
residues = sequence.split('-')
|
| 850 |
+
|
| 851 |
+
# Get segments using analyzer
|
| 852 |
+
segments = analyzer.split_on_bonds(smiles)
|
| 853 |
+
|
| 854 |
+
# Debug print
|
| 855 |
+
print(f"Number of residues: {len(residues)}")
|
| 856 |
+
print(f"Number of segments: {len(segments)}")
|
| 857 |
+
|
| 858 |
+
# Top subplot - Basic structure
|
| 859 |
+
ax_struct.set_xlim(0, 10)
|
| 860 |
+
ax_struct.set_ylim(0, 2)
|
| 861 |
+
|
| 862 |
+
num_residues = len(residues)
|
| 863 |
+
spacing = 9.0 / (num_residues - 1) if num_residues > 1 else 9.0
|
| 864 |
+
|
| 865 |
+
# Draw basic structure
|
| 866 |
+
y_pos = 1.5
|
| 867 |
+
for i in range(num_residues):
|
| 868 |
+
x_pos = 0.5 + i * spacing
|
| 869 |
+
|
| 870 |
+
# Draw amino acid box
|
| 871 |
+
rect = patches.Rectangle((x_pos-0.3, y_pos-0.2), 0.6, 0.4,
|
| 872 |
+
facecolor='lightblue', edgecolor='black')
|
| 873 |
+
ax_struct.add_patch(rect)
|
| 874 |
+
|
| 875 |
+
# Draw connecting bonds if not the last residue
|
| 876 |
+
if i < num_residues - 1:
|
| 877 |
+
segment = segments[i] if i < len(segments) else None
|
| 878 |
+
if segment:
|
| 879 |
+
# Determine bond type from segment info
|
| 880 |
+
bond_type = 'ester' if 'O-linked' in segment.get('bond_after', '') else 'peptide'
|
| 881 |
+
is_n_methylated = 'N-Me' in segment.get('bond_after', '')
|
| 882 |
+
|
| 883 |
+
bond_color = 'red' if bond_type == 'ester' else 'black'
|
| 884 |
+
linestyle = '--' if bond_type == 'ester' else '-'
|
| 885 |
+
|
| 886 |
+
# Draw bond line
|
| 887 |
+
ax_struct.plot([x_pos+0.3, x_pos+spacing-0.3], [y_pos, y_pos],
|
| 888 |
+
color=bond_color, linestyle=linestyle, linewidth=2)
|
| 889 |
+
|
| 890 |
+
# Add bond type label
|
| 891 |
+
mid_x = x_pos + spacing/2
|
| 892 |
+
bond_label = f"{bond_type}"
|
| 893 |
+
if is_n_methylated:
|
| 894 |
+
bond_label += "\n(N-Me)"
|
| 895 |
+
ax_struct.text(mid_x, y_pos+0.1, bond_label,
|
| 896 |
+
ha='center', va='bottom', fontsize=10,
|
| 897 |
+
color=bond_color)
|
| 898 |
+
|
| 899 |
+
# Add residue label
|
| 900 |
+
ax_struct.text(x_pos, y_pos-0.5, residues[i],
|
| 901 |
+
ha='center', va='top', fontsize=14)
|
| 902 |
+
|
| 903 |
+
# Bottom subplot - Detailed breakdown
|
| 904 |
+
ax_detail.set_ylim(0, len(segments)+1)
|
| 905 |
+
ax_detail.set_xlim(0, 1)
|
| 906 |
+
|
| 907 |
+
# Create detailed breakdown
|
| 908 |
+
segment_y = len(segments) # Start from top
|
| 909 |
+
for i, segment in enumerate(segments):
|
| 910 |
+
y = segment_y - i
|
| 911 |
+
|
| 912 |
+
# Check if this is a bond or residue
|
| 913 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 914 |
+
if residue:
|
| 915 |
+
text = f"Residue {i+1}: {residue}"
|
| 916 |
+
if mods:
|
| 917 |
+
text += f" ({', '.join(mods)})"
|
| 918 |
+
color = 'blue'
|
| 919 |
+
else:
|
| 920 |
+
# Must be a bond
|
| 921 |
+
text = f"Bond {i}: "
|
| 922 |
+
if 'O-linked' in segment.get('bond_after', ''):
|
| 923 |
+
text += "ester"
|
| 924 |
+
elif 'N-Me' in segment.get('bond_after', ''):
|
| 925 |
+
text += "peptide (N-methylated)"
|
| 926 |
+
else:
|
| 927 |
+
text += "peptide"
|
| 928 |
+
color = 'red'
|
| 929 |
+
|
| 930 |
+
# Add segment analysis
|
| 931 |
+
ax_detail.text(0.05, y, text, fontsize=12, color=color)
|
| 932 |
+
ax_detail.text(0.5, y, f"SMILES: {segment.get('content', '')}", fontsize=10, color='gray')
|
| 933 |
+
|
| 934 |
+
# If cyclic, add connection indicator
|
| 935 |
+
if sequence.startswith('cyclo('):
|
| 936 |
+
ax_struct.annotate('', xy=(9.5, y_pos), xytext=(0.5, y_pos),
|
| 937 |
+
arrowprops=dict(arrowstyle='<->', color='red', lw=2))
|
| 938 |
+
ax_struct.text(5, y_pos+0.3, 'Cyclic Connection',
|
| 939 |
+
ha='center', color='red', fontsize=14)
|
| 940 |
+
|
| 941 |
+
# Add titles and adjust layout
|
| 942 |
+
ax_struct.set_title("Peptide Structure Overview", pad=20)
|
| 943 |
+
ax_detail.set_title("Segment Analysis Breakdown", pad=20)
|
| 944 |
+
|
| 945 |
+
# Remove axes
|
| 946 |
+
for ax in [ax_struct, ax_detail]:
|
| 947 |
+
ax.set_xticks([])
|
| 948 |
+
ax.set_yticks([])
|
| 949 |
+
ax.axis('off')
|
| 950 |
+
|
| 951 |
+
plt.tight_layout()
|
| 952 |
+
return fig
|
| 953 |
+
|
| 954 |
+
class PeptideStructureGenerator:
|
| 955 |
+
"""A class to generate 3D structures of peptides using different embedding methods"""
|
| 956 |
+
|
| 957 |
+
@staticmethod
|
| 958 |
+
def prepare_molecule(smiles):
|
| 959 |
+
"""Prepare molecule with proper hydrogen handling"""
|
| 960 |
+
mol = Chem.MolFromSmiles(smiles, sanitize=False)
|
| 961 |
+
if mol is None:
|
| 962 |
+
raise ValueError("Failed to create molecule from SMILES")
|
| 963 |
+
|
| 964 |
+
# Calculate valence for each atom
|
| 965 |
+
for atom in mol.GetAtoms():
|
| 966 |
+
atom.UpdatePropertyCache(strict=False)
|
| 967 |
+
|
| 968 |
+
# Sanitize with reduced requirements
|
| 969 |
+
Chem.SanitizeMol(mol,
|
| 970 |
+
sanitizeOps=Chem.SANITIZE_FINDRADICALS|
|
| 971 |
+
Chem.SANITIZE_KEKULIZE|
|
| 972 |
+
Chem.SANITIZE_SETAROMATICITY|
|
| 973 |
+
Chem.SANITIZE_SETCONJUGATION|
|
| 974 |
+
Chem.SANITIZE_SETHYBRIDIZATION|
|
| 975 |
+
Chem.SANITIZE_CLEANUPCHIRALITY)
|
| 976 |
+
|
| 977 |
+
mol = Chem.AddHs(mol)
|
| 978 |
+
return mol
|
| 979 |
+
|
| 980 |
+
@staticmethod
|
| 981 |
+
def get_etkdg_params(attempt=0):
|
| 982 |
+
"""Get ETKDG parameters with optional modifications based on attempt number"""
|
| 983 |
+
params = AllChem.ETKDGv3()
|
| 984 |
+
params.randomSeed = -1
|
| 985 |
+
params.maxIterations = 200
|
| 986 |
+
params.numThreads = 4 # Reduced for web interface
|
| 987 |
+
params.useBasicKnowledge = True
|
| 988 |
+
params.enforceChirality = True
|
| 989 |
+
params.useExpTorsionAnglePrefs = True
|
| 990 |
+
params.useSmallRingTorsions = True
|
| 991 |
+
params.useMacrocycleTorsions = True
|
| 992 |
+
params.ETversion = 2
|
| 993 |
+
params.pruneRmsThresh = -1
|
| 994 |
+
params.embedRmsThresh = 0.5
|
| 995 |
+
|
| 996 |
+
if attempt > 10:
|
| 997 |
+
params.bondLength = 1.5 + (attempt - 10) * 0.02
|
| 998 |
+
params.useExpTorsionAnglePrefs = False
|
| 999 |
+
|
| 1000 |
+
return params
|
| 1001 |
+
|
| 1002 |
+
def generate_structure_etkdg(self, smiles, max_attempts=20):
|
| 1003 |
+
"""Generate 3D structure using ETKDG without UFF optimization"""
|
| 1004 |
+
success = False
|
| 1005 |
+
mol = None
|
| 1006 |
+
|
| 1007 |
+
for attempt in range(max_attempts):
|
| 1008 |
+
try:
|
| 1009 |
+
mol = self.prepare_molecule(smiles)
|
| 1010 |
+
params = self.get_etkdg_params(attempt)
|
| 1011 |
+
|
| 1012 |
+
if AllChem.EmbedMolecule(mol, params) == 0:
|
| 1013 |
+
success = True
|
| 1014 |
+
break
|
| 1015 |
+
except Exception as e:
|
| 1016 |
+
continue
|
| 1017 |
+
|
| 1018 |
+
if not success:
|
| 1019 |
+
raise ValueError("Failed to generate structure with ETKDG")
|
| 1020 |
+
|
| 1021 |
+
return mol
|
| 1022 |
+
|
| 1023 |
+
def generate_structure_uff(self, smiles, max_attempts=20):
|
| 1024 |
+
"""Generate 3D structure using ETKDG followed by UFF optimization"""
|
| 1025 |
+
best_mol = None
|
| 1026 |
+
lowest_energy = float('inf')
|
| 1027 |
+
|
| 1028 |
+
for attempt in range(max_attempts):
|
| 1029 |
+
try:
|
| 1030 |
+
test_mol = self.prepare_molecule(smiles)
|
| 1031 |
+
params = self.get_etkdg_params(attempt)
|
| 1032 |
+
|
| 1033 |
+
if AllChem.EmbedMolecule(test_mol, params) == 0:
|
| 1034 |
+
res = AllChem.UFFOptimizeMolecule(test_mol, maxIters=2000,
|
| 1035 |
+
vdwThresh=10.0, confId=0,
|
| 1036 |
+
ignoreInterfragInteractions=True)
|
| 1037 |
+
|
| 1038 |
+
if res == 0:
|
| 1039 |
+
ff = AllChem.UFFGetMoleculeForceField(test_mol)
|
| 1040 |
+
if ff:
|
| 1041 |
+
current_energy = ff.CalcEnergy()
|
| 1042 |
+
if current_energy < lowest_energy:
|
| 1043 |
+
lowest_energy = current_energy
|
| 1044 |
+
best_mol = Chem.Mol(test_mol)
|
| 1045 |
+
except Exception:
|
| 1046 |
+
continue
|
| 1047 |
+
|
| 1048 |
+
if best_mol is None:
|
| 1049 |
+
raise ValueError("Failed to generate optimized structure")
|
| 1050 |
+
|
| 1051 |
+
return best_mol
|
| 1052 |
+
|
| 1053 |
+
@staticmethod
|
| 1054 |
+
def mol_to_sdf_bytes(mol):
|
| 1055 |
+
"""Convert RDKit molecule to SDF file bytes"""
|
| 1056 |
+
# First write to StringIO in text mode
|
| 1057 |
+
sio = StringIO()
|
| 1058 |
+
writer = Chem.SDWriter(sio)
|
| 1059 |
+
writer.write(mol)
|
| 1060 |
+
writer.close()
|
| 1061 |
+
|
| 1062 |
+
# Convert the string to bytes
|
| 1063 |
+
return sio.getvalue().encode('utf-8')
|
| 1064 |
+
|
| 1065 |
+
def process_input(smiles_input=None, file_obj=None, show_linear=False,
|
| 1066 |
+
show_segment_details=False, generate_3d=False, use_uff=False):
|
| 1067 |
+
"""Process input and create visualizations using PeptideAnalyzer"""
|
| 1068 |
+
analyzer = PeptideAnalyzer()
|
| 1069 |
+
temp_dir = tempfile.mkdtemp() if generate_3d else None
|
| 1070 |
+
structure_files = []
|
| 1071 |
+
|
| 1072 |
+
# Handle direct SMILES input
|
| 1073 |
+
if smiles_input:
|
| 1074 |
+
smiles = smiles_input.strip()
|
| 1075 |
+
|
| 1076 |
+
# First check if it's a peptide using analyzer's method
|
| 1077 |
+
if not analyzer.is_peptide(smiles):
|
| 1078 |
+
return "Error: Input SMILES does not appear to be a peptide structure.", None, None
|
| 1079 |
+
|
| 1080 |
+
try:
|
| 1081 |
+
# Create molecule
|
| 1082 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 1083 |
+
if mol is None:
|
| 1084 |
+
return "Error: Invalid SMILES notation.", None, None
|
| 1085 |
+
|
| 1086 |
+
# Generate 3D structures if requested
|
| 1087 |
+
if generate_3d:
|
| 1088 |
+
generator = PeptideStructureGenerator()
|
| 1089 |
+
|
| 1090 |
+
try:
|
| 1091 |
+
# Generate ETKDG structure
|
| 1092 |
+
mol_etkdg = generator.generate_structure_etkdg(smiles)
|
| 1093 |
+
etkdg_path = os.path.join(temp_dir, "structure_etkdg.sdf")
|
| 1094 |
+
writer = Chem.SDWriter(etkdg_path)
|
| 1095 |
+
writer.write(mol_etkdg)
|
| 1096 |
+
writer.close()
|
| 1097 |
+
structure_files.append(etkdg_path)
|
| 1098 |
+
|
| 1099 |
+
# Generate UFF structure if requested
|
| 1100 |
+
if use_uff:
|
| 1101 |
+
mol_uff = generator.generate_structure_uff(smiles)
|
| 1102 |
+
uff_path = os.path.join(temp_dir, "structure_uff.sdf")
|
| 1103 |
+
writer = Chem.SDWriter(uff_path)
|
| 1104 |
+
writer.write(mol_uff)
|
| 1105 |
+
writer.close()
|
| 1106 |
+
structure_files.append(uff_path)
|
| 1107 |
+
|
| 1108 |
+
except Exception as e:
|
| 1109 |
+
return f"Error generating 3D structures: {str(e)}", None, None, None
|
| 1110 |
+
|
| 1111 |
+
# Use analyzer to get sequence
|
| 1112 |
+
segments = analyzer.split_on_bonds(smiles)
|
| 1113 |
+
|
| 1114 |
+
# Process segments and build sequence
|
| 1115 |
+
sequence_parts = []
|
| 1116 |
+
output_text = ""
|
| 1117 |
+
|
| 1118 |
+
# Only include segment analysis in output if requested
|
| 1119 |
+
if show_segment_details:
|
| 1120 |
+
output_text += "Segment Analysis:\n"
|
| 1121 |
+
for i, segment in enumerate(segments):
|
| 1122 |
+
output_text += f"\nSegment {i}:\n"
|
| 1123 |
+
output_text += f"Content: {segment['content']}\n"
|
| 1124 |
+
output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
|
| 1125 |
+
output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
|
| 1126 |
+
|
| 1127 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1128 |
+
if residue:
|
| 1129 |
+
if mods:
|
| 1130 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1131 |
+
else:
|
| 1132 |
+
sequence_parts.append(residue)
|
| 1133 |
+
output_text += f"Identified as: {residue}\n"
|
| 1134 |
+
output_text += f"Modifications: {mods}\n"
|
| 1135 |
+
else:
|
| 1136 |
+
output_text += f"Warning: Could not identify residue in segment: {segment['content']}\n"
|
| 1137 |
+
output_text += "\n"
|
| 1138 |
+
else:
|
| 1139 |
+
# Just build sequence without detailed analysis in output
|
| 1140 |
+
for segment in segments:
|
| 1141 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1142 |
+
if residue:
|
| 1143 |
+
if mods:
|
| 1144 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1145 |
+
else:
|
| 1146 |
+
sequence_parts.append(residue)
|
| 1147 |
+
|
| 1148 |
+
# Check if cyclic using analyzer's method
|
| 1149 |
+
is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
|
| 1150 |
+
three_letter = '-'.join(sequence_parts)
|
| 1151 |
+
one_letter = ''.join(analyzer.three_to_one.get(aa.split('(')[0], 'X') for aa in sequence_parts)
|
| 1152 |
+
|
| 1153 |
+
if is_cyclic:
|
| 1154 |
+
three_letter = f"cyclo({three_letter})"
|
| 1155 |
+
one_letter = f"cyclo({one_letter})"
|
| 1156 |
+
|
| 1157 |
+
# Create cyclic structure visualization
|
| 1158 |
+
img_cyclic = annotate_cyclic_structure(mol, three_letter)
|
| 1159 |
+
|
| 1160 |
+
# Create linear representation if requested
|
| 1161 |
+
img_linear = None
|
| 1162 |
+
if show_linear:
|
| 1163 |
+
fig_linear = create_enhanced_linear_viz(three_letter, smiles)
|
| 1164 |
+
buf = BytesIO()
|
| 1165 |
+
fig_linear.savefig(buf, format='png', bbox_inches='tight', dpi=300)
|
| 1166 |
+
buf.seek(0)
|
| 1167 |
+
img_linear = Image.open(buf)
|
| 1168 |
+
plt.close(fig_linear)
|
| 1169 |
+
|
| 1170 |
+
# Add summary to output
|
| 1171 |
+
summary = "Summary:\n"
|
| 1172 |
+
summary += f"Sequence: {three_letter}\n"
|
| 1173 |
+
summary += f"One-letter code: {one_letter}\n"
|
| 1174 |
+
summary += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
|
| 1175 |
+
#if is_cyclic:
|
| 1176 |
+
#summary += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
|
| 1177 |
+
#summary += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
|
| 1178 |
+
|
| 1179 |
+
if structure_files:
|
| 1180 |
+
summary += "\n3D Structures Generated:\n"
|
| 1181 |
+
for filepath in structure_files:
|
| 1182 |
+
summary += f"- {os.path.basename(filepath)}\n"
|
| 1183 |
+
|
| 1184 |
+
return summary + output_text, img_cyclic, img_linear, structure_files if structure_files else None
|
| 1185 |
+
|
| 1186 |
+
except Exception as e:
|
| 1187 |
+
return f"Error processing SMILES: {str(e)}", None, None, None
|
| 1188 |
+
|
| 1189 |
+
# Handle file input
|
| 1190 |
+
if file_obj is not None:
|
| 1191 |
+
try:
|
| 1192 |
+
# Handle file content
|
| 1193 |
+
if hasattr(file_obj, 'name'):
|
| 1194 |
+
with open(file_obj.name, 'r') as f:
|
| 1195 |
+
content = f.read()
|
| 1196 |
+
else:
|
| 1197 |
+
content = file_obj.decode('utf-8') if isinstance(file_obj, bytes) else str(file_obj)
|
| 1198 |
+
|
| 1199 |
+
output_text = ""
|
| 1200 |
+
for line in content.splitlines():
|
| 1201 |
+
smiles = line.strip()
|
| 1202 |
+
if smiles:
|
| 1203 |
+
# Check if it's a peptide
|
| 1204 |
+
if not analyzer.is_peptide(smiles):
|
| 1205 |
+
output_text += f"Skipping non-peptide SMILES: {smiles}\n"
|
| 1206 |
+
continue
|
| 1207 |
+
|
| 1208 |
+
# Process this SMILES
|
| 1209 |
+
segments = analyzer.split_on_bonds(smiles)
|
| 1210 |
+
sequence_parts = []
|
| 1211 |
+
|
| 1212 |
+
# Add segment details if requested
|
| 1213 |
+
if show_segment_details:
|
| 1214 |
+
output_text += f"\nSegment Analysis for SMILES: {smiles}\n"
|
| 1215 |
+
for i, segment in enumerate(segments):
|
| 1216 |
+
output_text += f"\nSegment {i}:\n"
|
| 1217 |
+
output_text += f"Content: {segment['content']}\n"
|
| 1218 |
+
output_text += f"Bond before: {segment.get('bond_before', 'None')}\n"
|
| 1219 |
+
output_text += f"Bond after: {segment.get('bond_after', 'None')}\n"
|
| 1220 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1221 |
+
if residue:
|
| 1222 |
+
if mods:
|
| 1223 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1224 |
+
else:
|
| 1225 |
+
sequence_parts.append(residue)
|
| 1226 |
+
output_text += f"Identified as: {residue}\n"
|
| 1227 |
+
output_text += f"Modifications: {mods}\n"
|
| 1228 |
+
else:
|
| 1229 |
+
for segment in segments:
|
| 1230 |
+
residue, mods = analyzer.identify_residue(segment)
|
| 1231 |
+
if residue:
|
| 1232 |
+
if mods:
|
| 1233 |
+
sequence_parts.append(f"{residue}({','.join(mods)})")
|
| 1234 |
+
else:
|
| 1235 |
+
sequence_parts.append(residue)
|
| 1236 |
+
|
| 1237 |
+
# Get cyclicity and create sequence
|
| 1238 |
+
is_cyclic, peptide_cycles, aromatic_cycles = analyzer.is_cyclic(smiles)
|
| 1239 |
+
sequence = f"cyclo({'-'.join(sequence_parts)})" if is_cyclic else '-'.join(sequence_parts)
|
| 1240 |
+
|
| 1241 |
+
output_text += f"\nSummary for SMILES: {smiles}\n"
|
| 1242 |
+
output_text += f"Sequence: {sequence}\n"
|
| 1243 |
+
output_text += f"Is Cyclic: {'Yes' if is_cyclic else 'No'}\n"
|
| 1244 |
+
if is_cyclic:
|
| 1245 |
+
output_text += f"Peptide Cycles: {', '.join(peptide_cycles)}\n"
|
| 1246 |
+
#output_text += f"Aromatic Cycles: {', '.join(aromatic_cycles)}\n"
|
| 1247 |
+
output_text += "-" * 50 + "\n"
|
| 1248 |
+
|
| 1249 |
+
return output_text, None, None
|
| 1250 |
+
|
| 1251 |
+
except Exception as e:
|
| 1252 |
+
return f"Error processing file: {str(e)}", None, None
|
| 1253 |
+
|
| 1254 |
+
return "No input provided.", None, None
|
| 1255 |
+
|
tr2d2-pep/utils/timer.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time, torch
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
|
| 5 |
+
class StepTimer:
|
| 6 |
+
def __init__(self, device=None):
|
| 7 |
+
self.times = defaultdict(list)
|
| 8 |
+
self.device = device
|
| 9 |
+
self._use_cuda_sync = (
|
| 10 |
+
isinstance(device, torch.device) and device.type == "cuda"
|
| 11 |
+
) or (isinstance(device, str) and "cuda" in device)
|
| 12 |
+
|
| 13 |
+
@contextmanager
|
| 14 |
+
def section(self, name):
|
| 15 |
+
if self._use_cuda_sync:
|
| 16 |
+
torch.cuda.synchronize()
|
| 17 |
+
t0 = time.perf_counter()
|
| 18 |
+
try:
|
| 19 |
+
yield
|
| 20 |
+
finally:
|
| 21 |
+
if self._use_cuda_sync:
|
| 22 |
+
torch.cuda.synchronize()
|
| 23 |
+
dt = time.perf_counter() - t0
|
| 24 |
+
self.times[name].append(dt)
|
| 25 |
+
|
| 26 |
+
def summary(self, top_k=None):
|
| 27 |
+
# returns (name, count, total, mean, p50, p95)
|
| 28 |
+
import numpy as np
|
| 29 |
+
rows = []
|
| 30 |
+
for k, v in self.times.items():
|
| 31 |
+
a = np.array(v, dtype=float)
|
| 32 |
+
rows.append((k, len(a), a.sum(), a.mean(), np.median(a), np.percentile(a, 95)))
|
| 33 |
+
rows.sort(key=lambda r: r[2], reverse=True) # by total time
|
| 34 |
+
return rows[:top_k] if top_k else rows
|
tr2d2-pep/utils/utils.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Console logger utilities.
|
| 2 |
+
|
| 3 |
+
Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
|
| 4 |
+
Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import fsspec
|
| 9 |
+
import lightning
|
| 10 |
+
import torch
|
| 11 |
+
from timm.scheduler import CosineLRScheduler
|
| 12 |
+
import argparse
|
| 13 |
+
import numpy as np
|
| 14 |
+
import random
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
def sample_categorical_logits(logits, dtype=torch.float64):
|
| 18 |
+
# do not require logits to be log-softmaxed
|
| 19 |
+
gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log()
|
| 20 |
+
return (logits + gumbel_noise).argmax(dim=-1)
|
| 21 |
+
|
| 22 |
+
def fsspec_exists(filename):
|
| 23 |
+
"""Check if a file exists using fsspec."""
|
| 24 |
+
fs, _ = fsspec.core.url_to_fs(filename)
|
| 25 |
+
return fs.exists(filename)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def fsspec_listdir(dirname):
|
| 29 |
+
"""Listdir in manner compatible with fsspec."""
|
| 30 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
| 31 |
+
return fs.ls(dirname)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def fsspec_mkdirs(dirname, exist_ok=True):
|
| 35 |
+
"""Mkdirs in manner compatible with fsspec."""
|
| 36 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
| 37 |
+
fs.makedirs(dirname, exist_ok=exist_ok)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def print_nans(tensor, name):
|
| 41 |
+
if torch.isnan(tensor).any():
|
| 42 |
+
print(name, tensor)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class CosineDecayWarmupLRScheduler(
|
| 46 |
+
CosineLRScheduler,
|
| 47 |
+
torch.optim.lr_scheduler._LRScheduler):
|
| 48 |
+
|
| 49 |
+
def __init__(self, *args, **kwargs):
|
| 50 |
+
super().__init__(*args, **kwargs)
|
| 51 |
+
self._last_epoch = -1
|
| 52 |
+
self.step(epoch=0)
|
| 53 |
+
|
| 54 |
+
def step(self, epoch=None):
|
| 55 |
+
if epoch is None:
|
| 56 |
+
self._last_epoch += 1
|
| 57 |
+
else:
|
| 58 |
+
self._last_epoch = epoch
|
| 59 |
+
# We call either step or step_update, depending on
|
| 60 |
+
# whether we're using the scheduler every epoch or every
|
| 61 |
+
# step.
|
| 62 |
+
# Otherwise, lightning will always call step (i.e.,
|
| 63 |
+
# meant for each epoch), and if we set scheduler
|
| 64 |
+
# interval to "step", then the learning rate update will
|
| 65 |
+
# be wrong.
|
| 66 |
+
if self.t_in_epochs:
|
| 67 |
+
super().step(epoch=self._last_epoch)
|
| 68 |
+
else:
|
| 69 |
+
super().step_update(num_updates=self._last_epoch)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class LoggingContext:
|
| 73 |
+
"""Context manager for selective logging."""
|
| 74 |
+
def __init__(self, logger, level=None, handler=None, close=True):
|
| 75 |
+
self.logger = logger
|
| 76 |
+
self.level = level
|
| 77 |
+
self.handler = handler
|
| 78 |
+
self.close = close
|
| 79 |
+
|
| 80 |
+
def __enter__(self):
|
| 81 |
+
if self.level is not None:
|
| 82 |
+
self.old_level = self.logger.level
|
| 83 |
+
self.logger.setLevel(self.level)
|
| 84 |
+
if self.handler:
|
| 85 |
+
self.logger.addHandler(self.handler)
|
| 86 |
+
|
| 87 |
+
def __exit__(self, et, ev, tb):
|
| 88 |
+
if self.level is not None:
|
| 89 |
+
self.logger.setLevel(self.old_level)
|
| 90 |
+
if self.handler:
|
| 91 |
+
self.logger.removeHandler(self.handler)
|
| 92 |
+
if self.handler and self.close:
|
| 93 |
+
self.handler.close()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
|
| 97 |
+
"""Initializes multi-GPU-friendly python logger."""
|
| 98 |
+
|
| 99 |
+
logger = logging.getLogger(name)
|
| 100 |
+
logger.setLevel(level)
|
| 101 |
+
|
| 102 |
+
# this ensures all logging levels get marked with the rank zero decorator
|
| 103 |
+
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
| 104 |
+
for level in ('debug', 'info', 'warning', 'error',
|
| 105 |
+
'exception', 'fatal', 'critical'):
|
| 106 |
+
setattr(logger,
|
| 107 |
+
level,
|
| 108 |
+
lightning.pytorch.utilities.rank_zero_only(
|
| 109 |
+
getattr(logger, level)))
|
| 110 |
+
|
| 111 |
+
return logger
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def str2bool(v):
|
| 115 |
+
if isinstance(v, bool):
|
| 116 |
+
return v
|
| 117 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 118 |
+
return True
|
| 119 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 120 |
+
return False
|
| 121 |
+
else:
|
| 122 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def set_seed(seed, use_cuda):
|
| 126 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 127 |
+
np.random.seed(seed)
|
| 128 |
+
random.seed(seed)
|
| 129 |
+
torch.manual_seed(seed)
|
| 130 |
+
# torch.backends.cudnn.deterministic = True
|
| 131 |
+
if use_cuda:
|
| 132 |
+
torch.cuda.manual_seed(seed)
|
| 133 |
+
torch.cuda.manual_seed_all(seed)
|
| 134 |
+
print(f'=> Seed of the run set to {seed}')
|
| 135 |
+
|